In [1]:
import numpy as np
import pickle
import yaml
import pandas as pd
import sys
import torch
import torch.distributions as td
sys.path.append('/home/erik.ohara/macaw/')
from utils.helpers import dict2namespace
from macaw import MACAW
import matplotlib.pyplot as plt

In [2]:
pca_path = '/work/forkert_lab/erik/PCA/mitacs_no_crop_2'
reshaped_path = '/work/forkert_lab/erik/MACAW/reshaped/mitacs_no_crop/all'
data_path = '/home/erik.ohara/SFCN_PD_scanner'
macaw_path = '/home/erik.ohara/macaw'
model_path = '/work/forkert_lab/erik/MACAW/models/PD_PCA3_age_sex_indsite_whole_obj_rever'
ncauses = 3
nevecs = 50
ncomps = 575
nbasecomps = 25

In [3]:
data_train = np.load(reshaped_path + '/reshaped_3D_train.npy')
data_val = np.load(reshaped_path + '/reshaped_3D_val.npy')

In [None]:
with open(pca_path + '/pca.pkl','rb') as f:  
    pca = pickle.load(f)

In [None]:
def one_hot(a, veclen=10):
    b = np.zeros((a.size, veclen))
    b[np.arange(a.size), a] = 1
    return b

def inverse_one_hot(b):
    return np.argmax(b,axis=1)

In [None]:
df = pd.read_csv(data_path + '/all_df_2.csv',low_memory=False)
print(f"The original size of the dataframe is {df.shape}")
df_train = pd.read_csv(data_path + '/split/all/df_train.csv',low_memory=False)
df_val = pd.read_csv(data_path + '/split/all/df_val.csv',low_memory=False)
print(f"The size of the df_train is {df_train.shape}")
print(f"The size of the df_train is {df_val.shape}")

In [None]:
site = df['Site_3']
age = df['Age'].astype(int) - df['Age'].astype(int).min()
sex = df['Sex_bin'] 
number_sites = len(df['Site_3'].unique())
ncauses = ncauses -1 + number_sites
site_one_hot = one_hot(site)

site_train = df_train['Site_3']
age_train = df_train['Age'].astype(int) - df['Age'].astype(int).min()
site_one_hot_train = one_hot(site_train) 
sex_train = df_train['Sex_bin'] 

site_val = df_val['Site_3']
age_val = df_val['Age'].astype(int) - df['Age'].astype(int).min()
site_one_hot_val = one_hot(site_val) 
sex_val = df_val['Sex_bin'] 

In [None]:
# causal Graph
site_to_latents = [(l,i)  for l in range(number_sites) for i in range(ncauses,nevecs+ncauses)]
#site_to_age_sex = [(l,i)  for l in range(number_sites) for i in range(number_sites,ncauses)]
age_to_latents = [(number_sites,i) for i in range(ncauses,nevecs+ncauses)]
sex_to_latents = [(number_sites+1,i) for i in range(ncauses,nevecs+ncauses)]
#autoregressive_latents = [(i,j) for i in range(ncauses,nevecs+ncauses-1) for j in range(i+1,nevecs+ncauses)]
edges =  site_to_latents + age_to_latents + sex_to_latents #+ autoregressive_latents

In [None]:
with open(macaw_path + '/config/mitacs.yaml', 'r') as f:
    config_raw = yaml.load(f, Loader=yaml.FullLoader)
    
config = dict2namespace(config_raw)
config.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
#Priors
unique_values, counts = np.unique(site, return_counts=True)
print(unique_values)
P_site = counts/np.sum(counts)

P_sex = np.sum(sex)/len(sex)

unique_values_age, counts_age = np.unique(age, return_counts=True)
P_age = counts_age/np.sum(counts_age)
# Filling P_Ages with 0 when there is no one on age
new_P_age = np.array([])
for each_age in range(age.max()+1):
    achou = False
    for idx, age_comparing in enumerate(unique_values_age):
        if each_age == age_comparing:
            new_P_age = np.append(new_P_age,P_age[idx])
            achou = True
    if not achou:   
        new_P_age = np.append(new_P_age,0)
print(f"P_age shape: {P_age.shape}")
print(f"new_P_age shape: {new_P_age.shape}")

