In [1]:
#import argparse
import datetime
import sys
import json
from collections import defaultdict
from pathlib import Path
from tempfile import mkdtemp

import numpy as np
import torch
from torch import optim
import matplotlib.pyplot as plt
import math
from sklearn.mixture import GaussianMixture

import models
import objectives_dev as objectives
#from utils import Logger, Timer, save_model, save_vars, unpack_data
from utils_dev import Logger, Timer, save_model, save_vars, unpack_data, EarlyStopping

In [2]:
#args
experiment = 'test'
model = 'VADE_rna_atac' 
obj = 'elbo'
K = 10
looser = False
llik_scaling = 0
batch_size = 128
epochs = 10
n_centroids = 10
latent_dim = 8
num_hidden_layers = 2
hidden_dim =128
learn_prior = False
logp = False
print_freq = 0
no_analytics = False
seed = 1
dataSize = []

class params():
    
    def __init__(self,
                 experiment,
                 model,
                 obj,
                 K,
                 looser,
                 llik_scaling,
                 batch_size,
                 epochs,
                 n_centroids,
                 latent_dim,
                 num_hidden_layers,
                 hidden_dim,
                 learn_prior,
                 logp,
                 print_freq,
                 no_analytics,
                 seed,
                dataSize):
        
        self.experiment = experiment
        self.model = model
        self.obj = obj
        self.K = K
        self.looser = looser
        self.llik_scaling = llik_scaling
        self.batch_size = batch_size
        self.epochs = epochs
        self.n_centroids = n_centroids
        self.latent_dim = latent_dim
        self.num_hidden_layers = num_hidden_layers
        self.hidden_dim = hidden_dim
        self.learn_prior = learn_prior
        self.logp = logp
        self.print_freq = print_freq
        self.no_analytics = no_analytics
        self.seed = seed
        self.dataSize = dataSize
        
args = params(experiment,
                model,
                 obj,
                 K,
                 looser,
                 llik_scaling,
                 batch_size,
                 epochs,
                 n_centroids,
                 latent_dim,
                 num_hidden_layers,
                 hidden_dim,
                 learn_prior,
                 logp,
                 print_freq,
                 no_analytics,
                 seed,
                 dataSize)

In [3]:
# random seed
# https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
np.random.seed(args.seed)

In [4]:
device = torch.device("cpu")

In [5]:
# load model
modelC = getattr(models, 'VAE_{}'.format(args.model))
model = modelC(args).to(device)

In [None]:
train_loader = model.getDataLoaders(batch_size=args.batch_size, device=device) 

Loading  data ...
Original data contains 41036 cells x 29589 peaks
Finished loading takes 0.33 min
Loading  data ...


In [None]:
#MMVAE
for i, d in enumerate(train_loader):
    if i == 0:
        data0 = d[0]
        data1 = d[1]
    elif i < 5:
        data0 = torch.cat([data0, d[0]], dim=0)
        data1 = torch.cat([data1, d[1]], dim=0)
data = [data0.to(device), data1.to(device)]

In [78]:
#data = torch.tensor(train_loader.dataset.data.todense())

In [38]:
#testing m_elbo_naive_vade
x = data
qz_xs, px_zs, zss = model(x)
n_centroids = model.params.n_centroids
lpx_zs, klds = [], []

In [14]:
print(zss[0].shape)
print(px_zs[0][0].batch_shape)

torch.Size([640, 8])
torch.Size([640, 28767])


In [31]:
for r, qz_x in enumerate(qz_xs):
        zs = zss[r]
        gamma, mu_c, var_c, pi = model.vaes[r].get_gamma(zs)
        
        mu, logvar = model.vaes[r]._qz_x_params
        mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids)
        logvar_expand = logvar.unsqueeze(2).expand(logvar.size(0), logvar.size(1), n_centroids)
        
        # log p(z|c)
        lpz_c = -0.5*torch.sum(gamma*torch.sum(math.log(2*math.pi) + \
                                           torch.log(var_c) + \
                                           torch.exp(logvar_expand)/var_c + \
                                           (mu_expand-mu_c)**2/var_c, dim=1), dim=1)
        # log p(c)
        lpc = torch.sum(gamma*torch.log(pi), 1)

        # log q(z|x)
        lqz_x = -0.5*torch.sum(1+logvar+math.log(2*math.pi), 1) #see VaDE paper
        #lqz_x = qz_x.log_prob(zs).sum(-1)
        #lqz_x = lqz_x.squeeze()

        # log q(c|x)
        lqc_x = torch.sum(gamma*torch.log(gamma), 1)
    
        kld = -lpz_c - lpc + lqz_x + lqc_x 
        #kld = kl_divergence(qz_x, model.pz(*model.pz_params))
        klds.append(kld)

