In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Standard imports
import numpy as np

# pertpy is needed to download the Kang data
import pertpy
import scanpy as sc

# This will download the data to ./data/kang_2018.h5ad
adata = pertpy.data.kang_2018()
# Store counts separately in the layers
adata.layers["counts"] = adata.X.copy()

  from .autonotebook import tqdm as notebook_tqdm
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [3]:
import pylemur.pp.basic

adata.layers["logcounts"] = pylemur.pp.basic.shifted_log_transform(adata.X)

In [4]:
import dask.array as da

def get_input_arr():
    return da.from_array(adata.layers["logcounts"])

In [5]:
x = np.array([[1, 2, 3]])

In [6]:
np.squeeze(x).shape

(3,)

In [7]:
A = np.ones((24673, 2))
B = np.ones((24673,))

In [8]:
B

array([1., 1., 1., ..., 1., 1., 1.])

In [9]:
adata

AnnData object with n_obs × n_vars = 24673 × 15706
    obs: 'nCount_RNA', 'nFeature_RNA', 'tsne1', 'tsne2', 'label', 'cluster', 'cell_type', 'replicate', 'nCount_SCT', 'nFeature_SCT', 'integrated_snn_res.0.4', 'seurat_clusters'
    var: 'name'
    obsm: 'X_pca', 'X_umap'
    layers: 'counts', 'logcounts'

In [None]:
import pylemur

model = pylemur.tl.LEMUR(adata, get_input_arr, design = "~ label", n_embedding=15, layer = "logcounts")
model.fit()
model.align_with_harmony(max_iter=2) # TODO: remove max_iter=2
print(model)

In [None]:
# The model.cond(**kwargs) call specifies the condition for the prediction
ctrl_pred = model.predict(new_condition=model.cond(label="ctrl"))
stim_pred = model.predict(new_condition=model.cond(label="stim"))

In [None]:
# Recalculate the UMAP on the embedding calculated by LEMUR
adata.obsm["embedding"] = model.embedding
sc.pp.neighbors(adata, use_rep="embedding")
sc.tl.umap(adata)
sc.pl.umap(adata, color=["label", "cell_type"])

In [None]:
import matplotlib.pyplot as plt
adata.layers["diff"] = (stim_pred - ctrl_pred)
# Also try CXCL10, IL8, and FBXO40
sel_gene = "TSC22D3"

fsize = plt.rcParams['figure.figsize']
fig = plt.figure(figsize=(fsize[0] * 3, fsize[1]))
axs = [fig.add_subplot(1, 3, i+1) for i in range(3)]
for ax in axs:
    ax.set_aspect('equal')
sc.pl.umap(adata, layer="diff", color=[sel_gene], cmap = plt.get_cmap("seismic"), vcenter=0,
    vmin=-4, vmax=4, title="Pred diff (stim - ctrl)", ax=axs[0], show=False)
sc.pl.umap(adata[adata.obs["label"]=="ctrl"], layer="logcounts", color=[sel_gene], vmin = 0, vmax =4,
    title="Ctrl expr", ax=axs[1], show=False)
sc.pl.umap(adata[adata.obs["label"]=="stim"], layer="logcounts", color=[sel_gene], vmin = 0, vmax =4,
    title="Stim expr", ax=axs[2])
