In [1]:
import os
import shutil
import unittest
from catvae.trainer import LightningVAE
from catvae.sim import multinomial_bioms
from biom import Table
from biom.util import biom_open
import numpy as np
from pytorch_lightning import Trainer
import argparse
import seaborn as sns

In [2]:
import torch
torch.__version__

'1.6.0'

# Simulate data

In [3]:
np.random.seed(0)
sims = multinomial_bioms(k=4, D=10, N=500, M=100)
Y = sims['Y']
parts = Y.shape[0] // 10
samp_ids = list(map(str, range(Y.shape[0])))
obs_ids = list(map(str, range(Y.shape[1])))
train = Table(Y[:parts * 8].T, obs_ids, samp_ids[:parts * 8])
test = Table(Y[parts * 8 : parts * 9].T,
             obs_ids, samp_ids[parts * 8 : parts * 9])
valid = Table(Y[parts * 9:].T, obs_ids, samp_ids[parts * 9:])
tree = sims
with biom_open('train.biom', 'w') as f:
    train.to_hdf5(f, 'train')
with biom_open('test.biom', 'w') as f:
    test.to_hdf5(f, 'test')
with biom_open('valid.biom', 'w') as f:
    valid.to_hdf5(f, 'valid')

# Train model

In [4]:
output_dir = 'output'
args = [
    '--train-biom', 'train.biom',
    '--test-biom', 'test.biom',
    '--val-biom', 'valid.biom',
    '--output-directory', output_dir,
    '--epochs', '1024',
    '--batch-size', '50',
    '--num-workers', '10',
    '--scheduler', 'cosine',
    '--learning-rate', '1e-2',
    '--n-latent', '4',
    '--gpus', '1'
]
parser = argparse.ArgumentParser(add_help=False)
parser = LightningVAE.add_model_specific_args(parser)
parser.add_argument('--num-workers', type=int)
parser.add_argument('--gpus', type=int)
args = parser.parse_args(args)
model = LightningVAE(args)
model.set_eigs(sims['eigvectors'], sims['eigs'])

trainer = Trainer(
    max_epochs=args.epochs,
    gpus=args.gpus,
    check_val_every_n_epoch=10,
    # profiler=profiler,
    fast_dev_run=False,
    # auto_scale_batch_size='power'
)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type         | Params
---------------------------------------
0 | model | LinearCatVAE | 81    


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

torch.Size([9, 4])
torch.Size([9, 4]) torch.Size([9])
torch.Size([9, 10]) torch.Size([50, 9])


RuntimeError: mat1 dim 1 must match mat2 dim 0

In [None]:
model.gt_eigvectors

In [None]:
!ls

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir lightning_logs

In [None]:
trainer.model

# Evaluate the model

In [None]:
!ls lightning_logs/version_0/checkpoints

In [None]:
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
checkpoint_dir = 'lightning_logs/version_0/checkpoints'
path = f'{checkpoint_dir}/epoch=399.ckpt'
model = LightningCountVAE.load_from_checkpoint(path).cuda()

In [None]:
valid

In [None]:
import torch
from mavi.dataset.biom import BiomDataset
from skbio.stats.composition import alr_inv, closure
valid_dataset = BiomDataset(valid)

In [None]:
print(self.S1.device, self.n.device, np.log(1 / self.n).device)
        z = [] 
pred_z = []
pred_probs = []
counts = []
for i in range(len(valid_dataset)):
    cnts, batch_idx = valid_dataset[i]
    counts.append(closure(cnts))
    cnts = torch.Tensor(cnts).cuda().unsqueeze(0)
    smoothed_cnts = cnts + 1
    res = model.model.inference(smoothed_cnts)
    z.append(sims['z'][parts * 8 + i])
    pred_z.append(res['qz_m'].cpu().detach().numpy())
    pred_probs.append(alr_inv(res['px_mean'].cpu().detach().numpy()))

In [None]:
counts = np.vstack(counts)
pred_probs = np.vstack(pred_probs)
pred_z = np.vstack(pred_z)

In [None]:
from scipy.spatial.distance import pdist
d_predz = pdist(pred_z)
dz = pdist(z)

In [None]:
z = np.vstack(z)

In [None]:
sns.heatmap(pred_z)

In [None]:
sns.heatmap(z)

In [None]:
import matplotlib.pyplot as plt
plt.scatter(d_predz, dz)

In [None]:
from scipy.stats import pearsonr
pearsonr(d_predz, dz)

In [None]:
W = model.model.get_loadings()
d_estW = pdist(W)
dW = pdist(sims['W'])
plt.scatter(dW, d_estW)
print(pearsonr(dW, d_estW))

In [None]:
lam = np.ravel(sims['z'][parts * 8 : parts * 8 + 50] @ sims['W'].T)
pred_lam = np.ravel(z @ W.T)
plt.scatter(lam, pred_lam)
print(pearsonr(lam, pred_lam))

In [None]:
z.shape, W.shape, lam.shape, sims['z'].shape, sims['W'].shape

In [None]:
from scipy.stats import entropy
ens = []
for i in range(counts.shape[0]):
    e = entropy(counts[i], pred_probs[i])
    ens.append(e)

In [None]:
sns.distplot(ens, bins=20)

In [None]:
sns.heatmap(pred_probs)

In [None]:
sns.heatmap(counts)

In [None]:
plt.scatter(counts.ravel()+1, pred_probs.ravel()+1)
plt.xscale('log')
plt.yscale('log')

In [None]:
pearsonr(counts.ravel()+1, pred_probs.ravel()+1)

In [None]:
torch.exp(model.model.encoder.variational_logvars)

In [None]:
parser