In [12]:
#import stuff 
import yaml
import os
import sys
import shutil
import numpy as np
import torch
import h5py

from itertools import cycle

from torch.backends import cudnn
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable, grad

from src.data import LoadDataset
from src.ufdn import LoadModel

#We'll need to modify this for 3+ domains
from src.util import vae_loss, calc_gradient_penalty, interpolate_vae_3d

from tensorboardX import SummaryWriter 

In [2]:
#set this to the path of this yaml file 
#get it from my github branch 
config_path = 'config/tcga_domains.yaml'

In [6]:
conf = yaml.load(open(config_path,'r'))
exp_name = conf['exp_setting']['exp_name']
#img_size is only used in conv nets
#originally it was 64
img_size = conf['exp_setting']['img_size']
#20,501 img_depth
img_depth = conf['exp_setting']['img_depth']
domains = conf['exp_setting']['domains']
number_of_domains = len(domains)


data_root = conf['exp_setting']['data_root']
batch_size = conf['trainer']['batch_size']


enc_dim = conf['model']['vae']['encoder'][-1][1] #latent space dimension #100
code_dim = conf['model']['vae']['code_dim'] #number of domains #currently 3 
vae_learning_rate = conf['model']['vae']['lr'] #learning rate #10e-4
vae_betas = tuple(conf['model']['vae']['betas']) #used for adam optimizer


In [7]:
#load the model in a blank form 
vae = LoadModel('vae',conf['model']['vae'],img_size,img_depth)

In [9]:
#load in the trained params
#put statedict.pt in the same directory as this ipynb 
vae.load_state_dict(torch.load('statedict.pt'))

In [17]:
#set to eval mode, very important (so it doesn't train)
vae.eval()

