In [2]:
#import argparse
import datetime
import sys
import json
from collections import defaultdict
from pathlib import Path
from tempfile import mkdtemp

import numpy as np
import torch
from torch import optim
import matplotlib.pyplot as plt

import models
import objectives_dev as objectives
#from utils import Logger, Timer, save_model, save_vars, unpack_data
from utils_dev import Logger, Timer, save_model, save_vars, unpack_data, EarlyStopping

In [20]:
#args
experiment = 'test'
model = 'rna_atac_dev' #VAE試しに使う
obj = 'elbo'
K = 10
looser = False
llik_scaling = 0
batch_size = 128
epochs = 10
latent_dim = 8
num_hidden_layers = 2
hidden_dim =128
learn_prior = False
logp = False
print_freq = 0
no_analytics = False
seed = 1
dataSize = []

class params():
    
    def __init__(self,
                 experiment,
                 model,
                 obj,
                 K,
                 looser,
                 llik_scaling,
                 batch_size,
                 epochs,
                 latent_dim,
                 num_hidden_layers,
                 hidden_dim,
                 learn_prior,
                 logp,
                 print_freq,
                 no_analytics,
                 seed,
                dataSize):
        
        self.experiment = experiment
        self.model = model
        self.obj = obj
        self.K = K
        self.looser = looser
        self.llik_scaling = llik_scaling
        self.batch_size = batch_size
        self.epochs = epochs
        self.latent_dim = latent_dim
        self.num_hidden_layers = num_hidden_layers
        self.hidden_dim = hidden_dim
        self.learn_prior = learn_prior
        self.logp = logp
        self.print_freq = print_freq
        self.no_analytics = no_analytics
        self.seed = seed
        self.dataSize = dataSize
        
args = params(experiment,
                model,
                 obj,
                 K,
                 looser,
                 llik_scaling,
                 batch_size,
                 epochs,
                 latent_dim,
                 num_hidden_layers,
                 hidden_dim,
                 learn_prior,
                 logp,
                 print_freq,
                 no_analytics,
                 seed,
                 dataSize)

In [21]:
# random seed
# https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
np.random.seed(args.seed)

In [22]:
device = torch.device("cpu")

In [23]:
# load model
modelC = getattr(models, 'VAE_{}'.format(args.model))
model = modelC(args).to(device)

In [24]:
print(model)

RNA_ATAC(
  (vaes): ModuleList(
    (0): RNA(
      (enc): Enc(
        (enc): Sequential(
          (0): Sequential(
            (0): Linear(in_features=23758, out_features=128, bias=True)
            (1): ReLU(inplace=True)
          )
          (1): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU(inplace=True)
          )
        )
        (fc21): Linear(in_features=128, out_features=8, bias=True)
        (fc22): Linear(in_features=128, out_features=8, bias=True)
      )
      (dec): Dec(
        (dec): Sequential(
          (0): Sequential(
            (0): Linear(in_features=8, out_features=128, bias=True)
            (1): ReLU(inplace=True)
          )
          (1): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU(inplace=True)
          )
        )
        (fc31): Linear(in_features=128, out_features=23758, bias=True)
        (fc32): Linear(in_features=128, out_feature

In [8]:
# set up run path
runId = datetime.datetime.now().isoformat()
experiment_dir = Path('../experiments/' + args.experiment)
experiment_dir.mkdir(parents=True, exist_ok=True)
runPath = mkdtemp(prefix=runId, dir=str(experiment_dir))
print(runPath)

../experiments/junk/2020-04-18T14:25:44.897101jxrkmmr3


In [9]:
# preparation for training
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=1e-4, amsgrad=True)

In [10]:
train_loader, val_loader, test_loader = model.getDataLoaders(args.batch_size, device=device)

Loading  data ...
Original data contains 10309 cells x 33160 peaks
Finished loading takes 1.21 min
Loading  data ...
Original data contains 10309 cells x 244544 peaks
Finished loading takes 0.84 min


In [44]:
for m, vae in enumerate(model.vaes):
    args.dataSize.append(torch.Size([1, len(train_loader.dataset[0][m])]))

23758

In [9]:
#objective = getattr(objectives,
#                    ('m_' if hasattr(model, 'vaes') else '')
#                    + args.obj
#                    + ('_looser' if (args.looser and args.obj != 'elbo') else ''))
objective = getattr(objectives, 'm_elbo_warmup') #test warmup

t_objective = getattr(objectives, ('m_' if hasattr(model, 'vaes') else '') + 'iwae')

print(objective)

<function m_elbo_warmup at 0x130929bf8>


In [10]:
def train(epoch, agg, W=30):
    model.train()
    b_loss = 0
    beta = (epoch - 1) / W  if epoch <= W else 1
    for i, dataT in enumerate(train_loader):
        
        #data = unpack_data(dataT, device=device)
        data = dataT #mmvae_rna_atac
        optimizer.zero_grad()
        #loss = -objective(model, data, K=args.K)
        loss = -objective(model, data, beta, K=args.K)
        loss.backward()
        optimizer.step()
        b_loss += loss.item()
        if args.print_freq > 0 and i % args.print_freq == 0:
            print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
    agg['train_loss'].append(b_loss / len(train_loader.dataset))
    print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1]))