In [32]:
r = 0
for d, px_z in enumerate(px_zs[r]):
            lpx_z = px_z.log_prob(x[d]) * model.vaes[d].llik_scaling
            #lpx_zs.append(lpx_z.view(*px_z.batch_shape[:], -1).sum(-1).squeeze()) #added squeeze()
            lpx_zs.append(lpx_z.sum(-1))

In [37]:
print(lpx_z.shape)
print(lpx_zs[2].shape)

torch.Size([640, 64754])
torch.Size([640])


In [35]:
lpx_z.shape
lpx_z.sum(-1).shape
lpx_zs[0].shape

torch.Size([640, 28767])

In [23]:
lpx_z.shape

torch.Size([640, 64754])

In [39]:
for r, qz_x in enumerate(qz_xs):
        zs = zss[r]
        gamma, mu_c, var_c, pi = model.vaes[r].get_gamma(zs)
        
        mu, logvar = model.vaes[r]._qz_x_params
        mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids)
        logvar_expand = logvar.unsqueeze(2).expand(logvar.size(0), logvar.size(1), n_centroids)
        
        # log p(z|c)
        lpz_c = -0.5*torch.sum(gamma*torch.sum(math.log(2*math.pi) + \
                                           torch.log(var_c) + \
                                           torch.exp(logvar_expand)/var_c + \
                                           (mu_expand-mu_c)**2/var_c, dim=1), dim=1)
        # log p(c)
        lpc = torch.sum(gamma*torch.log(pi), 1)

        # log q(z|x)
        lqz_x = -0.5*torch.sum(1+logvar+math.log(2*math.pi), 1) #see VaDE paper
        #lqz_x = qz_x.log_prob(zs).sum(-1)
        #lqz_x = lqz_x.squeeze()

        # log q(c|x)
        lqc_x = torch.sum(gamma*torch.log(gamma), 1)
    
        kld = -lpz_c - lpc + lqz_x + lqc_x 
        #kld = kl_divergence(qz_x, model.pz(*model.pz_params))
        klds.append(kld)
        
        for d, px_z in enumerate(px_zs[r]):
            lpx_z = px_z.log_prob(x[d]) * model.vaes[d].llik_scaling
            #lpx_zs.append(lpx_z.view(*px_z.batch_shape[:2], -1).sum(-1).squeeze()) #added squeeze()
            lpx_zs.append(lpx_z.sum(-1)) #added squeeze()
    #obj = (1 / len(model.vaes)) * (torch.stack(lpx_zs).sum(0) - torch.stack(klds).sum(0))
obj = (1 / len(model.vaes)) * (torch.stack(lpx_zs).sum(0) - torch.stack(klds).sum(0))

In [40]:
torch.stack(lpx_zs).size()

torch.Size([4, 640])

In [41]:
torch.stack(klds).size()

torch.Size([2, 640])

In [42]:
torch.stack(lpx_zs).sum(0).size()

torch.Size([640])

In [43]:
torch.stack(klds).sum(0).size()

torch.Size([640])

In [44]:
lpx_zs

