In [1]:
import os
import shutil
import unittest
from catvae.trainer import LightningCountVAE
from catvae.sim import multinomial_bioms
from biom import Table
from biom.util import biom_open
import numpy as np
from pytorch_lightning import Trainer
import argparse
import seaborn as sns

In [2]:
import torch
torch.__version__

'1.6.0'

# Simulate data

In [3]:
np.random.seed(0)
sims = multinomial_bioms(k=4, D=30, N=250, M=10000)
Y = sims['Y']
parts = Y.shape[0] // 10
samp_ids = list(map(str, range(Y.shape[0])))
obs_ids = list(map(str, range(Y.shape[1])))
train = Table(Y[:parts * 8].T, obs_ids, samp_ids[:parts * 8])
test = Table(Y[parts * 8 : parts * 9].T,
             obs_ids, samp_ids[parts * 8 : parts * 9])
valid = Table(Y[parts * 9:].T, obs_ids, samp_ids[parts * 9:])
tree = sims
with biom_open('train.biom', 'w') as f:
    train.to_hdf5(f, 'train')
with biom_open('test.biom', 'w') as f:
    test.to_hdf5(f, 'test')
with biom_open('valid.biom', 'w') as f:
    valid.to_hdf5(f, 'valid')

In [4]:
sims['tree'].write('basis.nwk')

'basis.nwk'

# Train model

In [5]:
output_dir = 'output'
args = [
    '--train-biom', 'train.biom',
    '--test-biom', 'train.biom',
    '--val-biom', 'train.biom',
    '--basis-file', 'basis.nwk',
    '--output-directory', output_dir,
    '--epochs', '30000',
    '--batch-size', '200',
    '--num-workers', '10',
    '--scheduler', 'cosine',
    '--learning-rate', '1e-1',
    '--n-latent', '4',
    '--gpus', '1'
]
parser = argparse.ArgumentParser(add_help=False)
parser = LightningCountVAE.add_model_specific_args(parser)
parser.add_argument('--num-workers', type=int)
parser.add_argument('--gpus', type=int)
args = parser.parse_args(args)
model = LightningCountVAE(args)
model.set_eigs(sims['eigvectors'], sims['eigs'])

trainer = Trainer(
    max_epochs=args.epochs,
    gpus=args.gpus,
    check_val_every_n_epoch=10,
    # profiler=profiler,
    fast_dev_run=False,
    # auto_scale_batch_size='power'
)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type         | Params
---------------------------------------
0 | model | LinearCatVAE | 6 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