In [11]:
def validate(epoch, agg, W=30):
    model.eval()
    b_loss = 0
    beta = (epoch - 1) / W  if epoch <= W else 1
    with torch.no_grad():
        for i, dataT in enumerate(val_loader):
            #data = unpack_data(dataT, device=device)
            data = dataT #mmvae_rna_atac
            #loss = -t_objective(model, data, K=args.K)
            #loss = -objective(model, data, K=args.K)
            loss = -objective(model, data, beta, K=args.K)
            b_loss += loss.item()
    agg['val_loss'].append(b_loss / len(val_loader.dataset))
    print('====>             Validation loss: {:.4f}'.format(agg['val_loss'][-1]))

In [12]:
def test(epoch, agg, W=30):
    model.eval()
    b_loss = 0
    beta = (epoch - 1) / W  if epoch <= W else 1
    with torch.no_grad():
        for i, dataT in enumerate(test_loader):
            
            #data = unpack_data(dataT, device=device)
            data = dataT #mmvae_rna_atac
            
            #loss = -t_objective(model, data, K=args.K)
            #loss = -objective(model, data, K=args.K)
            loss = -objective(model, data, beta, K=args.K)
            b_loss += loss.item()
            #if i == 0:
            #    model.reconstruct(data, runPath, epoch)
            #model.reconstruct(data, runPath, epoch, sampling=False, N=1)
            #    if not args.no_analytics:
           # model.analyse(data, runPath, epoch)
    agg['test_loss'].append(b_loss / len(test_loader.dataset))
    print('====>             Test loss: {:.4f}'.format(agg['test_loss'][-1]))

In [13]:
def estimate_log_marginal(K):
    """Compute an IWAE estimate of the log-marginal likelihood of test data."""
    model.eval()
    marginal_loglik = 0
    with torch.no_grad():
        for dataT in test_loader:
            data = unpack_data(dataT, device=device)
            marginal_loglik += -t_objective(model, data, K).item()

    marginal_loglik /= len(test_loader.dataset)
    print('Marginal Log Likelihood (IWAE, K = {}): {:.4f}'.format(K, marginal_loglik))

In [14]:
#sys.stdout = Logger('{}/run.log'.format(runPath))
#print('Expt:', runPath)
#print('RunID:', runId)
            
with Timer('MM-VAE') as t:
        agg = defaultdict(list)
        # initialize the early_stopping object
        early_stopping = EarlyStopping(patience=7, verbose=True) 
        
        for epoch in range(1, args.epochs + 1):
            train(epoch, agg)
            #save_model(model, runPath + '/model.rar')
            save_vars(agg, runPath + '/losses.rar')
            
            # early_stopping needs the validation loss to check if it has decresed, 
            # and if it has, it will make a checkpoint of the current model
            validate(epoch, agg)
            early_stopping(agg['val_loss'][-1], model, runPath)
            if early_stopping.early_stop:
                print('Early stopping')
                break
            
            test(epoch, agg)
            #model.generate(runPath, epoch)
       # if args.logp:  # compute as tight a marginal likelihood as possible
           # estimate_log_marginal(5000)

