In [1]:
import scanpy as sc
import scout
import plotly.graph_objects as go
import plotly.express as px
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import pandas as pd
import scvi
import tqdm
import numpy as np

%load_ext autoreload
%autoreload 2

Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)


In [11]:
adata = sc.read_h5ad("../../data/GSE136148/adata.h5ad")

In [12]:
fig = px.scatter(
    x = adata.var["bulk"],
    y = adata.var["pseudo"],
    log_x=True, log_y=True,
)
fig.update_xaxes(
    scaleanchor="y",
    scaleratio=1,
)
fig.add_trace(
    go.Scatter(
        x = [0.1, max(adata.var["bulk"].max(), adata.var["pseudo"].max())],
        y = [0.1, max(adata.var["bulk"].max(), adata.var["pseudo"].max())],
        mode="lines"
    )
)
fig.update_layout(
    width=700, height=500,
    margin=dict(l=10, r=10, t=10, b=10),
)
fig

In [13]:
scout.tl.scale_log_center(adata)

In [14]:
X = []
for i, cell_type in enumerate(adata.obs["cell_type"].cat.categories):
    _x = adata[adata.obs["cell_type"] == cell_type, :].layers["counts"]
    X.append(torch.tensor(_x, dtype=torch.float32))

In [15]:
cell_types = adata.obs["cell_type"].cat.categories.values
cell_types

array([0, 1, 2])

In [17]:
adata[adata.obs["cell_type"] == 0, 14].layers["scvi_mean"].mean(), adata[adata.obs["cell_type"] == 0, 14].layers["counts"].mean()

(ArrayView(1.4902279, dtype=float32), ArrayView(1.4595537, dtype=float32))

In [49]:
def plot_distribution(gene):
    gene_idx = adata.var.index.get_loc(gene)
    df = pd.DataFrame(dict(counts=adata[:, gene].layers["counts"].flatten(), cell_type=adata.obs["cell_type"].values))

    fig = px.histogram(
        df.sort_values("cell_type"), x="counts", color="cell_type",
        marginal="box", nbins=100, histnorm="probability density",
        color_discrete_sequence=scout.ply.SC_DEFAULT_COLORS, barmode="overlay"
    )
    fig.update_traces(marker_line_width=1, marker_line_color="black")

    for i, cell_type in enumerate(adata.obs["cell_type"].cat.categories):
        view = adata[adata.obs["cell_type"] == cell_type, gene_idx]
        l = torch.tensor(view.layers["counts"].mean())
        # l = torch.tensor(view.layers["scvi_mean"].mean())
        pois = D.Poisson(l)
        xx = torch.tensor(list(range(0, int(view.layers["ncounts"].max()))))
        fig.add_trace(go.Scatter(x=xx, y=pois.log_prob(xx).exp(), mode="lines", line_color=scout.ply.SC_DEFAULT_COLORS[i], name=f"{cell_type}"))

    fig.update_layout(
        width=800, height=500,
        margin=dict(l=10, r=10, t=10, b=10),
    )
    
    return fig

plot_distribution("COX6B1")

In [19]:
plot_distribution("NPW")

In [51]:
scout.ply.violin(adata, adata.var.index[14], "cell_type", layer="ncounts")

In [52]:
loc = torch.empty(adata.n_vars, 3)
scale = torch.empty(adata.n_vars, 3)

for i, _ in enumerate(cell_types):
    loc[:, i] = torch.tensor(adata[adata.obs["cell_type"] == cell_types[i], :].layers["ncounts"].mean(axis=0))
    scale[:, i] = torch.tensor(adata[adata.obs["cell_type"] == cell_types[i], :].layers["ncounts"].std(axis=0))

Y = torch.tensor(adata.var["bulk"].values, dtype=torch.int)

In [36]:
adata.layers["counts"].std()

10.092584

In [50]:
torch.tensor([1,2,3]) * torch.tensor([1,2,3])

tensor([1, 4, 9])

