In [1]:
import torch
import torch.distributions as td
import sys
import yaml
import pandas as pd
import numpy as np
from tensorboardX import SummaryWriter
macaw_path = '/home/erik.ohara/macaw'
sys.path.append(macaw_path +'/')
from utils.helpers import dict2namespace
from macaw import MACAW
from compression.vqvae import vqvae

In [2]:
model_path = f"/work/forkert_lab/erik/MACAW/models/PCA3D_15000_new"
model_path_ae = "/work/forkert_lab/erik/MACAW/models/macaw_AE3D_4000"
vqvae_path = '/work/forkert_lab/erik/MACAW/models/vqvae3D_8'
ukbb_path = '/home/erik.ohara/UKBB'
nevecs = 50
ncauses = 2
nbasecomps = 25

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
with open(macaw_path +'/compression/vqvae/vqvae.yaml', 'r') as f:
    config_raw_vq_vae = yaml.load(f, Loader=yaml.FullLoader)
config_vq_vae = dict2namespace(config_raw_vq_vae)
config_vq_vae.device = device

In [5]:
writer = SummaryWriter(macaw_path + 'logs/vqvae_MM')

In [6]:
vqvae_checkpoint = torch.load(vqvae_path + '/vqvae_UKBB_best.pt', map_location=torch.device(device))

In [7]:
vqvae_checkpoint['model']['encoder.3.weight'].shape

torch.Size([8, 8, 4, 4])

In [8]:
model_vqvae = vqvae.VQVAE(config_vq_vae,writer)
model_vqvae.load_checkpoint(vqvae_path + '/vqvae_UKBB_best.pt')

(244,
 {'train_loss': [0.02070819027721882,
   0.009951869957149029,
   0.00868601817637682,
   0.008131437003612518,
   0.007512097712606192,
   0.006955557968467474,
   0.006557262036949396,
   0.006033552344888449,
   0.005535459611564875,
   0.005094088613986969,
   0.0047188918106257915,
   0.004407292697578669,
   0.004106500651687384,
   0.00391008798032999,
   0.003722347319126129,
   0.0035424442030489445,
   0.0034138059709221125,
   0.0033190518151968718,
   0.003206800203770399,
   0.0031246310099959373,
   0.0030677043832838535,
   0.003000636352226138,
   0.002947735833004117,
   0.0029183528386056423,
   0.002869099611416459,
   0.002857488812878728,
   0.00283345440402627,
   0.0028050108812749386,
   0.00277797132730484,
   0.0027777855284512043,
   0.002742487471550703,
   0.0027286880649626255,
   0.002720511518418789,
   0.002723422134295106,
   0.002717782976105809,
   0.0027037758845835924,
   0.002671454567462206,
   0.0026926298160105944,
   0.00266890786588192,

In [None]:
load_dict = torch.load(model_path + f'/macaw_ukbb_PCA3D_0.pt', map_location=torch.device(device))

In [None]:
load_dict['model_state_dict_flows'][3].keys()

In [None]:
with open(macaw_path + '/config/ukbb.yaml', 'r') as f:
    config_raw = yaml.load(f, Loader=yaml.FullLoader)
config = dict2namespace(config_raw)
config.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
data_path = ukbb_path + '/test.csv'
df = pd.read_csv(data_path,low_memory=False)
all_eid = df[['eid']].to_numpy()
min_age = df['Age'].min()
print(f"Age min: {min_age}")
sex = df['Sex'] 
age = df['Age'] - min_age

In [None]:
unique_values, counts = np.unique(age, return_counts=True)
P_sex = np.sum(sex)/len(sex)
P_age = counts/np.sum(counts)
priors = [(slice(0,1),td.Bernoulli(torch.tensor([P_sex]).to(config.device))), # sex
          (slice(1,2),td.Categorical(torch.tensor([P_age]).to(config.device))), # age
          (slice(ncauses,nbasecomps+ncauses),td.Normal(torch.zeros(nbasecomps).to(config.device), torch.ones(nbasecomps).to(config.device))), # base_comps
          (slice(nbasecomps+ncauses,nevecs+ncauses),td.Normal(torch.zeros(nevecs-nbasecomps).to(config.device), torch.ones(nevecs-nbasecomps).to(config.device))), # new_comps
         ]

In [None]:
# causal Graph
sex_to_latents = [(0,i) for i in range(ncauses,nevecs+ncauses)]
age_to_latents = [(1,i) for i in range(ncauses,nevecs+ncauses)]
autoregressive_latents = [(i,j) for i in range(ncauses,nevecs+ncauses-1) for j in range(i+1,nevecs+ncauses)]
edges = sex_to_latents + age_to_latents + autoregressive_latents

In [None]:
macaw = MACAW.MACAW(config)

In [None]:
macaw.hidden

In [None]:
datashape1 = ncauses + nevecs
macaw.load_model(model_path + f'/macaw_ukbb_PCA3D_0.pt',
                             edges,priors,datashape1)

In [None]:
macaw.flow_list[3].state_dict().keys()

In [None]:
macaw_ae = torch.load(model_path_ae + f'/macaw_ukbb_PCA3D_0.pt', map_location=torch.device(config.device))

In [None]:
macaw_ae.flow_list[0].state_dict().keys()

In [None]:
macaw_ae.pdim