UFDN(
  (enc_0): Sequential(
    (0): Linear(in_features=20501, out_features=500, bias=True)
    (1): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (enc_mu): Sequential(
    (0): Linear(in_features=500, out_features=100, bias=True)
  )
  (enc_logvar): Sequential(
    (0): Linear(in_features=500, out_features=100, bias=True)
  )
  (dec_0): Sequential(
    (0): Linear(in_features=133, out_features=500, bias=True)
    (1): BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (dec_1): Sequential(
    (0): Linear(in_features=500, out_features=20501, bias=True)
  )
)

In [13]:
#get the data 

tcga = h5py.File('../TCGAProject/tcga_01.h5', mode='r') # Adjust for correct file
cancers = list(tcga['tcga/train'])
tcga_stack = np.vstack(list([tcga['tcga/train/'+c] for c in cancers]))

In [18]:
#get a random encoding 
enc_1 = vae(Variable(torch.FloatTensor(tcga_stack[0,:]).unsqueeze(0)),return_enc=True).cpu().data.numpy()

In [22]:
enc_1[0].shape

AttributeError: 'numpy.ndarray' object has no attribute 'expand'

In [25]:
encodings = np.zeros((7301,100))

In [29]:
#get all the encodings 
for i in range(tcga_stack.shape[0]):
    if i%100==0: 
        print(i)
    encodings[i] = vae(Variable(torch.FloatTensor(tcga_stack[i,:]).unsqueeze(0)),return_enc=True).cpu().data.numpy()

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300


In [30]:
np.save(arr=encodings, file='encodings.npy')

In [53]:
encodings

array([[ 1.53615141,  2.03998041,  0.26760939, ...,  2.68303204,
        -0.45182565, -1.81147921],
       [ 1.53098238,  2.02159905,  0.27904192, ...,  2.67480159,
        -0.45245257, -1.78775108],
       [ 1.52760839,  2.03231525,  0.25800234, ...,  2.67778993,
        -0.46561608, -1.80481398],
       ...,
       [ 1.5598805 ,  2.02180123,  0.29629704, ...,  2.66306639,
        -0.45519075, -1.82134151],
       [ 1.55105269,  2.03682613,  0.29667199, ...,  2.66759586,
        -0.43666753, -1.80562031],
       [ 1.54031563,  2.02915335,  0.29366675, ...,  2.66511512,
        -0.43831924, -1.79955018]])

In [32]:
np.load('encodings.npy')

array([[ 1.53615141,  2.03998041,  0.26760939, ...,  2.68303204,
        -0.45182565, -1.81147921],
       [ 1.53098238,  2.02159905,  0.27904192, ...,  2.67480159,
        -0.45245257, -1.78775108],
       [ 1.52760839,  2.03231525,  0.25800234, ...,  2.67778993,
        -0.46561608, -1.80481398],
       ...,
       [ 1.5598805 ,  2.02180123,  0.29629704, ...,  2.66306639,
        -0.45519075, -1.82134151],
       [ 1.55105269,  2.03682613,  0.29667199, ...,  2.66759586,
        -0.43666753, -1.80562031],
       [ 1.54031563,  2.02915335,  0.29366675, ...,  2.66511512,
        -0.43831924, -1.79955018]])

In [33]:
#messing aroudn to figure out the decoding interpolations 


import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable, grad


In [50]:
attr_inters = 5
id_inters = 3
attr_max = 1.0
attr_dim=3
random_test=False
return_each_layer=False
sd =1
disentangle_dim=None
    
attr_min = 1.0-attr_max

alphas = np.linspace(attr_min, attr_max, attr_inters)
if disentangle_dim:
    alphas = [torch.FloatTensor([*([1 - alpha]*int((attr_dim-disentangle_dim)/3)),
                                 *([0]*int((attr_dim-disentangle_dim)/3)),
                                 *([alpha]*int((attr_dim-disentangle_dim)/3)),
                                 *([ v for i in range(int(disentangle_dim/2)) for v in [1-alpha,alpha]])]) for alpha in alphas]\
            +[torch.FloatTensor([*([0]*int((attr_dim-disentangle_dim)/3)),
                                 *([alpha]*int((attr_dim-disentangle_dim)/3)),
                                 *([1-alpha]*int((attr_dim-disentangle_dim)/3)),
                                 *([ v for i in range(int(disentangle_dim/2)) for v in [alpha,1-alpha]])]) for alpha in alphas[1:]]\
            +[torch.FloatTensor([*([alpha]*int((attr_dim-disentangle_dim)/3)),
                                 *([1 - alpha]*int((attr_dim-disentangle_dim)/3)),
                                 *([0]*int((attr_dim-disentangle_dim)/3)),
                                 *([ v for i in range(int(disentangle_dim/2)) for v in [1-alpha,alpha]])]) for alpha in alphas[1:-1]]
else:
    alphas = [torch.FloatTensor([*([1 - alpha]*int(attr_dim/3)),
                                 *([0]*int(attr_dim/3)),
                                 *([alpha]*int(attr_dim/3))]) for alpha in alphas]\
            +[torch.FloatTensor([*([0]*int(attr_dim/3)),
                                 *([alpha]*int(attr_dim/3)),
                                 *([1-alpha]*int(attr_dim/3))]) for alpha in alphas[1:]]\
            +[torch.FloatTensor([*([alpha]*int(attr_dim/3)),
                                 *([1 - alpha]*int(attr_dim/3)),
                                 *([0]*int(attr_dim/3))]) for alpha in alphas[1:-1]]

In [51]:
alphas

[tensor([1., 0., 0.]),
 tensor([0.7500, 0.0000, 0.2500]),
 tensor([0.5000, 0.0000, 0.5000]),
 tensor([0.2500, 0.0000, 0.7500]),
 tensor([0., 0., 1.]),
 tensor([0.0000, 0.2500, 0.7500]),
 tensor([0.0000, 0.5000, 0.5000]),
 tensor([0.0000, 0.7500, 0.2500]),
 tensor([0., 1., 0.]),
 tensor([0.2500, 0.7500, 0.0000]),
 tensor([0.5000, 0.5000, 0.0000]),
 tensor([0.7500, 0.2500, 0.0000])]

In [52]:
for alpha in alphas:
    alpha = Variable(alpha.unsqueeze(0).expand((1, attr_dim)))
    print(alpha)

tensor([[1., 0., 0.]])
tensor([[0.7500, 0.0000, 0.2500]])
tensor([[0.5000, 0.0000, 0.5000]])
tensor([[0.2500, 0.0000, 0.7500]])
tensor([[0., 0., 1.]])
tensor([[0.0000, 0.2500, 0.7500]])
tensor([[0.0000, 0.5000, 0.5000]])
tensor([[0.0000, 0.7500, 0.2500]])
tensor([[0., 1., 0.]])
tensor([[0.2500, 0.7500, 0.0000]])
tensor([[0.5000, 0.5000, 0.0000]])
tensor([[0.7500, 0.2500, 0.0000]])


In [69]:
vae.decode(Variable(torch.FloatTensor(encodings[1000,:]).unsqueeze(0)), Variable(torch.FloatTensor(np.identity(33)[3,:]).unsqueeze(0)))

tensor([[-1.4640e-02, -1.1962e-02, -3.6094e-02,  ...,  7.7365e-02,
          2.7635e-02, -2.4737e-02]], grad_fn=<ThAddmmBackward>)

In [64]:
encodings[1000,:]

array([ 1.54030132,  2.02600002,  0.27444807, -1.54520559,  2.70614004,
        0.25061393,  1.10731661, -1.0014708 ,  4.08568716, -1.56517661,
        3.10664129, -2.52854395,  2.67715669,  2.08753204,  3.8629775 ,
       -1.10294557,  3.14427829,  2.57238388, -1.33304691,  3.34982204,
       -4.08677578, -4.98691845, -2.21853876,  2.30798244, -2.00586915,
        3.21945572,  0.50464016, -1.49799097, -5.0915184 ,  0.39732355,
        1.00865614, -0.36023962,  1.97710395,  1.6248368 ,  1.06481194,
       -0.74596322, -1.54945481, -1.3093797 ,  0.58001065,  2.5515008 ,
       -6.48347139, -2.7555809 , -0.72744966, -0.28303605, -3.69317007,
       -3.40090013,  3.80709147,  4.36505938,  1.09659648,  4.76109028,
       -1.56292117,  1.60711229, -0.20046866, -0.27617255, -0.25786683,
       -2.89299488, -1.18031538, -6.29413462,  0.64341903,  3.25232434,
        1.69875801, -1.01921535, -0.97148287, -2.11943579,  0.24789819,
        4.7263093 , -1.41808462,  1.49636734,  1.92666674,  3.33