[tensor([-45481.0820, -45815.5078, -45500.1055, -45471.0859, -45309.5078,
         -45517.2617, -45259.9023, -45536.1289, -45447.3750, -45130.4219,
         -45401.4805, -45461.4414, -45534.7031, -45509.4727, -45564.8125,
         -45349.8203, -45756.3359, -46071.3125, -45739.6406, -46565.4102,
         -46067.2500, -45610.1680, -45867.7852, -46698.4141, -45655.1484,
         -45940.5508, -45369.4180, -45738.8164, -46203.5508, -45277.1914,
         -45232.6680, -45312.1289, -45440.0469, -45459.8203, -45321.5977,
         -45327.2070, -45171.2031, -45265.2461, -45167.4023, -45830.4961,
         -45469.5195, -45989.8945, -45391.9297, -45173.0664, -45582.9609,
         -45408.8906, -45842.9883, -45121.2188, -45159.4141, -45583.5625,
         -45302.8008, -45347.5078, -45804.3789, -45565.1016, -46423.5078,
         -45678.2188, -45540.0312, -45088.7109, -45329.0898, -45654.6914,
         -45566.7812, -45588.0586, -45184.1406, -45193.5742, -45124.2227,
         -45850.5117, -45731.4258, -46

In [11]:
klds

[tensor([2.9031, 2.9450, 2.9001, 2.8915, 2.8980, 2.8980, 2.8965, 2.9089, 2.9048,
         2.8948, 2.9001, 2.8968, 2.8984, 2.8965, 2.9040, 2.8994, 2.9270, 2.9120,
         2.8978, 2.9266, 2.9470, 2.9132, 2.9204, 2.9092, 2.9009, 3.1161, 2.9014,
         2.8988, 2.8985, 2.8940, 2.8965, 2.8958, 2.8973, 2.8992, 2.8996, 2.8943,
         2.8959, 2.8991, 2.8963, 2.9189, 2.9005, 2.9129, 2.8979, 2.8990, 2.9236,
         2.9035, 2.8991, 2.8978, 2.8973, 2.9023, 2.8957, 2.8966, 2.9038, 2.9015,
         2.9130, 2.9180, 2.8971, 2.8966, 2.8971, 2.9035, 2.9025, 2.9056, 2.8958,
         2.8964, 2.8961, 2.9046, 2.9070, 3.2065, 2.9108, 2.9085, 2.8985, 2.8968,
         2.8993, 2.8949, 2.8977, 2.9222, 2.8943, 2.8957, 2.8961, 2.8963, 2.9584,
         2.8958, 2.8947, 2.8975, 2.8964, 2.8977, 2.9069, 2.8972, 2.8962, 2.8976,
         2.8979, 2.8965, 2.9680, 2.8949, 2.8954, 2.8941, 2.8969, 2.9155, 2.9141,
         2.8966, 2.9329, 2.8970, 2.8992, 2.8999, 2.9028, 2.8959, 2.8935, 2.9012,
         2.9094, 2.8941, 2.9

In [85]:
obj

tensor([-90895.1562, -91120.8359, -90986.3516, -90651.7344, -90497.9531,
        -90869.2500, -90395.2812, -90977.2812, -90943.7188, -90427.0703,
        -90544.5234, -91013.0391, -90788.5078, -90915.5625, -90725.4688,
        -90791.5391, -90939.2578, -91260.2266, -91509.3047, -91652.6328,
        -91397.5469, -90844.5312, -91117.0781, -91913.4062, -90865.3047,
        -91082.5547, -90441.0312, -90880.8281, -91525.3984, -90604.9062,
        -90319.5859, -90578.5781, -90600.2656, -90644.2422, -90455.0156,
        -90833.7422, -90345.3203, -90597.0312, -90196.0000, -91242.6250,
        -90771.5625, -91230.9062, -90565.1016, -90481.7656, -90919.6562,
        -90662.9609, -91101.5625, -90427.6328, -90434.4844, -90756.3750,
        -90729.3984, -90446.0469, -91382.9219, -90756.0156, -91611.5234,
        -90815.5391, -90817.2969, -90207.6797, -90798.1641, -90833.5234,
        -90776.3828, -90775.3359, -90476.5547, -90426.2188, -90218.3672,
        -91322.4688, -91157.3281, -92146.2188, -912

In [51]:
#testing init_gmm_params
output = []
for i, data in enumerate(train):
            
    #qz_xs, px_zs, zss = self.forward(data)
    zss = model.latents(data, sampling=True)
    zs = sum(zss)/len(zss)
            
    output.append(zs.squeeze().detach().cpu())

In [55]:
output = []
for i, dataT in enumerate(train):      
    qz_x, px_z, zs = model.forward(dataT)
    output.append(zs.squeeze().detach().cpu())

In [55]:
print(output[0].size())
print(len(output))

torch.Size([128, 8])
321


In [57]:
output = torch.cat(output).numpy()

In [61]:
output.shape

(41036, 8)

In [64]:
gmm = GaussianMixture(n_components=model.params.n_centroids, covariance_type='diag', init_params='kmeans')
gmm.fit(output)

GaussianMixture(covariance_type='diag', init_params='kmeans', max_iter=100,
                means_init=None, n_components=10, n_init=1,
                precisions_init=None, random_state=None, reg_covar=1e-06,
                tol=0.001, verbose=0, verbose_interval=10, warm_start=False,
                weights_init=None)

In [65]:
model.mu_c.data.copy_(torch.from_numpy(gmm.means_.T.astype(np.float32)))
model.var_c.data.copy_(torch.from_numpy(gmm.covariances_.T.astype(np.float32)))

tensor([[0.5193, 0.4833, 0.5020, 0.5097, 0.2943, 0.4887, 0.4805, 0.5040, 0.3346,
         0.4367],
        [0.4068, 0.3201, 0.3961, 0.4090, 0.4443, 0.2943, 0.4284, 0.4582, 0.4863,
         0.4227],
        [0.4992, 0.3733, 0.4824, 0.4973, 0.4022, 0.3890, 0.5086, 0.4577, 0.4273,
         0.4227],
        [0.4780, 0.4070, 0.4891, 0.4700, 0.4314, 0.4584, 0.5266, 0.3081, 0.5033,
         0.5196],
        [0.3786, 0.3934, 0.3860, 0.2956, 0.3736, 0.4044, 0.3142, 0.3504, 0.4847,
         0.4113],
        [0.3949, 0.5473, 0.4276, 0.4756, 0.4108, 0.4325, 0.4815, 0.4822, 0.3721,
         0.3432],
        [0.4713, 0.4763, 0.2943, 0.4517, 0.3992, 0.4393, 0.4676, 0.3818, 0.4008,
         0.4733],
        [0.3038, 0.4008, 0.4738, 0.4675, 0.4891, 0.3956, 0.4517, 0.4708, 0.4639,
         0.3636]])

In [None]:
#testing init_gmm_params_separate

In [9]:
output = [[] for m in model.vaes]
for i, dataT in enumerate(train_loader):

    if device == torch.device('cuda'):
        data = [d.to(device) for d in dataT] #GPU
    else:
        data = dataT
        #_, _, _ = self.forward(data)
    lats = model.latents(data, sampling=False)
    for m, o in enumerate(output):
        o.append(lats[m].detach().cpu())

In [11]:
gmm = GaussianMixture(n_components=model.params.n_centroids, covariance_type='diag', init_params='kmeans')
for m,o in enumerate(output):
    o = torch.cat(o).numpy()
    gmm.fit(o)
    model.vaes[m].pi.data.copy_(torch.from_numpy(gmm.weights_.astype(np.float32)))
    model.vaes[m].mu_c.data.copy_(torch.from_numpy(gmm.means_.T.astype(np.float32)))
    model.vaes[m].var_c.data.copy_(torch.from_numpy(gmm.covariances_.T.astype(np.float32)))

In [12]:
model.vaes[0].pi

Parameter containing:
tensor([4.9956e-01, 4.0000e-04, 3.6207e-03, 3.0672e-02, 1.8135e-01, 4.0000e-04,
        2.3129e-01, 8.8569e-03, 1.2321e-02, 3.1529e-02], requires_grad=True)

In [23]:
model.vaes[1].mu_c

Parameter containing:
tensor([[ 0.0507,  0.0557,  0.0528,  0.0485,  0.0542,  0.0496,  0.0499,  0.0499,
          0.0541,  0.0559],
        [ 0.0091,  0.0177,  0.0091,  0.0094,  0.0098,  0.0112,  0.0123,  0.0161,
          0.0136,  0.0163],
        [-0.0095, -0.0110, -0.0098, -0.0088, -0.0106, -0.0105, -0.0109, -0.0123,
         -0.0121, -0.0133],
        [ 0.0317,  0.0450,  0.0364,  0.0373,  0.0379,  0.0314,  0.0309,  0.0356,
          0.0321,  0.0367],
        [ 0.0520,  0.0474,  0.0516,  0.0495,  0.0498,  0.0511,  0.0516,  0.0487,
          0.0514,  0.0512],
        [ 0.0765,  0.0735,  0.0710,  0.0804,  0.0770,  0.0788,  0.0754,  0.0761,
          0.0736,  0.0670],
        [ 0.0379,  0.0396,  0.0381,  0.0356,  0.0406,  0.0371,  0.0371,  0.0362,
          0.0398,  0.0396],
        [-0.0249, -0.0208, -0.0213, -0.0243, -0.0284, -0.0246, -0.0212, -0.0181,
         -0.0232, -0.0170]], requires_grad=True)

[tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]])]