In [3]:
import deconV.deconV as dv

import glob, tqdm, time, os, warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D

import matplotlib.pyplot as plt
from matplotlib import rcParams

import chart_studio.plotly as py
import plotly.graph_objects as go
import plotly.express as px

import pandas as pd
import numpy as np
import scanpy as sc
import scvi
import seaborn as sns
import tqdm
import scout

%load_ext autoreload
%autoreload 2

Global seed set to 0


In [4]:
params = {
    "jupyter": True,
    "tqdm": True,
    "cell_type_key": "cellType",
    "layer": "ncounts",
    "index_col": 0,
    "selected_ct": ["alpha", "delta", "gamma", "beta"],
    # "selected_ct": ["0", "1", "2"],
    "model_type": "poisson",
    "ignore_others": True,
    "n_top_genes": -1,
    "plot_pseudo_bulk": False,
    "lr": 0.01,
    "epochs": 5000,
    "fig_fmt": "png",
    "indir": "../../data/xin/",
    # "indir": "../../data/GSE136148/",
    "outdir": "out",
    "figsize": (8, 8),
    "dpi": 80,
}


In [233]:
adata = sc.read_csv(os.path.join(params["indir"], "sc.tsv"), delimiter="\t")
adata.obs = pd.read_csv(os.path.join(params["indir"], "pdata.tsv"), sep="\t", index_col=0).loc[adata.obs.index]
adata = adata[adata.obs["cellType"].isin(params["selected_ct"]), :].copy()
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)
scout.tl.scale_log_center(adata, target_sum=None)
adata

AnnData object with n_obs × n_vars = 748 × 18025
    obs: 'sampleID', 'SubjectName', 'cellTypeID', 'cellType', 'n_genes'
    var: 'n_cells'
    uns: 'log1p'
    layers: 'counts', 'ncounts', 'centered', 'logcentered'

In [234]:
adata.layers["ncounts"].max()

227246.08

In [235]:
bdata = sc.read_csv(os.path.join(params["indir"], "bulk.tsv"), delimiter="\t")
bdata

AnnData object with n_obs × n_vars = 18 × 39849

In [236]:
def create_dataset(adata, bdata, label_key, layer="counts"):
    common_genes = list(set(adata.var.index.values) & set(bdata.var.index.values))
    _adata = adata[:, common_genes].copy()
    _bdata = bdata[:, common_genes].copy()

    X = torch.tensor(_adata.layers[layer].round())

    Y = torch.tensor(_bdata.X)

    return X, Y


In [237]:
X, Y = create_dataset(adata, bdata, params["cell_type_key"], layer="ncounts")
X.shape, Y.shape

(torch.Size([748, 17390]), torch.Size([18, 17390]))

In [273]:
class Encoder(nn.Module):
    def __init__(self, in_features, n_latent=10, hidden_dim=128):
        super().__init__()
        self.in_features = in_features
        self.n_latent = n_latent
        self.hidden_dim = hidden_dim

        self.enc = nn.Sequential(
            nn.Linear(in_features=self.in_features, out_features=self.hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim),
            nn.ReLU()
        )

        self.mu_head = nn.Linear(in_features=self.hidden_dim, out_features=self.n_latent)
        self.logvar_head = nn.Linear(in_features=self.hidden_dim, out_features=self.n_latent)

    def forward(self, x):
        _x = self.enc(x)
        z_mu = self.mu_head(_x)
        z_logvar = torch.tanh(self.logvar_head(_x)) * 20
        return z_mu, z_logvar

    def sample(Self, z_mu, z_logvar):
        return z_mu + torch.exp(0.5 * z_logvar) * torch.randn(1)

class Decoder(nn.Module):
    def __init__(self, n_latent, out_features, hidden_dim=128):
        super().__init__()
        self.n_latent = n_latent
        self.out_features = out_features
        self.hidden_dim = hidden_dim

        self.dec = nn.Sequential(
            nn.Linear(in_features=self.n_latent, out_features=self.hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
        )

    def forward(self, z):
        log_lambda = self.dec(z)
        return (torch.tanh(log_lambda) * 20).exp()


In [301]:
def loss_kl(z_mu, z_logvar):
    return -0.5 * (1.0 + z_logvar - z_mu**2 - torch.exp(z_logvar)).sum() / len(z_mu)

def loss_loglik(x_lambda, x):
    dist = D.Poisson(x_lambda)
    return dist.log_prob(x).sum() / x_lambda.shape[1]


In [316]:
enc = Encoder(X.shape[1])
dec = Decoder(n_latent=enc.n_latent, out_features=X.shape[1])

enc_optim = torch.optim.Adam(enc.parameters(), lr=1e-5)
dec_optim = torch.optim.Adam(dec.parameters(), lr=1e-5)
train_loader = torch.utils.data.DataLoader(X, batch_size=32)

def fit(num_epochs):
    pbar = tqdm.tqdm(range(num_epochs))
    lh_losses = []
    kl_losses = []
    for epoch in pbar:
        epoch_loss = 0.0
        epoch_kl_loss = 0.0
        epoch_lh_loss = 0.0
        for x in train_loader:
            z_mu, z_logvar = enc(x)
            kl_loss = loss_kl(z_mu, z_logvar)
            epoch_kl_loss += kl_loss.item()
            z = enc.sample(z_mu, z_logvar)
            
            x_lambda = dec(z)
            lh_loss = loss_loglik(x_lambda, x)
            epoch_lh_loss += lh_loss.item()
            
            loss = kl_loss + lh_loss
            loss.backward()

            epoch_loss += loss.item()

            dec_optim.step()
            enc_optim.step()
        
        epoch_loss /= X.shape[0]
        epoch_lh_loss /= X.shape[0]
        epoch_kl_loss /= X.shape[0]

        lh_losses.append(epoch_lh_loss)
        kl_losses.append(epoch_kl_loss)
        
        pbar.set_postfix({"loss": f"{epoch_loss:.1f}"})

    kl_losses = np.array(kl_losses)
    lh_losses = np.array(lh_losses)
    kl_losses /= kl_losses.max()
    lh_losses /= lh_losses.max()

    return kl_losses, lh_losses


kl_losses, lh_losses = fit(100)


100%|██████████| 100/100 [00:47<00:00,  2.10it/s, loss=-460148312.3]


In [317]:
def plot_losses(**kwargs):
    fig = go.Figure()
    for loss_name, loss in kwargs.items():
        fig.add_trace(
            go.Scatter(
                x=list(range(0, len(loss))),
                y=loss,
                name=loss_name, showlegend=True
            )
        )
    fig.update_layout(
        width=1000, height=800
    )
    return fig

In [318]:
plot_losses(kl_losses=kl_losses, lh_losses=lh_losses)