In [None]:
import torch
import torch.nn as nn

import collections
from typing import List, Tuple

import pyro
from pyro.infer import SVI, Trace_ELBO, TraceGraph_ELBO
from pyro.optim import CosineAnnealingWarmRestarts

import numpy as np
import matplotlib.pyplot as plt
import scipy.special as ssp

import umap

from torch.optim import SGD
from torch.distributions.utils import probs_to_logits
from torch.utils.data import DataLoader, TensorDataset

In [None]:
import drvish.util.plot as drplt

from drvish.data import split_dataset, split_labeled_dataset
from drvish.train import train_until_plateau, evaluate, AggMo, train, evaluate
from drvish.util import build_dr_dataset

from drvish.models import NBVAE, DRNBVAE

from drvish.models.modules import LinearMultiBias

In [None]:
n_classes = 8
n_latent = 8
n_cells_per_class = 512
n_features = 128
n_drugs = 2
n_conditions = 8

In [None]:
exp, classes, progs, z, doses, drs, lib_size, umis = build_dr_dataset(
    n_classes=n_classes,
    n_latent=n_latent,
    n_cells_per_class=n_cells_per_class,
    n_features=n_features,
    n_drugs=n_drugs,
    n_conditions=n_conditions,
    library_kw={'loc': 5.5, 'scale': 0.5},
    class_kw={"scale": 2.0, "sparsity": 0.5}
)

In [None]:
drplt.make_grid(
    *[drplt.drug_response(d.reshape(-1, n_conditions), dos, classes)
     for d,dos in zip(drs,doses)],
    n_cols=2
)

In [None]:
te_umis = umis[:, 4:, :]
te_exp = exp[:, 4:, :]
te_classes = classes[classes >= 4]
te_drs = [d[:, 4:, :] for d in drs]

umis = umis[:, :4, :]
exp = exp[:, :4, :]
classes = classes[classes < 4]
drs = [d[:, :4, :] for d in drs]

In [None]:
umis_flat = umis.reshape((4 * n_cells_per_class, n_features))

dr_means = torch.stack([torch.tensor(ssp.logit(d).mean(0)) for d in drs], dim=2)

In [None]:
tr_dl, val_dl = split_dataset(
    torch.tensor(umis_flat, dtype=torch.float),
    batch_size=128,
    train_p=0.875,
)

pyro.clear_param_store()
nbvae = NBVAE(
    n_input=n_features,
    n_latent=16,
    layers=[256, 256],
)

In [None]:
tr_dl, val_dl = split_labeled_dataset(
    torch.tensor(umis_flat, dtype=torch.float),
    labels=classes,
    target=dr_means, 
    batch_size=128,
    train_p=0.875,
)

In [None]:
pyro.clear_param_store()
nbvae = DRNBVAE(
    n_input=n_features,
    n_classes=n_classes - 4,
    n_drugs=n_drugs,
    n_conditions=n_conditions,
    n_latent=16,
    layers=[256, 256],
    lam_scale=1.0,
    bias_scale=1.0,
)

In [None]:
scheduler = CosineAnnealingWarmRestarts(
    {
        "optimizer": AggMo,
        "T_0": 10,
        "eta_min": 1e-6,
        "optim_args": {"lr": 5e-4, "betas": [0.0, 0.9, 0.99], "nesterov": True},
    },
    {"clip_norm": 20.0}
)
svi = SVI(nbvae.model, nbvae.guide, scheduler, loss=TraceGraph_ELBO())

In [None]:
train_loss, val_loss = train_until_plateau(svi, scheduler, tr_dl, val_dl, verbose=True)

In [None]:
from drvish.train import cos_annealing_factor

In [None]:
caf = lambda e: cos_annealing_factor(e % 10, 10) * 2.0

In [None]:
train_loss2, val_loss2 = train_until_plateau(
    svi, scheduler, tr_dl, val_dl, verbose=True, min_cycles=5
)

train_loss.extend(train_loss2)
val_loss.extend(val_loss2)

In [None]:
fig,ax = plt.subplots(1, 1, figsize=(12, 10))

k = 10

x = np.arange(len(train_loss))
ax.plot(x[1:], train_loss[1:], label="train")
ax.plot(x[1:], val_loss[1:], label="validation")

axin = ax.inset_axes([0.2, 0.4, 0.7, 0.3])

axin.plot(x[k:], train_loss[k:], label="train")
axin.plot(x[k:], val_loss[k:], label="validation")
axin.autoscale(tight=True)
b = ax.indicate_inset_zoom(axin, label=None)
axin.set_xticklabels("")
axin.set_yticklabels("")

plt.legend()
plt.show()

In [None]:
class_t = torch.tensor(classes)

umis_t = torch.tensor(umis_flat, dtype=torch.float)

z_loc, _ = nbvae.encoder(umis_t)

mean_dr_logit = nbvae.lmb.calc_response(z_loc, class_t).detach().numpy()

In [None]:
te_umis_flat = te_umis.reshape(4 * n_cells_per_class, n_features)
te_dr_means = np.dstack([ssp.logit(d).mean(0) for d in te_drs])

te_umi_t = torch.tensor(te_umis_flat, dtype=torch.float)
te_z, _ = nbvae.encoder(te_umi_t)

te_mean_dr_logit = nbvae.lmb.calc_response(te_z, class_t).detach().numpy()

In [None]:
x = umap.UMAP().fit_transform(np.sqrt(umis_flat))
x2 = umap.UMAP().fit_transform(z_loc.detach().numpy())

te_x = umap.UMAP().fit_transform(np.sqrt(te_umis_flat))
te_x2 = umap.UMAP().fit_transform(te_z.detach().numpy())

In [None]:
fig,ax = plt.subplots(2, 2, figsize=(12, 12))
ax[0,0].scatter(x[:,0], x[:,1], c=classes)
ax[0,1].scatter(x2[:,0], x2[:,1], c=classes)
ax[1,0].scatter(te_x[:,0], te_x[:,1], c=te_classes)
ax[1,1].scatter(te_x2[:,0], te_x2[:,1], c=te_classes)
plt.show()

In [None]:
fig,ax = plt.subplots(4, n_drugs, figsize=(12, 20))

for i in range(n_drugs):
    for j,c in enumerate(np.unique(classes)):
        ax[j,i].plot(ssp.expit(dr_means[j,:,i].T), color='b')
        ax[j,i].plot(ssp.expit(mean_dr_logit[j,:,i].T), color='g')

plt.show()

In [None]:
fig,ax = plt.subplots(4, n_drugs, figsize=(12, 20))

for i in range(n_drugs):
    for j,c in enumerate(np.unique(classes)):
        ax[j,i].plot(ssp.expit(te_dr_means[j,:,i].T), color='b')
        ax[j,i].plot(ssp.expit(te_mean_dr_logit[j,:,i].T), color='g')

plt.show()