In [1]:
import numpy as np
import pickle
import sys
macaw_path = '/home/erik.ohara/macaw'
sys.path.append(macaw_path +'/')
from macaw import MACAW
from utils.helpers import dict2namespace
import yaml
import torch
import torch.distributions as td
import pandas as pd

In [2]:
path = '/work/forkert_lab/erik/MACAW/cf_images/macaw_vqvae8_50nevecs_2_zero'
vqvae_path = '/work/forkert_lab/erik/MACAW/models/vqvae3D_8'
ukbb_path = '/home/erik.ohara/UKBB'
model_path = f"/work/forkert_lab/erik/MACAW/models/macaw_vqvae8_50nevecs_2"
nevecs = 50
ncauses = 2
ncomps = 10625
nbasecomps = 25

In [3]:
with open(macaw_path + '/config/ukbbVQVAE.yaml', 'r') as f:
    config_raw = yaml.load(f, Loader=yaml.FullLoader)
    
config = dict2namespace(config_raw)
config.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
diff_0 = np.load(path + "/diff_0.npy")

In [5]:
diff_0.shape[1]

402

In [6]:
diff_0[:,2]

array([0.88822135, 0.90224285, 0.89696575, ..., 0.88498354, 0.88383241,
       0.88658599])

In [7]:
for ev in range(diff_0.shape[1]):
    if diff_0[:,ev].sum() != 0:
        print(ev)

2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


In [8]:
diff_1 = np.load(path + "/diff_50.npy")
for ev in range(diff_1.shape[1]):
    if diff_0[:,ev].sum() != 0:
        print(ev)

2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


In [9]:
with open(vqvae_path + '/encoded_data_val_all.pkl','rb') as file_val_data:  
        encoded_data_val_all = pickle.load(file_val_data)

In [10]:
encoded_data_val_all.shape

(2369, 8, 10625)

In [11]:
encoded_data_val_all[0,7,:50]

array([0.0326902 , 0.03645178, 0.03649143, 0.0364942 , 0.03642333,
       0.03647562, 0.03648647, 0.03659338, 0.03669963, 0.03606642,
       0.03520264, 0.03727198, 0.03776916, 0.0381999 , 0.03804722,
       0.03806691, 0.03735187, 0.03700652, 0.03694885, 0.03670905,
       0.03653187, 0.03654395, 0.03654039, 0.03697416, 0.03193894,
       0.03582074, 0.03589864, 0.03650068, 0.03630941, 0.03707056,
       0.03721583, 0.03704946, 0.03621329, 0.03545854, 0.03565918,
       0.03942621, 0.04174054, 0.04166941, 0.04194109, 0.03874976,
       0.03812994, 0.03853172, 0.03990479, 0.03794438, 0.03713609,
       0.03684948, 0.03658118, 0.03653095, 0.03616025, 0.03462523],
      dtype=float32)

In [12]:
data_all_path = ukbb_path + '/ukbb_img.csv'
df_all = pd.read_csv(data_all_path,low_memory=False)
min_age = df_all['Age'].min()

sex = df_all['Sex'] 
age = df_all['Age'] - min_age
P_sex = np.sum(sex)/len(sex)
unique_values, counts = np.unique(age, return_counts=True)
P_age = counts/np.sum(counts)
priors = [(slice(0,1),td.Bernoulli(torch.tensor([P_sex]).to(config.device))), # sex
          (slice(1,2),td.Categorical(torch.tensor([P_age]).to(config.device))), # age
          (slice(ncauses,nbasecomps+ncauses),td.Normal(torch.zeros(nbasecomps).to(config.device), torch.ones(nbasecomps).to(config.device))), # base_comps
          (slice(nbasecomps+ncauses,nevecs+ncauses),td.Normal(torch.zeros(nevecs-nbasecomps).to(config.device), torch.ones(nevecs-nbasecomps).to(config.device))), # new_comps
         ]

  (slice(1,2),td.Categorical(torch.tensor([P_age]).to(config.device))), # age


In [13]:
# causal Graph
sex_to_latents = [(0,i) for i in range(ncauses,nevecs+ncauses)]
age_to_latents = [(1,i) for i in range(ncauses,nevecs+ncauses)]
autoregressive_latents = [(i,j) for i in range(ncauses,nevecs+ncauses-1) for j in range(i+1,nevecs+ncauses)]
edges = sex_to_latents + age_to_latents + autoregressive_latents

In [14]:
n_channels = encoded_data_val_all.shape[1]

In [15]:
macaw = MACAW.MACAW(config)
datashape1 = ncauses + (nevecs * n_channels)
macaw.load_model(model_path + f'/macaw_ukbb_PCA3D_0.pt',
                            edges,priors,datashape1)

In [16]:
# val data
data_path = ukbb_path + '/val.csv'
df = pd.read_csv(data_path,low_memory=False)
all_eid = df[['eid']].to_numpy()
#causes = df[['Age','Sex']].to_numpy()
min_age = df['Age'].min()
print(f"Age min: {min_age}")
sex = df['Sex'] 
age = df['Age'] - min_age

Age min: 46


In [17]:
encoded_obs = encoded_data_val_all[:,:,0:0+nevecs].reshape(encoded_data_val_all.shape[0],-1)
X_obs = np.hstack([np.array(sex)[:,np.newaxis], np.array(age)[:,np.newaxis], encoded_obs])

In [18]:
macaw.model.eval()
with torch.no_grad():
    z_obs = macaw._forward_flow(X_obs) 
    cc = macaw._backward_flow(z_obs)

In [19]:
diff = z_obs - X_obs

In [24]:
print(X_obs.max())
print(X_obs.min())

35.0
-0.0013158321380615234


In [25]:
print(z_obs.max())
print(z_obs.min())

35.0
-6.807508


In [22]:
X_obs[:,2:52]

array([[0.09293685, 0.06166101, 0.05740533, ..., 0.0643393 , 0.06910811,
        0.0701291 ],
       [0.09292493, 0.06158376, 0.05752921, ..., 0.0643706 , 0.06907666,
        0.07014233],
       [0.09292539, 0.06164406, 0.05758187, ..., 0.0643594 , 0.06910204,
        0.07012931],
       ...,
       [0.08874224, 0.06118666, 0.05661191, ..., 0.0638795 , 0.06943066,
        0.07002006],
       [0.0887429 , 0.06111928, 0.05617022, ..., 0.06386361, 0.06943824,
        0.07001473],
       [0.12739179, 0.06210498, 0.06219234, ..., 0.068652  , 0.07362227,
        0.07592334]])

In [23]:
z_obs[:,2:52]

array([[ 0.8624756 ,  0.03508633, -0.13468707, ..., -0.04518509,
        -0.16890526, -0.3085804 ],
       [ 0.8587063 , -0.03079271, -0.06704944, ...,  0.22314453,
        -0.6955471 , -0.45117378],
       [ 0.85885304,  0.0206278 , -0.0382989 , ...,  0.4459133 ,
        -0.01369095, -0.13887024],
       ...,
       [-0.46371177, -0.3694709 , -0.5678869 , ...,  0.01689339,
         0.09090042,  0.03864288],
       [-0.46350464, -0.42694134, -0.809049  , ...,  0.6482372 ,
         0.68356514,  0.9192133 ],
       [11.755917  ,  0.41372424,  2.4789772 , ..., -2.8025513 ,
        10.994232  , 10.672541  ]], dtype=float32)

In [20]:
for ev in range(diff.shape[1]):
    if diff_0[:,ev].sum() != 0:
        print(ev)

2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
