In [1]:
import anndata as ad
import scanpy as sc
import gc
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sea
import scipy
import torch
import os

from scDisInFact import scdisinfact, create_scdisinfact_dataset

from metrics import calculate_metrics

sc.settings.verbosity = 0
sc.settings.set_figure_params(dpi=300)
pd.set_option('display.max_columns', None)
seed = 10
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

R_HOME is already set to: /vast/palmer/apps/avx2/software/R/4.3.2-foss-2022b-patched/lib64/R


Unable to determine R library path: Command '('/vast/palmer/apps/avx2/software/R/4.3.2-foss-2022b-patched/lib64/R/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 127.
  anndata2ri.activate()


In [None]:
def run_scdisinfact(adata, batch_key, condition_key, dataset_name, cell_type_label=None):
    if np.max(adata.X) > 15:
        sc.pp.filter_cells(adata, min_genes=300)
        sc.pp.filter_genes(adata, min_cells=10)

        sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
        sc.pp.log1p(adata)

    if adata.shape[1] > 3000:
        sc.pp.highly_variable_genes(adata, n_top_genes=3000, batch_key=batch_key)
        adata = adata[:, adata.var["highly_variable"]].copy()
    else:
        sc.pp.highly_variable_genes(adata, n_top_genes=adata.shape[1], batch_key=batch_key)

    import warnings

    # Suppress all warnings
    warnings.filterwarnings("ignore")

    result_dir = "./scd/" + dataset_name + "/"
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    if isinstance(adata.X, scipy.sparse.spmatrix):
        adata.X = adata.X.toarray()

    counts = adata.X
    adata.obs["batch"] = adata.obs[batch_key].copy()
    meta_cells = adata.obs.copy()

    if isinstance(condition_key, str):
        condition_key = [condition_key]

    if not isinstance(condition_key, list):
        print("Wrong condition_key, must be string or list of string")

    data_dict = create_scdisinfact_dataset(counts, meta_cells, condition_key=condition_key, batch_key=batch_key, log_trans=False)

    # default setting of hyper-parameters
    reg_mmd_comm = 1e-4
    reg_mmd_diff = 1e-4
    reg_kl_comm = 1e-5
    reg_kl_diff = 1e-2
    reg_class = 1
    reg_gl = 1

    Ks = [8, 4]

    batch_size = 64
    nepochs = 100
    interval = 10
    lr = 5e-4
    lambs = [reg_mmd_comm, reg_mmd_diff, reg_kl_comm, reg_kl_diff, reg_class, reg_gl]
    model = scdisinfact(data_dict=data_dict, Ks=Ks, batch_size=batch_size, interval=interval, lr=lr,
                        reg_mmd_comm=reg_mmd_comm, reg_mmd_diff=reg_mmd_diff, reg_gl=reg_gl, reg_class=reg_class,
                        reg_kl_comm=reg_kl_comm, reg_kl_diff=reg_kl_diff, seed=0, device=device)
    model.train()
    losses = model.train_model(nepochs=nepochs, recon_loss="NB")
    torch.save(model.state_dict(), result_dir + f"model_{Ks}_{lambs}_{batch_size}_{nepochs}_{lr}.pth")
    model.load_state_dict(
        torch.load(result_dir + f"model_{Ks}_{lambs}_{batch_size}_{nepochs}_{lr}.pth", map_location=device))
    _ = model.eval()

    # one forward pass
    z_cs = []
    z_ds = []
    zs = []

    for dataset in data_dict["datasets"]:
        with torch.no_grad():
            # pass through the encoders
            dict_inf = model.inference(counts=dataset.counts_norm.to(model.device),
                                       batch_ids=dataset.batch_id[:, None].to(model.device), print_stat=True)
            # pass through the decoder
            dict_gen = model.generative(z_c=dict_inf["mu_c"], z_d=dict_inf["mu_d"],
                                        batch_ids=dataset.batch_id[:, None].to(model.device))
            z_c = dict_inf["mu_c"]
            z_d = dict_inf["mu_d"]
            z = torch.cat([z_c] + z_d, dim=1)
            mu = dict_gen["mu"]
            z_ds.append([x.cpu().detach().numpy() for x in z_d])
            z_cs.append(z_c.cpu().detach().numpy())
            zs.append(np.concatenate([z_cs[-1]] + z_ds[-1], axis=1))

    latent = np.concatenate(z_cs, axis=0)
    meta_dict = {}
    for namei in meta_cells.columns.tolist():
        meta_dict[namei] = np.concatenate([x[namei].values for x in data_dict["meta_cells"]])

    meta_df = pd.DataFrame(meta_dict)
    adata_latent = ad.AnnData(X=counts)
    adata_latent.var_names = adata.var_names
    adata_latent.obs = meta_df

    denoised_counts = model.predict_counts(input_counts=counts, meta_cells=meta_cells, condition_keys=condition_key,
                                           batch_key=batch_key, predict_conds=None, predict_batch=None)
    # adata.obsm["denoised"] = denoised_counts
    adata.layers["denoised"] = denoised_counts
    adata.obsm['main_effect'] = latent

    adata.write_h5ad("./scd/" + dataset_name + "_latent.h5ad")

    sc.pp.neighbors(adata_latent, n_neighbors=15, n_pcs=50)
    sc.tl.umap(adata_latent)
    sc.pl.umap(adata_latent, color=condition_key, ncols=1)
    sc.pl.umap(adata_latent, color=batch_key, ncols=1)
    if cell_type_label is not None:
        sc.pl.umap(adata_latent, color=cell_type_label)

# (EC)CITE-seq

https://www.nature.com/articles/s41588-021-00778-2

In [None]:
adata=sc.read_h5ad("../data/ECCITE.h5ad")
run_scdisinfact(adata=adata, batch_key='replicate', condition_key='perturbation', cell_type_label="Phase", dataset_name="ECCITE")

# ASD

https://singlecell.broadinstitute.org/single_cell/study/SCP1184/in-vivo-perturb-seq-reveals-neuronal-and-glial-abnormalities-associated-with-asd-risk-genes#study-download

In [None]:
adata=sc.read_h5ad("../data/ASD.h5ad")
run_scdisinfact(adata=adata, batch_key='Batch', condition_key='perturb01', cell_type_label="CellType", dataset_name="ASD")

# ASD1

https://singlecell.broadinstitute.org/single_cell/study/SCP1184/in-vivo-perturb-seq-reveals-neuronal-and-glial-abnormalities-associated-with-asd-risk-genes#study-download

In [None]:
adata=sc.read_h5ad("../data/ASD1.h5ad")
run_scdisinfact(adata=adata, batch_key='Batch', condition_key='perturb01', cell_type_label="CellType", dataset_name="ASD1")