In [1]:
import sys
import numpy as np

from sklearn.preprocessing import StandardScaler

import torch
import torch.distributions as td
import os
os.environ['TF_DETERMINISTIC_OPS'] = '1'

import yaml

sys.path.append('/home/erik.ohara/macaw/')
from utils.helpers import dict2namespace
from macaw import MACAW

import pickle
import pandas as pd

In [2]:
nevecs = 50
ncauses = 2
ncomps = 1500
nbasecomps = 25
ukbb_path = '/home/erik.ohara/UKBB'
pca_path = '/work/forkert_lab/erik/PCA3D'
macaw_path = '/home/erik.ohara/macaw'
reshaped_path = '/work/forkert_lab/erik/MACAW/reshaped/3D'
ukbb_path_T1_slices = '/work/forkert_lab/erik/T1_warped/train'
scalers_path = '/wor+k/forkert_lab/erik/MACAW/scalers/PCA3D_1500_experiments
output_path = f"/work/forkert_lab/erik/MACAW/models/PCA3D_1500_experiments/{nevecs}"

In [3]:
data = np.load(reshaped_path + '/reshaped_3D_train.npy')
data = data.reshape(data.shape[0],-1)
print("Data train loaded")

Data train loaded


In [4]:
data_val = np.load(reshaped_path + '/reshaped_3D_val.npy')
data_val = data_val.reshape(data_val.shape[0],-1)

In [5]:
with open(pca_path + "/evecs.pkl",'rb') as f:  
    evecs3D = pickle.load(f)

In [6]:
def encode(data, evecs):
    return np.matmul(data,evecs.T)

def decode(data,evecs):
    return np.matmul(data,evecs)

In [7]:
# Getting Age and Sex data
data_path = ukbb_path + '/ukbb_img.csv'
df = pd.read_csv(data_path,low_memory=False)
min_age = df['Age'].min()

sex = df['Sex'] 
age = df['Age'] - min_age

## Changing from here 

In [8]:
# Loading configurations
with open(macaw_path + '/config/ukbbHPCA.yaml', 'r') as f:
    config_raw = yaml.load(f, Loader=yaml.FullLoader)
config = dict2namespace(config_raw)
config.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [9]:
# Priors
P_sex = np.sum(sex)/len(sex)

unique_values, counts = np.unique(age, return_counts=True)
P_age = counts/np.sum(counts)
priors = [(slice(0,1),td.Bernoulli(torch.tensor([P_sex]).to(config.device))), # sex
          (slice(1,2),td.Categorical(torch.tensor([P_age]).to(config.device))), # age
          (slice(ncauses,nbasecomps+ncauses),td.Normal(torch.zeros(nbasecomps).to(config.device), torch.ones(nbasecomps).to(config.device))), # base_comps
          (slice(nbasecomps+ncauses,nevecs+ncauses),td.Normal(torch.zeros(nevecs-nbasecomps).to(config.device), torch.ones(nevecs-nbasecomps).to(config.device))), # new_comps
         ]

  (slice(1,2),td.Categorical(torch.tensor([P_age]).to(config.device))), # age


In [10]:
df = pd.read_csv(ukbb_path + '/train.csv',low_memory=False)
df_val = pd.read_csv(ukbb_path + '/val.csv',low_memory=False)

In [11]:
sex = df['Sex'] 
sex_val = df_val['Sex'] 
age = df['Age'] - min_age
age_val = df_val['Age'] - min_age

# causal Graph
sex_to_latents = [(0,i) for i in range(ncauses,nevecs+ncauses)]
age_to_latents = [(1,i) for i in range(ncauses,nevecs+ncauses)]
autoregressive_latents = [(i,j) for i in range(ncauses,nevecs+ncauses-1) for j in range(i+1,nevecs+ncauses)]
edges = sex_to_latents + age_to_latents + autoregressive_latents

In [12]:
loss_vals_all= []
scalers = {}

In [None]:
for e in range(0,ncomps-nbasecomps,nevecs-nbasecomps):
#for e in range(0,ncomps,nevecs):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    save_path = output_path + f'/macaw_ukbb_PCA3D_{e}.pt'

    encoded_data = encode(data,evecs3D[e:e+nevecs])
    encoded_data_val = encode(data_val,evecs3D[e:e+nevecs])
    scaler = StandardScaler()
    encoded_data = scaler.fit_transform(encoded_data)
    encoded_data_val = scaler.transform(encoded_data_val)
    scalers[f"{e}"] = scaler
    print(e)
    
    if not os.path.exists(save_path):    
        #X = np.hstack([causes, encoded_data])  
        X = np.hstack([np.array(sex)[:,np.newaxis], np.array(age)[:,np.newaxis], encoded_data])   
        X_val = np.hstack([np.array(sex_val)[:,np.newaxis], np.array(age_val)[:,np.newaxis], encoded_data_val])   

        macaw = MACAW.MACAW(config)
        #loss_vals = macaw.fit(X,edges, augment=True)
        loss_vals = macaw.fit_with_priors(X,edges, priors, validation=X_val)
        df_loss_vals = pd.DataFrame(loss_vals)
        if (df_loss_vals.isnull().values.any()):
            print("Tem um nulo no {}".format(e))
            break
        loss_vals_all.append(loss_vals)
        
        #macaw.save_best_model()
        torch.save(macaw,save_path)