tensor([[-6.6151e-04, -1.2041e-03,  1.8023e-03,  9.4212e-04],
        [ 1.1088e-03, -3.1966e-04,  6.2828e-04, -1.1048e-03],
        [-4.6673e-04,  1.1367e-03, -4.8688e-04,  5.9106e-04],
        [-9.5821e-04, -7.2369e-05,  2.0285e-04,  4.9363e-05],
        [-1.6317e-04, -1.0512e-04,  5.1375e-04,  6.3205e-04],
        [ 1.5434e-03, -3.0442e-04,  9.0776e-04, -8.7771e-04],
        [-2.7815e-04, -1.0713e-03, -6.0843e-04,  9.4978e-04],
        [ 1.8745e-04, -9.0518e-04, -8.7342e-04,  4.7239e-04],
        [-1.5246e-03,  1.5120e-04,  6.3772e-04,  3.1237e-04],
        [-9.2798e-04,  6.8641e-04, -1.2747e-03,  4.6764e-04],
        [ 8.1069e-04, -4.7660e-04,  5.9866e-04, -2.0696e-04],
        [-5.3539e-04,  1.7420e-03,  7.5694e-04,  7.6580e-04],
        [ 1.0877e-04, -1.6315e-03, -7.0067e-04, -7.6027e-04],
        [ 7.6635e-04,  1.2570e-05,  8.8399e-04, -2.2097e-03],
        [ 4.3665e-05,  1.2349e-03,  8.5998e-06, -9.9088e-04],
        [ 1.4860e-03,  1.0444e-03, -1.1456e-04,  7.4403e-04],
        

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

tensor([[-6.6151e-04, -1.2041e-03,  1.8023e-03,  9.4212e-04],
        [ 1.1088e-03, -3.1966e-04,  6.2828e-04, -1.1048e-03],
        [-4.6673e-04,  1.1367e-03, -4.8688e-04,  5.9106e-04],
        [-9.5821e-04, -7.2369e-05,  2.0285e-04,  4.9363e-05],
        [-1.6317e-04, -1.0512e-04,  5.1375e-04,  6.3205e-04],
        [ 1.5434e-03, -3.0442e-04,  9.0776e-04, -8.7771e-04],
        [-2.7815e-04, -1.0713e-03, -6.0843e-04,  9.4978e-04],
        [ 1.8745e-04, -9.0518e-04, -8.7342e-04,  4.7239e-04],
        [-1.5246e-03,  1.5120e-04,  6.3772e-04,  3.1237e-04],
        [-9.2798e-04,  6.8641e-04, -1.2747e-03,  4.6764e-04],
        [ 8.1069e-04, -4.7660e-04,  5.9866e-04, -2.0696e-04],
        [-5.3539e-04,  1.7420e-03,  7.5694e-04,  7.6580e-04],
        [ 1.0877e-04, -1.6315e-03, -7.0067e-04, -7.6027e-04],
        [ 7.6635e-04,  1.2570e-05,  8.8399e-04, -2.2097e-03],
        [ 4.3665e-05,  1.2349e-03,  8.5998e-06, -9.9088e-04],
        [ 1.4860e-03,  1.0444e-03, -1.1456e-04,  7.4403e-04],
        

RuntimeError: sparse tensors do not have strides
Exception raised from strides at /opt/conda/conda-bld/pytorch_1595629395347/work/aten/src/ATen/SparseTensorImpl.cpp:52 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x4d (0x7f3d7420877d in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: at::SparseTensorImpl::strides() const + 0xb2 (0x7f3d52e88402 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #2: <unknown function> + 0xb3737b (0x7f3d531f037b in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0xb3d655 (0x7f3d531f6655 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0xb3da80 (0x7f3d531f6a80 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #5: at::native::sum_out(at::Tensor&, at::Tensor const&, c10::ArrayRef<long>, bool, c10::optional<c10::ScalarType>) + 0x8f (0x7f3d531f6b6f in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #6: at::native::sum(at::Tensor const&, c10::ArrayRef<long>, bool, c10::optional<c10::ScalarType>) + 0x4b (0x7f3d531f71fb in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #7: at::native::sum(at::Tensor const&, c10::optional<c10::ScalarType>) + 0x38 (0x7f3d531f72c8 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0xff5633 (0x7f3d536ae633 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x7f67be (0x7f3d52eaf7be in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::sum(at::Tensor const&, c10::optional<c10::ScalarType>) + 0xf8 (0x7f3d535e70a8 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0x2bb396b (0x7f3d5526c96b in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x7f67be (0x7f3d52eaf7be in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #13: at::Tensor::sum(c10::optional<c10::ScalarType>) const + 0xf8 (0x7f3d53758818 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x30c63e0 (0x7f3d5577f3e0 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #15: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x3fd (0x7f3d5578485d in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #16: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x451 (0x7f3d55786401 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #17: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x89 (0x7f3d5577e579 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #18: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x4a (0x7f3d5969e1ba in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #19: <unknown function> + 0xc9067 (0x7f3d922f2067 in /home/juermieboop/miniconda3/envs/catvae/lib/python3.8/site-packages/zmq/backend/cython/../../../../.././libstdc++.so.6)
frame #20: <unknown function> + 0x76db (0x7f3d94fa86db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #21: clone + 0x3f (0x7f3d94cd1a3f in /lib/x86_64-linux-gnu/libc.so.6)


In [None]:
a = torch.randn(2, 3).to_sparse().requires_grad_(True)
b = torch.randn(3, 2, requires_grad=True)
#y = torch.sparse.mm(a, b)
y = a @ b
y.sum().backward()  
a.grad

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir lightning_logs

In [None]:
trainer.model

# Evaluate the model

In [None]:
!ls lightning_logs/version_1/checkpoints

In [None]:
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
checkpoint_dir = 'lightning_logs/version_1/checkpoints'
path = f'{checkpoint_dir}/epoch=139.ckpt'
model = LightningCountVAE.load_from_checkpoint(path).cuda()

In [None]:
#W = model.model.get_loadings()
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist

W = model.model.decoder.weight.detach().cpu().numpy()

d_estW = pdist(W)
dW = pdist(sims['W'])
plt.scatter(dW, d_estW)
plt.plot(np.linspace(0, 4), np.linspace(0, 4), 'r')
plt.xlabel('Predicted correlations')
plt.ylabel('Actual correlations')

print(pearsonr(dW, d_estW))

In [None]:
eigvals = (W**2).sum(axis=0)
Weig = W / np.sqrt(eigvals)

In [None]:
Weig.T @ Weig

In [None]:
Wu, Ws, Wv = np.linalg.svd(W)
k = W.shape[1]
Wu = Wu[:, :k]
Ws = Ws[:k]
Wv = Wv[:k, :]

In [None]:
W.T @ W

In [None]:
import torch
from catvae.dataset.biom import BiomDataset
from skbio.stats.composition import alr_inv, closure
valid_dataset = BiomDataset(valid)

In [None]:
decoder_np = model.model.decoder.weight.detach().cpu().numpy()
z = [] 
pred_z = []
pred_probs = []
counts = []
for i in range(len(valid_dataset)):
    cnts, batch_idx = valid_dataset[i]
    counts.append(closure(cnts))
    cnts = torch.Tensor(cnts).cuda().unsqueeze(0)
    res = model.forward(cnts)
    z.append(sims['z'][parts * 8 + i])
    pred_z.append(res['qz_m'].cpu().detach().numpy())
    pred_probs.append(alr_inv(res['px_mean'].cpu().detach().numpy()))

In [None]:
counts = np.vstack(counts)
pred_probs = np.vstack(pred_probs)
pred_z = np.vstack(pred_z)

In [None]:
from scipy.spatial.distance import pdist
d_predz = pdist(pred_z)
dz = pdist(z)

In [None]:
z = np.vstack(z)

In [None]:
sns.heatmap(pred_z)

In [None]:
sns.heatmap(z)

In [None]:
#W = model.model.get_loadings()
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist

W = model.model.decoder.weight.detach().cpu().numpy()

d_estW = pdist(W)
dW = pdist(sims['W'])
plt.scatter(dW, d_estW)
print(pearsonr(dW, d_estW))

In [None]:
lam = np.ravel(sims['z'][parts * 8 : parts * 8 + 50] @ sims['W'].T)
pred_lam = np.ravel(z @ W.T)
plt.scatter(lam, pred_lam)
print(pearsonr(lam, pred_lam))

In [None]:
z.shape, W.shape, lam.shape, sims['z'].shape, sims['W'].shape

In [None]:
from scipy.stats import entropy
ens = []
for i in range(counts.shape[0]):
    e = entropy(counts[i], pred_probs[i])
    ens.append(e)

In [None]:
sns.distplot(ens, bins=20)

In [None]:
sns.heatmap(pred_probs)

In [None]:
sns.heatmap(counts)

In [None]:
plt.scatter(counts.ravel()+1, pred_probs.ravel()+1)
plt.xscale('log')
plt.yscale('log')

In [None]:
pearsonr(counts.ravel()+1, pred_probs.ravel()+1)

In [None]:
torch.exp(model.model.encoder.variational_logvars)

In [None]:
parser

In [None]:
# scratch work with least squares

from gneiss.balances import _balance_basis
from gneiss.cluster import random_linkage
d = 100
basis = _balance_basis(random_linkage(d))[0]

Id = np.eye(d)
dd = (1 / d) * np.ones((d, d))

A = Id - dd
b = np.random.randn(d)

x, err, rank, s = np.linalg.lstsq(A, b, rcond=None)