In [None]:
#import argparse
import datetime
import sys
import json
from collections import defaultdict
from pathlib import Path
from tempfile import mkdtemp

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Subset, DataLoader
from torchnet.dataset import TensorDataset, ResampleDataset
import matplotlib.pyplot as plt
import math

import models
import objectives_dev as objectives
#from utils import Logger, Timer, save_model, save_vars, unpack_data
from utils_dev import Logger, Timer, save_model, save_vars, unpack_data, EarlyStopping, vade_kld

In [None]:
#args
experiment = 'test'
model = 'rna_atac_dev' #VAE試しに使う
obj = 'elbo'
K = 1
looser = False
llik_scaling = 0
batch_size = 1024
epochs = 100
n_centroids = 10
latent_dim = 20
num_hidden_layers = 1
hidden_dim = [128, 128]
learn_prior = False
logp = False
print_freq = 0
no_analytics = False
seed = 1
dataSize = []
r_dim = a_dim = []

class params():
    
    def __init__(self,
                 experiment,
                 model,
                 obj,
                 K,
                 looser,
                 llik_scaling,
                 batch_size,
                 epochs,
                 n_centroids,
                 latent_dim,
                 num_hidden_layers,
                 hidden_dim,
                 learn_prior,
                 logp,
                 print_freq,
                 no_analytics,
                 seed,
                dataSize,
                r_dim,
                a_dim):
        
        self.experiment = experiment
        self.model = model
        self.obj = obj
        self.K = K
        self.looser = looser
        self.llik_scaling = llik_scaling
        self.batch_size = batch_size
        self.epochs = epochs
        self.n_centroids = n_centroids
        self.latent_dim = latent_dim
        self.num_hidden_layers = num_hidden_layers
        self.hidden_dim = hidden_dim
        self.learn_prior = learn_prior
        self.logp = logp
        self.print_freq = print_freq
        self.no_analytics = no_analytics
        self.seed = seed
        self.dataSize = dataSize
        self.r_dim = r_dim
        self.a_dim = a_dim
        
args = params(experiment,
                model,
                 obj,
                 K,
                 looser,
                 llik_scaling,
                 batch_size,
                 epochs,
                 n_centroids,
                 latent_dim,
                 num_hidden_layers,
                 hidden_dim,
                 learn_prior,
                 logp,
                 print_freq,
                 no_analytics,
                 seed,
                 dataSize,
                 r_dim,
                 a_dim)

In [3]:
# random seed
# https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
np.random.seed(args.seed)

In [4]:
device = torch.device("cpu")

In [None]:
# set up run path
#runId = datetime.datetime.now().isoformat()
runId ='test'
experiment_dir = Path('../experiments/' + args.experiment)
experiment_dir.mkdir(parents=True, exist_ok=True)
runPath = mkdtemp(prefix=runId, dir=str(experiment_dir))
print(runPath)

In [None]:
#train_loader = model.getDataLoaders(batch_size=args.batch_size, device=device) #for train only

In [5]:
dataset_path = '../data/Paired-seq/combined/'
r_dataset = torch.load(dataset_path + 'r_dataset.rar')
a_dataset = torch.load(dataset_path + 'a_dataset.rar')

In [6]:
num = 5000
#num = 25845
r_dataset = Subset(r_dataset, list(range(num)))
a_dataset = Subset(a_dataset, list(range(num)))

In [7]:
train_dataset= TensorDataset([
    #ResampleDataset(r_dataset),
    #ResampleDataset(a_dataset)
    r_dataset,
    a_dataset
    ])

In [8]:
train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

In [9]:
#args.r_dim = r_dataset.data.shape[1]
#args.a_dim = a_dataset.data.shape[1]
args.r_dim = r_dataset.dataset.shape[1]
args.a_dim = a_dataset.dataset.shape[1]
r_dataset = a_dataset = train_dataset = None

In [10]:
# load model
modelC = getattr(models, 'VAE_{}'.format(args.model))
model = modelC(args).to(device)

In [11]:
# preparation for training
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=1e-4, amsgrad=True)

In [None]:
#pre_objective  = getattr(objectives, 'elbo_ae') 
pre_objective  = getattr(objectives, 'm_elbo_naive_ae') 
#pretrained_path = '../data/Paired-seq/combined/RNA-seq/'
pretrained_path = '../data/Paired-seq/combined/subset'