priors = [(slice(0,number_sites),td.OneHotCategorical(torch.tensor([P_site]).to(config.device))), # site
          #(slice(number_sites,number_sites+1),td.Normal(torch.zeros(1).to(config.device), torch.ones(1).to(config.device))), # age
          (slice(number_sites,number_sites+1),td.Categorical(torch.tensor([new_P_age]).to(config.device))), # age
          #(slice(number_sites+1,ncauses),td.Normal(torch.zeros(1).to(config.device), torch.ones(1).to(config.device))), # sex
          (slice(number_sites+1,ncauses),td.Bernoulli(torch.tensor([P_sex]).to(config.device))), # sex
          (slice(ncauses,nbasecomps+ncauses),td.Normal(torch.zeros(nbasecomps).to(config.device), torch.ones(nbasecomps).to(config.device))), # base_comps
          (slice(nbasecomps+ncauses,nevecs+ncauses),td.Normal(torch.zeros(nevecs-nbasecomps).to(config.device), torch.ones(nevecs-nbasecomps).to(config.device))), # new_comps
         ]

In [None]:
with open(model_path + '/scalers.pkl','rb') as f:  
    scalers = pickle.load(f)

In [None]:
all_encoded_obs_train = pca.transform(data_train)
all_encoded_obs_val = pca.transform(data_val)

In [None]:
np.zeros(all_encoded_obs_train.shape)

In [None]:
datashape1 = ncauses + nevecs
all_z_obs_train = np.zeros(all_encoded_obs_train.shape)
all_z_obs_val = np.zeros(all_encoded_obs_val.shape)
all_x_obs_train = np.zeros(all_encoded_obs_train.shape)
all_x_obs_val = np.zeros(all_encoded_obs_val.shape)
for e in range(0,ncomps-nbasecomps,nevecs-nbasecomps):
    encoded_data_train = all_encoded_obs_train[:,e:e+nevecs]
    encoded_data_val = all_encoded_obs_val[:,e:e+nevecs]
    macaw = MACAW.MACAW(config)
    macaw.load_model(model_path + f'/macaw_pd_PCA_{e}.pt', edges, priors, datashape1)
    macaw.model.to(config.device)
    for each_flow in macaw.flow_list:
        each_flow.to(config.device)
        each_flow.device = config.device
    macaw.device = config.device
    scaler = scalers[f"{e}"]
    encoded_data_train = scaler.transform(encoded_data_train)
    encoded_data_val = scaler.transform(encoded_data_val)
    X_train = np.hstack([site_one_hot_train,
                       np.array(age_train)[:,np.newaxis],
                       np.array(sex_train)[:,np.newaxis],
                       encoded_data_train]) 
    X_val = np.hstack([site_one_hot_val,
                   np.array(age_val)[:,np.newaxis],
                   np.array(sex_val)[:,np.newaxis],
                   encoded_data_val]) 
    z_obs_train = macaw._forward_flow(X_train)
    z_obs_val = macaw._forward_flow(X_val)
    diff_train = z_obs_train - X_train
    all_x_obs_train[:,e:e+nevecs] = encoded_data_train
    all_x_obs_val[:,e:e+nevecs] = encoded_data_val
    all_z_obs_train[:,e:e+nevecs] = z_obs_train[:,ncauses:]
    all_z_obs_val[:,e:e+nevecs] = z_obs_val[:,ncauses:]

In [None]:
all_x_obs_train[:,11]

In [None]:
all_z_obs_train[:,11]

# X_obs train distribution

In [None]:
fig, axs = plt.subplots(4, 3, figsize=(15, 10))
plot_limit = 12
# Flatten the axes array for easy iteration
axs = axs.flatten()

# Plot each loss value in the grid
nplot = 0
for i in range(all_x_obs_train.shape[1]):
    if nplot < plot_limit:
        nplot += 1
    else:
        break
    axs[i].hist(all_x_obs_train[:,i], bins='auto') 
    axs[i].set_title(f"PCA {i} ")
    #axs[i].set_xlim(left=-.9, right=1)
    #axs[i].set_ylim(top=5e4)

plt.tight_layout()
plt.show()

# Z_obs train distribution

In [None]:
fig, axs = plt.subplots(4, 3, figsize=(15, 10))
plot_limit = 12
# Flatten the axes array for easy iteration
axs = axs.flatten()

# Plot each loss value in the grid
nplot = 0
for i in range(all_z_obs_train.shape[1]):
    if nplot < plot_limit:
        nplot += 1
    else:
        break
    axs[i].hist(all_z_obs_train[:,i], bins='auto') 
    axs[i].set_title(f"PCA {i} ")
    #axs[i].set_xlim(left=-.9, right=1)
    #axs[i].set_ylim(top=5e4)

plt.tight_layout()
plt.show()