In [1]:
import deconV as dv
import scout

import glob, tqdm, time, os
import torch
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rcParams

import pandas as pd
import numpy as np
import scanpy as sc
import scvi
import seaborn as sns
import tqdm
import scout

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D

import plotly.express as px

%load_ext autoreload
%autoreload 2

Global seed set to 0


In [2]:
params = {
    "cell_type_key": "cellType",
    "fig_fmt": "png",


    "selected_ct": ["0", "1", "2"],
    "bulk_file": "../../data/synthetic100/bulk.csv",
    "ref_annot_file": "../../data/GSE136148/pdata.tsv",
    "ref_file": "../../data/GSE136148/sc.tsv",
    
    "outdir": "out",
    "figsize": (8,8),
    "dpi": 80,
}

sc.settings.set_figure_params(dpi=80, facecolor='white')
plt.rc("patch", edgecolor="black", facecolor="royalblue", linewidth=1.0)
plt.rc("axes", facecolor="white", edgecolor="black", linewidth=1.0)

sc.settings.verbosity = 0

In [3]:
sadata = dv.tl.read_data(params["ref_file"])
print(f"scRNA-seq data - cells: {sadata.shape[0]}, genes: {sadata.shape[1]}")

scRNA-seq data - cells: 3022, genes: 33694


In [4]:
print("Reading pheno data...")
pheno_df = pd.read_csv(params["ref_annot_file"], sep="\t", index_col=0)
pheno_df.index.name = None

Reading pheno data...


In [5]:
common_cells = list(set(pheno_df.index.tolist()) & set(sadata.obs_names.tolist()))
len(common_cells)

3022

In [6]:
sadata = sadata[common_cells, :].copy()
pheno_df = pheno_df.loc[common_cells, :].copy()
sadata.obs[params["cell_type_key"]] = pheno_df[params["cell_type_key"]].tolist()
sadata.obs.groupby(params["cell_type_key"]).size()

cellType
0    1988
1     686
2     319
3      29
dtype: int64

In [7]:
print("Reading bulk data...")
bulk_df = pd.read_csv(params["bulk_file"], sep=",", index_col=None)
if bulk_df.iloc[:,0].dtype == "O":
    bulk_df.set_index(bulk_df.columns[0], inplace=True)
print(f"bulk RNA-seq data - samples: {bulk_df.shape[0]}, genes: {bulk_df.shape[1]}")

Reading bulk data...
bulk RNA-seq data - samples: 8424, genes: 100


In [8]:
if params["selected_ct"] is not None and len(params["selected_ct"]) > 0:
    sadata = sadata[sadata.obs[params["cell_type_key"]].astype("str").isin(params["selected_ct"]), :].copy()

sadata.obs[params["cell_type_key"]] = sadata.obs[params["cell_type_key"]].astype("category")
sadata.obs.groupby(params["cell_type_key"]).size()

cellType
0    1988
1     686
2     319
dtype: int64

In [9]:
true_df = pd.read_csv("../../data/synthetic100/bulk_proportions.csv", index_col=0)
true_df.drop(columns="n_cells", inplace=True)

In [10]:
sc.pp.filter_cells(sadata, min_genes=200)
sc.pp.filter_genes(sadata, min_cells=3)
adata = dv.tl.combine(sadata, bulk_df)
scout.tl.scale_log_center(adata, target_sum=None, exclude_highly_expressed=True)

scRNA-seq data - cells: 2993, genes: 8424
bulk RNA-seq data - samples: 100, genes: 8424


In [11]:
decon = dv.DeconV(adata, cell_type_key="cellType", sub_type_key=None, layer="counts")

Added results to: adata.uns['de']['cellType']


In [12]:
decon.filter_outliers(dropout_factor_quantile=0.9, pseudobulk_lims=(-10, 10), aggregate="max")

In [13]:
decon.init_dataset(
    weight_type=None, weight_agg="min",
    inverse_weight=False, log_weight=False, quantiles=(0, 1)
)

In [14]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, adata, true_df) -> None:
        super().__init__()
        self.x = torch.tensor(adata.varm["bulk"].T, dtype=torch.float32)
        self.y = torch.tensor(true_df.values, dtype=torch.float32)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [15]:
decon.init_dataset(None)
mu, _ = decon.get_signature(quantiles=(0,1))

In [16]:
class PSM(nn.Module):
    def __init__(
        self, loc,
        gene_weights=None,
    ):
        super().__init__()
        self.loc = loc
        self.gene_weights = gene_weights
        self.log_cell_counts = nn.Parameter(
            torch.ones(loc.shape[1]) * torch.log(torch.tensor(1000) / loc.shape[1])
        )
        self.opt = torch.optim.Adam(self.parameters(), lr=0.5)

    def get_proportions(self):
        return F.softmax(self.log_cell_counts, 0)
        
    def _fit(self, bulk):
        for i in range(50):
            self.opt.zero_grad()
            loss = self(bulk)
            loss.backward()
            self.opt.step()

    def forward(self, x):
        loc = torch.sum(self.loc * self.log_cell_counts.exp(), 1)
        return (-D.Poisson(loc).log_prob(x) * F.softmax(self.gene_weights, 0)).mean()

class LR(nn.Module):
    def __init__(self, loc, n_genes, n_cell_types):
        super(LR, self).__init__()
        self.n_genes = n_genes
        self.n_cell_types = n_cell_types
        self.loc = loc

        self.weights = nn.Parameter(torch.ones(self.n_genes, 1))
        self.opt = torch.optim.Adam(self.parameters(), lr=0.1)

    def forward(self, bulk, true_proportions):
        psm = PSM(self.loc, self.weights)
        _t = time.time()
        psm._fit(bulk)

        _t = time.time()
        proportions = psm.get_proportions()
        loss = torch.sum((proportions - true_proportions) ** 2)
        return loss

In [None]:
dataset = CustomDataset(adata, true_df)

lrm = LR(loc=mu, n_genes=adata.n_vars, n_cell_types=3)
optimizer = torch.optim.Adam(lrm.parameters(), lr=0.1)
loss_fn = torch.nn.MSELoss()
pbar = tqdm.tqdm(range(100))

for epoch in pbar:
    L = 0.0
    c = 0
    for bulk, true_proportions in list(dataset)[:10]:
        optimizer.zero_grad()
        loss = lrm(bulk.round(), true_proportions)
        loss.backward()
        L += loss.item()
        c += 1
        optimizer.step()

        pbar.set_postfix(
            {
                "loss": f"{loss.item():.3f}",
            }
        )

    print(f"Epoch loss: {L/c:.5f}")