In [None]:
def pretrain(epoch, agg):
    model.train()
    b_loss = 0
    for i, dataT in enumerate(train_loader):
        
        #data = unpack_data(dataT, device=device) #unimodal
        data = dataT #multimodal
        optimizer.zero_grad()
        #loss = -objective(model, data, K=args.K)
        loss = -pre_objective(model, data, K=args.K)
        loss.backward()
        optimizer.step()
        b_loss += loss.item()
        if args.print_freq > 0 and i % args.print_freq == 0:
            print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
    agg['train_loss'].append(b_loss / len(train_loader.dataset))
    print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1]))

In [None]:
with Timer('MM-VAE') as t:
        agg = defaultdict(list)
        pretrain_epoch = 5
        for epoch in range(1, pretrain_epoch + 1):
            pretrain(epoch, agg)
            save_model(model, pretrained_path + '/model.rar')
            save_vars(agg, pretrained_path + '/losses.rar')

In [None]:
print('Loading model {} from {}'.format(model.modelName, pretrained_path))
model.load_state_dict(torch.load(pretrained_path + '/model.rar', map_location=device))
model._pz_params = model._pz_params

In [None]:
rescue = '../experiments/test/2020-04-28T17:13:31.932565bjdxippz'
print('Loading model {} from {}'.format(model.modelName, rescue))
model.load_state_dict(torch.load(rescue + '/model.rar.old', map_location=device))
model._pz_params = model._pz_params

In [None]:
fit = False
model.init_gmm_params(train_loader, fit=fit, var=0.1, device=device)
#model.init_gmm_params_separate(train_loader, device=device)

In [None]:
pre_pi= model._pz_params[0].detach()
pre_mu = model._pz_params[1].detach()
pre_var = model._pz_params[2].detach()

In [None]:
print(pre_pi)
print(pre_mu)
print(pre_var)

In [None]:
print(model._pz_params[0]/sum(model._pz_params[0]))
print(model._pz_params[1])
print(model._pz_params[2])

In [12]:
#training
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=1e-4, amsgrad=True)

In [13]:
#objective = getattr(objectives, 'elbo_vade') 
objective = getattr(objectives, 'm_elbo_naive_vade') 
#objective = getattr(objectives, 'm_elbo_vade') 
#objective = getattr(objectives, 'm_elbo_vade_warmup') 
#objective = getattr(objectives, 'm_elbo_vade_separate') 

In [None]:
def train(epoch, agg, W=30):
    model.train()
    b_loss = 0
    adj = 1
    #beta = (epoch - 1) / W  if epoch <= W else 1
    
    alpha = 100
    beta = alpha * (epoch - 1) / W if epoch<=W else alpha
    for i, dataT in enumerate(train_loader):
        
        #data = unpack_data(dataT, device=device) #unimodal
        data = dataT #multimodal
        optimizer.zero_grad()
        if objective==getattr(objectives, 'm_elbo_vade_warmup'):
            loss = -objective(model, data, beta, K=args.K)
        else:
            loss = -objective(model, data, adj=adj, K=args.K)
        loss.backward()
        optimizer.step()
        b_loss += loss.item()
        if args.print_freq > 0 and i % args.print_freq == 0:
            print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
    agg['train_loss'].append(b_loss / len(train_loader.dataset))
    print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1]))

In [14]:
model.train()
b_loss = 0
adj = 1
#beta = (epoch - 1) / W  if epoch <= W else 1

alpha = 100
#beta = alpha * (epoch - 1) / W if epoch<=W else alpha
for i, dataT in enumerate(train_loader):
    
    #data = unpack_data(dataT, device=device) #unimodal
    data = dataT #multimodal
    optimizer.zero_grad()
    if objective==getattr(objectives, 'm_elbo_vade_warmup'):
        loss = -objective(model, data, beta, K=args.K)
    else:
        loss = -objective(model, data, adj=adj, K=args.K)
    loss.backward()
    optimizer.step()
    b_loss += loss.item()
    if i == 0:
        break

In [16]:
from graphviz import Source
from torchviz import make_dot
arch = make_dot(loss)
Source(arch).render('../data/arch.png')

'../data/arch.png.pdf'