====> Epoch: 001 Train loss: 66731.9997
====>             Validation loss: 55159.3871
Validation loss decreased (inf --> 55159.387136).  Saving model ...
====>             Test loss: 48476.2845
====> Epoch: 002 Train loss: 48256.8420
====>             Validation loss: 16817.6148
Validation loss decreased (55159.387136 --> 16817.614819).  Saving model ...
====>             Test loss: 16762.3241
====> Epoch: 003 Train loss: 13528.2078
====> [MM-VAE] Time: 528.746s or 00:08:48


KeyboardInterrupt: 

In [None]:
recon_t, recon_v, recon_s = model.getDataLoaders(batch_size=1024,device=device) #姑息

In [None]:
for i,d in enumerate(recon_s): #データ取得
    if i == 0:
        data = d #get first mini-batch

In [None]:
print(data.size())

In [None]:
model.reconstruct(data, runPath, epoch=1, sampling=True, N=1)

In [None]:
model.analyse(data, runPath, epoch=1)

In [None]:
#VAE
for i, d in enumerate(recon_t):
    if i == 0:
        data = d
    else:
        data = torch.cat([data, d], dim=0)

In [9]:
#MMVAE
for i, d in enumerate(train_loader):
    if i == 0:
        data0 = d[0]
        data1 = d[1]
    else:
        data0 = torch.cat([data0, d[0]], dim=0)
        data1 = torch.cat([data1, d[1]], dim=0)
data = [data0, data1]

In [10]:
print(data[0].shape)
print(data[1].shape)

torch.Size([8350, 23758])
torch.Size([8350, 220258])


In [None]:
#full_t, full_s = model.getDataLoaders(batch_size=,device=device) #姑息

In [None]:
#for i,d in enumerate(full_s): #full データ取得
#    if i == 0:
#        full_data = d

In [None]:
model.visualize_latent(data, runPath, epoch=1, tsne=True, sampling = False)

In [None]:
#for 2d latent
lats = model.latents(data, sampling = False)
if len(lats) == 2:
    lat_rna = lats[0]
    lat_atac = lats[1]

In [None]:
if len(lats) == 2:
    plt.figure()
    plt.scatter(lat_rna[:,0],lat_rna[:,1],s=0.5)
    plt.savefig('{}/lat_rna.png'.format(runPath ), dpi=1000)
    plt.close('all')

    plt.figure()
    plt.scatter(lat_atac[:,0],lat_atac[:,1],s=0.5)
    plt.savefig('{}/lat_atac.png'.format(runPath ), dpi=1000)
    plt.close('all')

else: 
    plt.figure()
    plt.scatter(lats[:,0],lats[:,1],s=0.5)
    plt.savefig('{}/lat.png'.format(runPath ), dpi=1000)
    plt.close('all')

In [None]:
mean_lats = sum(lats)/len(lats)

plt.figure()
plt.scatter(mean_lats[:,0],mean_lats[:,1],s=0.5)
plt.savefig('{}/lat_mean.png'.format(runPath ), dpi=1000)
plt.close('all')

In [None]:
print(mean_lats)

In [37]:
len(test_loader.dataset)

1032

In [21]:
for i,d in enumerate(train_loader):
    if i==0:
        atac = d[1]

In [22]:
print(atac.shape)

torch.Size([128, 56340])


In [35]:
print(sum(atac==1))
print(sum(sum(atac==1)))
print(len(sum(atac==1)))

tensor([ 3,  6,  4,  ..., 19,  8,  6])
tensor(242201)
56340


In [26]:
128*56340

7211520

In [27]:
6969319/7211520

0.9664147086883209

In [36]:
242201/7211520

0.03358529131167909

NameError: name 'data' is not defined