In [102]:
class PSM(nn.Module):
    def __init__(
        self, loc, gene_weights=None, gene_scale=None, lib_size=torch.tensor(1000), norm=torch.tensor(1.0)
    ):
        super().__init__()
        self.loc = loc
        self.gene_weights = gene_weights
        self.gene_scale = gene_scale
        self.norm = norm
        self.log_cell_counts = nn.Parameter(torch.ones(loc.shape[1]) * torch.log(lib_size / loc.shape[1]))
        self.eps = torch.tensor(1e-5)

    def get_proportions(self):
        with torch.no_grad():
            return F.softmax(self.log_cell_counts, 0)

    def get_lib_size(self):
        with torch.no_grad():
            return torch.sum(self.log_cell_counts.exp())

    def get_distribution(self):
        loc = torch.sum(self.loc * self.log_cell_counts.exp(), 1)

        assert torch.isnan(loc).sum() == 0

        # return D.Poisson(loc)
        # turn D.Normal(loc, 10)

    def forward(self, x):
        d = self.get_distribution()

        if self.gene_weights != None:
            return (-d.log_prob(x) * self.gene_weights).mean() * self.norm

        return -d.log_prob(x).mean() * self.norm

In [103]:
class NSM(nn.Module):
    def __init__(
        self, loc, scale, gene_weights=None, gene_scale=None, lib_size=torch.tensor(2000000), norm=torch.tensor(1.0)
    ):
        super().__init__()
        self.loc = loc
        self.scale = scale
        self.gene_weights = gene_weights
        self.gene_scale = gene_scale
        self.log_cell_counts = nn.Parameter(torch.ones(loc.shape[1]) * torch.log(lib_size / loc.shape[1]))
        self.norm = norm
        self.eps = torch.tensor(1e-5)

    def get_proportions(self):
        with torch.no_grad():
            return F.softmax(self.log_cell_counts, 0)

    def get_lib_size(self):
        with torch.no_grad():
            return torch.sum(self.log_cell_counts.exp())

    def get_distribution(self):
        loc = torch.sum(self.loc * self.log_cell_counts.exp(), 1)
        scale = torch.max(
            self.eps, torch.sqrt(torch.sum(self.scale**2 * self.log_cell_counts.exp(), 1))
        )

        assert torch.isnan(loc).sum() == 0

        return D.Normal(loc, scale)

    def forward(self, x):
        d = self.get_distribution()

        if self.gene_weights != None:
            return (-d.log_prob(x) * self.gene_weights).mean() * self.norm

        return -d.log_prob(x).mean() * self.norm

In [104]:
poisson_model = PSM(loc)
normal_model = NSM(loc, scale)

In [105]:
def fmt_c(w):
    return " ".join([f"{v:.2f}" for v in w])

In [106]:
def _fit(model, y):
    pbar = tqdm.tqdm(range(10000))
    optim = torch.optim.Adam(model.parameters(), lr=0.01)
    for i in pbar:
        optim.zero_grad()
        loss = model(y)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.00001)
        optim.step()
        if i % 50 == 0:
            pbar.set_postfix({"loss":f"{loss.item():.1f}", "p":fmt_c(model.get_proportions()), "lib_size":f"{model.get_lib_size().item():.1f}"})


In [107]:
_fit(normal_model, pseudo)
normal_model.log_cell_counts.exp()

100%|██████████| 10000/10000 [00:20<00:00, 496.42it/s, loss=4.3, p=0.65 0.23 0.12, lib_size=2505.4]     


tensor([1642.7958,  576.0049,  285.8341], grad_fn=<ExpBackward0>)

In [108]:
pseudo = torch.tensor(np.round(adata.var["pseudo"].values))

In [115]:
_fit(poisson_model, Y)

100%|██████████| 10000/10000 [00:11<00:00, 868.48it/s, loss=36945836.0, p=0.76 0.24 0.00, lib_size=1293.3]


In [110]:
poisson_model.log_cell_counts.exp()

tensor([1667.6390,  651.6412,  156.4780], grad_fn=<ExpBackward0>)

In [111]:
adata.obs.groupby("cell_type").apply(len)

cell_type
0    1434
1     483
2     314
dtype: int64

In [113]:
adata.var["pb_ratio"].mean()

inf