In [1]:
import os
import shutil
import unittest
from catvae.trainer import MultBatchVAE, BiomDataModule
from catvae.sim import multinomial_batch_bioms
from biom import Table
from biom.util import biom_open
import numpy as np
from pytorch_lightning import Trainer
import torch
import argparse
import seaborn as sns
import pandas as pd

from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist

from pytorch_lightning.profiler import AdvancedProfiler

In [3]:
np.random.seed(0)
k = 20
C = 3
D = 100
sims = multinomial_batch_bioms(k=k, D=D, N=2000, M=1e6, C=C)
Y = sims['Y']
parts = Y.shape[0] // 10
samp_ids = list(map(str, range(Y.shape[0])))
obs_ids = list(map(str, range(Y.shape[1])))
train = Table(Y[:parts * 8].T, obs_ids, samp_ids[:parts * 8])
test = Table(Y[parts * 8 : parts * 9].T,
             obs_ids, samp_ids[parts * 8 : parts * 9])
valid = Table(Y[parts * 9:].T, obs_ids, samp_ids[parts * 9:])
tree = sims
with biom_open('train.biom', 'w') as f:
    train.to_hdf5(f, 'train')
with biom_open('test.biom', 'w') as f:
    test.to_hdf5(f, 'test')
with biom_open('valid.biom', 'w') as f:
    valid.to_hdf5(f, 'valid')

md = pd.DataFrame({'batch_category': sims['batch_idx']}, index=samp_ids)
md.index.name = 'sampleid'
md.to_csv('metadata.txt', sep='\t')
batch_priors = pd.Series(sims['alphaILR'])
batch_priors.to_csv('batch_priors.txt', sep='\t')

sims['tree'].write('basis.nwk')

'basis.nwk'

Run batch effects removal VAE

In [4]:
output_dir = 'output'

dm = BiomDataModule('train.biom', 'test.biom', 'valid.biom',
                    metadata='metadata.txt',
                    batch_category='batch_category',
                    batch_size=50)
model = MultBatchVAE(n_input=D, n_latent=k,
                     n_hidden=16, n_batches=C,
                     basis='basis.nwk', batch_prior='batch_priors.txt',
                     dropout=0.5, bias=True, batch_norm=True,
                     encoder_depth=1, learning_rate=0.1,
                     scheduler='cosine', transform='pseudocount')
print(model)

MultBatchVAE(
  (vae): LinearBatchVAE(
    (encoder): Encoder(
      (encoder): Linear(in_features=99, out_features=20, bias=True)
    )
    (decoder): ParametrizedLinear(
      in_features=20, out_features=99, bias=True
      (parametrizations): ModuleDict(
        (weight): GrassmannianTall(n=99, k=20, triv=expm)
      )
    )
    (beta): Embedding(3, 99)
  )
)


In [7]:
trainer = Trainer(
    max_epochs=3,
    gpus=0,
    check_val_every_n_epoch=1,
    # profiler=profiler,
    fast_dev_run=False,
    # auto_scale_batch_size='power'
)
trainer.fit(model, dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type           | Params
----------------------------------------
0 | vae  | LinearBatchVAE | 4.5 K 
----------------------------------------
4.5 K     Trainable params
0         Non-trainable params
4.5 K     Total params
0.018     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs

In [None]:
#W = model.get_embedding(exclude_batch=True, eps=1e-3).detach().cpu().numpy()
W = model.model.decoder.weight.detach().cpu().numpy()
d_estW = pdist(W)
simW = sims['W'] / np.sqrt(sims['eigs'])
dW = pdist(simW)

plt.scatter(dW, d_estW, s=1)
plt.plot(np.linspace(0.3, 1), np.linspace(0.3, 1), 'r')
plt.xlabel('Predicted correlations')
plt.ylabel('Actual correlations')

print(pearsonr(dW, d_estW))

In [None]:
x = torch.Tensor(sims['Y']).float()
b = torch.Tensor(sims['batch_idx']).long()
z = model.model.encode(x, b)

dsimz = pdist(sims['z'])
dz = pdist(z.detach().cpu().numpy())
plt.scatter(dz, dsimz, s=1)
plt.xlabel('Predicted distance z')
plt.ylabel('Actual distance z')
print(pearsonr(dz, dsimz))

In [None]:
from sklearn.metrics import accuracy_score
batch_pred = model.discriminator(x)
batch_ids = torch.Tensor(sims['batch_idx']).long()

acc = accuracy_score(batch_pred.detach().cpu().numpy().argmax(axis=1),
                     batch_ids.detach().cpu().numpy())            
print(acc)

In [None]:
x = torch.Tensor(sims['Y']).float()
z = model.to_latent(x)

dsimz = pdist(sims['z'])
dz = pdist(z.detach().cpu().numpy())
plt.scatter(dz, dsimz, s=1)
plt.xlabel('Predicted distance z')
plt.ylabel('Actual distance z')
print(pearsonr(dz, dsimz))

In [None]:
i = np.argsort(md['batch_category'].values[:parts * 8])
sns.heatmap(z[i].detach().cpu().numpy())

In [None]:
B = model.model.beta.weight.detach().cpu().numpy().T
d_estB = pdist(B)
simB = sims['B'].T
dB = pdist(simB)

plt.scatter(dB, d_estB, s=1)
#plt.plot(np.linspace(0, 4), np.linspace(0, 4), 'r')
plt.xlabel('Predicted batch correlations')
plt.ylabel('Actual batch correlations')

print(pearsonr(dB, d_estB))