In [2]:
import deconV as dv
import scout

import glob, tqdm, time, os
import torch
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rcParams

import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import numpy as np
import scanpy as sc
import scvi
import seaborn as sns
import tqdm
import scout

import plotly.express as px

%load_ext autoreload
%autoreload 2

Global seed set to 0


In [3]:
params = {
    "cell_type_key": "cellType",

    "selected_ct": ["0", "1", "2"],
    "bulk_file": "../../data/GSE136148/bulk.tsv",
    "ref_annot_file": "../../data/GSE136148/pdata.tsv",
    "ref_file": "../../data/GSE136148/sc.tsv",
}

sc.settings.set_figure_params(dpi=80, facecolor='white')
plt.rc("patch", edgecolor="black", facecolor="royalblue", linewidth=1.0)
plt.rc("axes", facecolor="white", edgecolor="black", linewidth=1.0)

sc.settings.verbosity = 0

In [4]:
sadata = dv.tl.read_data(params["ref_file"])
print(f"scRNA-seq data - cells: {sadata.shape[0]}, genes: {sadata.shape[1]}")

scRNA-seq data - cells: 3022, genes: 33694


In [5]:
print("Reading pheno data...")
pheno_df = pd.read_csv(params["ref_annot_file"], sep="\t", index_col=0)
pheno_df.index.name = None

Reading pheno data...


In [6]:
common_cells = list(set(pheno_df.index.tolist()) & set(sadata.obs_names.tolist()))
len(common_cells)

3022

In [7]:
sadata = sadata[common_cells, :].copy()
pheno_df = pheno_df.loc[common_cells, :].copy()
sadata.obs[params["cell_type_key"]] = pheno_df[params["cell_type_key"]].tolist()
sadata.obs.groupby(params["cell_type_key"]).size()

cellType
0    1988
1     686
2     319
3      29
dtype: int64

In [8]:
print("Reading bulk data...")
bulk_df = pd.read_csv(params["bulk_file"], sep="\t", index_col=None)
if bulk_df.iloc[:,0].dtype == "O":
    bulk_df.set_index(bulk_df.columns[0], inplace=True)
print(f"bulk RNA-seq data - samples: {bulk_df.shape[0]}, genes: {bulk_df.shape[1]}")

Reading bulk data...
bulk RNA-seq data - samples: 1, genes: 58387


In [9]:
if params["selected_ct"] is not None and len(params["selected_ct"]) > 0:
    sadata = sadata[sadata.obs[params["cell_type_key"]].astype("str").isin(params["selected_ct"]), :].copy()

sadata.obs[params["cell_type_key"]] = sadata.obs[params["cell_type_key"]].astype("category")
sadata.obs.groupby(params["cell_type_key"]).size()

cellType
0    1988
1     686
2     319
dtype: int64

In [10]:
sc.pp.filter_cells(sadata, min_genes=200)
sc.pp.filter_genes(sadata, min_cells=3)
adata = dv.tl.combine(sadata, bulk_df)
scout.tl.scale_log_center(adata, target_sum=None, exclude_highly_expressed=True)

scRNA-seq data - cells: 2993, genes: 18597
bulk RNA-seq data - samples: 1, genes: 18597


In [11]:
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.external.tl.trimap(adata)
scout.tl.sub_cluster(adata, "cellType")
scout.ply.subplots([
    scout.ply.projection(adata, hue="cellType", obsm_layer="X_umap"),
    scout.ply.projection(adata, hue="sub_type", obsm_layer="X_umap"),
], subplot_titles=["Cell Type", "Sub Type (leiden)"])



In [57]:
def synthetic_pseudo_dataset(adata, groupby, proportions, n_cells):
    pseudo = np.zeros(adata.n_vars)

    idxs = np.random.choice(len(proportions), n_cells, p=proportions).tolist()
    counts = [(idx, idxs.count(idx)) for idx in set(idxs)]

    lens = adata.obs.groupby(groupby).size()

    print(counts)
    print(lens)

    for ct_i, c in counts:
        gex = adata[adata.obs[groupby].cat.codes == ct_i, :].layers["counts"]

        ridx = np.random.randint(0, lens[ct_i], (c))
        print(ridx.shape)
        print(gex.shape)
        print(ridx)
        for ri in ridx:
            pseudo += gex[ridx[ri], :]

    return pseudo


def pseudo_bulk_samples(adata, groupby, n_samples):
    n_groups = len(adata.obs[groupby].cat.categories)
    proportions = np.random.randint(0, 10000, (n_samples, n_groups))
    proportions = (proportions.T / proportions.sum(1)).T

    bulks = []

    columns = ["n_cells"] + adata.obs[groupby].cat.categories.tolist()
    pdata = pd.DataFrame([], columns=columns)

    for i in tqdm.tqdm(range(n_samples)):
        n_cells = np.random.randint(1000, 10000)
        bulks.append(synthetic_pseudo_dataset(adata, groupby, proportions[i], n_cells))
        pdata = pd.concat([pdata, pd.DataFrame([[n_cells] + proportions[i].tolist()], columns=columns, index=[i])])

    bulk_counts = pd.DataFrame(bulks, columns=adata.var_names).T
    return bulk_counts, pdata

In [58]:
pseudo_bulk_samples(adata, "cellType", 1)

  0%|          | 0/1 [00:00<?, ?it/s]

[(0, 928), (1, 2937), (2, 1249)]
cellType
0    1988
1     686
2     319
dtype: int64
(928,)
(1988, 18597)
[1798 1518 1500  655 1049  673  135 1178  779 1535 1865 1815    3  968
 1556  641  535   58 1022  919 1488  323  903 1253  726 1209 1931 1806
 1208  392  961  917 1221  574 1724  564 1700 1583 1784 1397  177  658
 1950  994  859 1638  739  545  184 1123  875  279  600  136  404 1909
  884 1535   16 1405 1394 1092  524  570 1803 1475  467  365 1127  565
  764  325 1590 1837 1935 1697  507 1274  660  131   80 1100  304   63
 1724 1590 1597  132 1275  855   72  453  217  639 1886 1945 1945  384
  192  441 1746 1812  686  506  386  780  452  620  490 1077 1075 1572
  873  117 1067 1539 1869   95 1163  846 1253   31  198  390   26 1963
 1633 1775 1145  286   22 1051 1561 1889  462 1440 1449  136  848  418
 1245 1055  739  585 1736 1875  211  421 1217  530 1762 1407 1760  657
  646 1464  520 1125 1370 1339 1061 1808 1840  586  993 1681  533  507
 1980  516  848 1097  959 1665 1080  202  




IndexError: index 1798 is out of bounds for axis 0 with size 928