In [None]:
with Timer('MM-VAE') as t:
        agg = defaultdict(list)
        # initialize the early_stopping object
        early_stopping = EarlyStopping(patience=10, verbose=True) 
        
        for epoch in range(1, args.epochs + 1):
            train(epoch, agg)
            #save_model(model, runPath + '/model.rar')
            save_vars(agg, runPath + '/losses.rar')
            
            # early_stopping needs the validation loss to check if it has decresed, 
            # and if it has, it will make a checkpoint of the current model
            #validate(epoch, agg)
            #early_stopping(agg['val_loss'][-1], model, runPath)
            early_stopping(agg['train_loss'][-1], model, runPath)
            if early_stopping.early_stop:
                print('Early stopping')
                break
            
            #test(epoch, agg)

In [None]:
#MMVAE get all data
for i, d in enumerate(train_loader):
    if i == 0:
        data0 = d[0]
        data1 = d[1]
    else:
        data0 = torch.cat([data0, d[0]], dim=0)
        data1 = torch.cat([data1, d[1]], dim=0)
data = [data0.to(device), data1.to(device)]

In [None]:
model.visualize_latent(data, runPath, epoch=1, tsne=True, sampling=False)

In [None]:
#MMVAE get n data
n = 1
for i, d in enumerate(train_loader):
    if i == 0:
        data0 = d[0]
        data1 = d[1]
    elif i < n:
        data0 = torch.cat([data0, d[0]], dim=0)
        data1 = torch.cat([data1, d[1]], dim=0)
data = [data0.to(device), data1.to(device)]

In [None]:
#testing m_elbo_naive_vade
x = data
qz_xs, px_zs, zss = model(x)
n_centroids = model.params.n_centroids
lpx_zs, klds = [], []

In [None]:
model.vaes[0]._qz_x_params

In [None]:
for r, qz_x in enumerate(qz_xs):
    zs = zss[r]
    kld = vade_kld(model, zs, r)
    klds.append(kld)
        
    for d, px_z in enumerate(px_zs[r]):
        lpx_z = px_z.log_prob(x[d]) * model.vaes[d].llik_scaling
        #lpx_zs.append(lpx_z.view(*px_z.batch_shape[:2], -1).sum(-1).squeeze()) #added squeeze()
        lpx_zs.append(lpx_z.sum(-1))
#obj = (1 / len(model.vaes)) * (torch.stack(lpx_zs).sum(0) - torch.stack(klds).sum(0))
obj = (1 / len(model.vaes)) * (torch.stack(lpx_zs).sum(0) - torch.stack(klds).sum(0))

In [None]:
2**3

In [None]:
gamma

In [None]:
lgamma

In [None]:
torch.stack(lpx_zs).mean(1)

In [None]:
klds

In [None]:
r = 0
zs = zss[r]
n_centroids = model.params.n_centroids
gamma, lgamma, mu_c, var_c, pi = model.get_gamma(zs)
    
    #mu, logvar = model.vaes[r]._qz_x_params ミス 
mu, var = model.vaes[r]._qz_x_params
mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids)
    #logvar_expand = logvar.unsqueeze(2).expand(logvar.size(0), logvar.size(1), n_centroids)
var_expand = var.unsqueeze(2).expand(var.size(0), var.size(1), n_centroids)
        
    #lpz_c = -0.5*torch.sum(gamma*torch.sum(math.log(2*math.pi) + \
    #                                       torch.log(var_c) + \
    #                                       torch.exp(logvar_expand)/var_c + \
    #                                       (mu_expand-mu_c)**2/var_c, dim=1), dim=1) # log p(z|c)
lpz_c = -0.5*torch.sum(gamma*torch.sum(math.log(2*math.pi) + \
                                           torch.log(var_c) + \
                                           var_expand/var_c + \
                                           (mu_expand-mu_c)**2/var_c, dim=1), dim=1) # log p(z|c)
lpc = torch.sum(gamma*torch.log(pi), 1) # log p(c)
lqz_x = -0.5*torch.sum(1+torch.log(var)+math.log(2*math.pi), 1) #see VaDE paper # log q(z|x)
lqc_x = torch.sum(gamma*(lgamma), 1) # log q(c|x)
    
kld = -lpz_c - lpc + lqz_x + lqc_x 

In [None]:
lpz_c

In [None]:
lpc

In [None]:
lqz_x

In [None]:
lqc_x

In [None]:
 -lpz_c - lpc + lqz_x + lqc_x 