In [1]:
import numpy as np
import scanpy as sc
import cinemaot as co
import matplotlib.colors as colors
import matplotlib.pyplot as plt

import random
import torch
import sklearn
import os
def set_seed(seed: int):
    # Set Python random seed
    random.seed(seed)

    # Set NumPy random seed
    np.random.seed(seed)

    # Set PyTorch random seed
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # If using multi-GPU.

        # Ensure deterministic behavior in PyTorch (can slow down computations)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Set sklearn random seed
    sklearn.utils.check_random_state(seed)

    # Set environment variable for reproducibility
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(123)

from metrics import calculate_metrics

import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=FutureWarning)

In [2]:
from scCAPE import sccape
from scCAPE import plotting
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import os
import pickle
import gseapy as gp
from fractions import Fraction
import scipy
from scipy.sparse import csr_matrix

In [3]:
import pertpy as pt
import scanpy as sc
from scvi import REGISTRY_KEYS



In [None]:
import os
import sklearn
import scipy
from sklearn.neighbors import NearestNeighbors

In [None]:
def run_cinema_ot(adata, condition_key, cell_type_key,exp_label,ref_label, dataset_name):
    import warnings
    warnings.filterwarnings("ignore")
    warnings.filterwarnings("ignore", category=FutureWarning)

    print(adata)
    if adata.shape[1]>2000:
        sc.pp.highly_variable_genes(adata, n_top_genes=2000)
        adata = adata[:, adata.var["highly_variable"]]
    
    adata_=adata.copy()
    print(adata_)
    sc.pp.pca(adata_)
    print(condition_key,ref_label,exp_label,cell_type_key)
    cf, ot, de = co.cinemaot.cinemaot_unweighted(adata_, obs_label=condition_key, ref_label=ref_label,
                                                 expr_label=exp_label, mode='parametric', thres=0.5,
                                                 smoothness=1e-5, eps=1e-3, preweight_label=cell_type_key)
    de.write("./cinema_ot/"+dataset_name+"_ITE.h5ad")
    adata_.obsm['cf'] = cf.copy()
    adata_.obsm['cf'][adata_.obs[condition_key]==ref_label,:] = np.matmul(ot/np.sum(ot,axis=1)[:,None],cf[adata_.obs[condition_key]==exp_label,:])
    sc.pp.neighbors(adata_, use_rep='cf')

    sc.tl.umap(adata_, random_state=1)
    print("Effect modifier (called confounder in this paper) space:")
    sc.pl.umap(adata_, color=[condition_key, cell_type_key], wspace=0.5)

    import rpy2.robjects as robjects
    import os
    os.environ["R_HOME"]="/home/xx244/.conda/envs/benchmark/lib/R"
    import anndata2ri
    anndata2ri.activate()

    from metrics import calculate_metrics

    adata_.write("./cinema_ot/"+dataset_name+".h5ad")
    
    try:
        calculate_metrics(adata_, batch_key=condition_key, celltype_key=cell_type_key, all=True, n_neighbors=15,
                          embed='cf', savepath="./cinema_ot/"+dataset_name+".csv")
    except:
        calculate_metrics(adata_, batch_key=condition_key, celltype_key=cell_type_key, all=True, n_neighbors=15,
                          embed='cf', savepath="./cinema_ot/"+dataset_name+".csv")

adata=sc.read_h5ad("../data/ASD1.h5ad")
run_cinema_ot(adata=adata, condition_key='perturb01', cell_type_key="CellType", exp_label='mutated', ref_label='nan', dataset_name="ASD1")

In [None]:
def mixscape(adata,obs_label, ref_label, expr_label, nn=20, return_te = True):
    sc.pp.pca(adata)
    X_pca1 = adata.obsm['X_pca'][adata.obs[obs_label]==expr_label,:]
    X_pca2 = adata.obsm['X_pca'][adata.obs[obs_label]==ref_label,:]
    nbrs = NearestNeighbors(n_neighbors=nn, algorithm='ball_tree').fit(X_pca1)
    mixscape_pca = adata.obsm['X_pca'].copy()
    mixscapematrix = nbrs.kneighbors_graph(X_pca2).toarray()
    mixscape_pca[adata.obs[obs_label]==ref_label,:] = np.dot(mixscapematrix, mixscape_pca[adata.obs[obs_label]==expr_label,:])/20
    if return_te:
        te2 = adata.X[adata.obs[obs_label]==ref_label,:] - (mixscapematrix/np.sum(mixscapematrix,axis=1)[:,None]) @ (adata.X[adata.obs[obs_label]==expr_label,:])
        return mixscape_pca, mixscapematrix, te2
    else:
        return mixscape_pca, mixscapematrix

def run_mixscape(adata, condition_key, cell_type_key, exp_label, ref_label, dataset_name):
    if isinstance(adata.X, scipy.sparse.spmatrix):
        adata.X = adata.X.toarray()

    print(adata)
    if adata.shape[1]>2000:
        sc.pp.highly_variable_genes(adata, n_top_genes=2000)
        adata = adata[:, adata.var["highly_variable"]]
    
    adata_=adata.copy()
    
    import warnings
    warnings.filterwarnings("ignore")
    warnings.filterwarnings("ignore", category=FutureWarning)
    '''
    ms = pt.tl.Mixscape()

    ms.perturbation_signature(
        adata_,
        pert_key=condition_key,
        control=ref_label,
        split_by=None,
        n_neighbors=20,
    )
    '''
    mixscape_pca, mixscapematrix=mixscape(adata=adata,obs_label=condition_key, ref_label=ref_label, expr_label=exp_label, nn=20, return_te = False)

    adata_.obsm["ef"] = np.array(mixscape_pca)#(adata_.layers["X_pert"])
    sc.pp.neighbors(adata_, use_rep="ef")

    sc.tl.umap(adata_, random_state=1)
    sc.pl.umap(adata_, color=[condition_key, cell_type_key], wspace=0.5)

    adata_.write("./mixscape/"+dataset_name+".h5ad")
    
    import warnings
    warnings.filterwarnings("ignore")
    warnings.filterwarnings("ignore", category=FutureWarning)

    import rpy2.robjects as robjects
    import os
    os.environ["R_HOME"]="/home/xx244/.conda/envs/benchmark/lib/R"
    import anndata2ri
    anndata2ri.activate()

    from metrics import calculate_metrics

    try:
        import warnings
        warnings.filterwarnings("ignore")
        warnings.filterwarnings("ignore", category=FutureWarning)
        calculate_metrics(adata_, batch_key=condition_key, celltype_key=cell_type_key, all=True, n_neighbors=15,
                          embed='ef', savepath="./mixscape/"+dataset_name+".csv")
    except:
        import warnings
        warnings.filterwarnings("ignore")
        warnings.filterwarnings("ignore", category=FutureWarning)
        calculate_metrics(adata_, batch_key=condition_key, celltype_key=cell_type_key, all=True, n_neighbors=15,
                          embed='ef', savepath="./mixscape/"+dataset_name+".csv")

adata=sc.read_h5ad("../data/ASD1.h5ad")
run_mixscape(adata=adata, condition_key='perturb01', cell_type_key="CellType", exp_label='mutated', ref_label='nan', dataset_name="ASD1")

In [None]:
def run_scgen(adata, condition_key, cell_type_key,exp_label, ref_label, dataset_name):
    import warnings
    warnings.filterwarnings("ignore")
    warnings.filterwarnings("ignore", category=FutureWarning)

    adata=adata.copy()
    
    print(adata)
    if adata.shape[1]>2000:
        sc.pp.highly_variable_genes(adata, n_top_genes=2000)
        adata = adata[:, adata.var["highly_variable"]].copy()
    
    pt.tl.Scgen.setup_anndata(adata, batch_key=condition_key, labels_key=cell_type_key)
    scgen_model = pt.tl.Scgen(adata)

    scgen_model.train(
        max_epochs=100,
        batch_size=32,
        early_stopping=True,
        early_stopping_patience=25,
        #accelerator="cpu",
    )

    scgen_model.save("./scgen/"+dataset_name+".pt", overwrite=True)

    latent_X = scgen_model.get_latent_representation()
    latent_adata = sc.AnnData(X=latent_X, obs=adata.obs.copy())
    latent_adata.obsm["latent"]=latent_X

    sc.pp.neighbors(latent_adata)
    sc.tl.umap(latent_adata)
    sc.pl.umap(
        latent_adata,
        color=[condition_key, cell_type_key],
        wspace=0.4,
        frameon=False
    )
    latent_adata.write("./scgen/" + dataset_name + ".h5ad")

    import rpy2.robjects as robjects
    import os
    os.environ["R_HOME"]="/home/xx244/.conda/envs/benchmark/lib/R"
    import anndata2ri
    anndata2ri.activate()

    from metrics import calculate_metrics

    try:
        calculate_metrics(latent_adata, batch_key=condition_key, celltype_key=cell_type_key, all=True, n_neighbors=15,
                          embed='latent', savepath="./scgen/" + dataset_name + ".csv")
    except:
        calculate_metrics(latent_adata, batch_key=condition_key, celltype_key=cell_type_key, all=True, n_neighbors=15,
                          embed='latent', savepath="./scgen/" + dataset_name + ".csv")

adata=sc.read_h5ad("../data/ASD1.h5ad")
run_scgen(adata=adata, condition_key='perturb01', cell_type_key="CellType", exp_label='mutated', ref_label='nan', dataset_name="ASD1")

In [None]:
import os
import multiprocessing
'''
multiprocessing.set_start_method('spawn', force=True)
multiprocessing.set_start_method('forkserver', force=True)
warnings.filterwarnings('ignore', category=DeprecationWarning, module='multiprocessing')

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
'''
def run_sccape(adata, condition_key, cell_type_key, exp_label, ref_label, dataset_name):
    dataset_name1= "z_"+dataset_name

    print(adata)
    if adata.shape[1]>2000:
        sc.pp.highly_variable_genes(adata, n_top_genes=2000)
        adata = adata[:, adata.var["highly_variable"]]
    
    if isinstance(adata.X, scipy.sparse.spmatrix):
        X_dense = adata.X.toarray()
    
    adata.obs["condition"]=adata.obs[condition_key].copy()
    adata.write("./scCAPE/tmp.h5ad")
    
    data_x=csr_matrix(adata.X.T)
    sccape.onmf(data=data_x, dataset_name=dataset_name1, ncells=2000, nfactors=list(range(5, 16)),nreps=2, niters=500)
    
    try:
        sccape.CAPE_train(data_path="./scCAPE/tmp.h5ad", dataset_name=dataset_name1, perturbation_key='condition', split_key=None,
                      max_epochs=300, lambda_adv=0.5, lambda_ort=0.5, patience=5, model_index=0, hparams=None,
                      verbose=True)
    except Exception as e:
        print("Error occurred during scCAPE training:", str(e))


    latent_adata = sc.read_h5ad(os.path.join(dataset_name1, 'CAPE', 'model_index=0_basal.h5ad'))
    latent_adata.obsm["latent"] = latent_adata.X
    latent_adata.obs = adata.obs

    sc.pp.neighbors(latent_adata)
    sc.tl.umap(latent_adata)
    sc.pl.umap(
        latent_adata,
        color=[condition_key, cell_type_key],
        wspace=0.4,
        frameon=False
    )
    latent_adata.write("./scCAPE/" + dataset_name + ".h5ad")

    import rpy2.robjects as robjects
    os.environ["R_HOME"]="/home/xx244/.conda/envs/benchmark/lib/R"
    import anndata2ri
    anndata2ri.activate()

    from metrics import calculate_metrics

    import warnings
    warnings.filterwarnings("ignore")
    warnings.filterwarnings("ignore", category=FutureWarning)
    
    try:
        calculate_metrics(latent_adata, batch_key=condition_key, celltype_key=cell_type_key, all=True, n_neighbors=15,
                          embed='latent', savepath="./scCAPE/" + dataset_name + ".csv")
    except:
        calculate_metrics(latent_adata, batch_key=condition_key, celltype_key=cell_type_key, all=True, n_neighbors=15,
                          embed='latent', savepath="./scCAPE/" + dataset_name + ".csv")

adata=sc.read_h5ad("../data/ASD1.h5ad")
run_sccape(adata=adata, condition_key='perturb01', cell_type_key="CellType", exp_label='mutated', ref_label='nan', dataset_name="ASD1")

In [None]:
def evaluate_ASD1(adata, embed):
    import warnings
    warnings.filterwarnings("ignore")
    warnings.filterwarnings("ignore", category=FutureWarning)

    print("Evaluating the performance of condition (whether or not perturbed) mixing, phase as cell type")
    # Prepare the environment
    import rpy2.robjects as robjects
    import anndata2ri
    anndata2ri.activate()
    library_path = "/gpfs/gibbs/project/wang_zuoheng/xx244/R/4.3/"  # Replace with the actual path
    robjects.r(f'.libPaths(c("{library_path}", .libPaths()))')
    calculate_metrics(adata, batch_key='Perturbation', celltype_key="CellType", all=True, n_neighbors=15,
                      embed=embed)
    print("="*20)

    print("Evaluating the performance of condition (gene target of CRISPR) mixing, phase as cell type")
    # Prepare the environment
    import rpy2.robjects as robjects
    import anndata2ri
    anndata2ri.activate()
    library_path = "/gpfs/gibbs/project/wang_zuoheng/xx244/R/4.3/"  # Replace with the actual path
    robjects.r(f'.libPaths(c("{library_path}", .libPaths()))')
    calculate_metrics(adata, batch_key='perturb01', celltype_key="CellType", all=True, n_neighbors=15,
                      embed=embed)
    print("=" * 20)

    print("Evaluating the performance of batch mixing, phase as cell type")
    # Prepare the environment
    import rpy2.robjects as robjects
    import anndata2ri
    anndata2ri.activate()
    library_path = "/gpfs/gibbs/project/wang_zuoheng/xx244/R/4.3/"  # Replace with the actual path
    robjects.r(f'.libPaths(c("{library_path}", .libPaths()))')
    calculate_metrics(adata, batch_key="Batch", celltype_key="CellType", all=True, n_neighbors=15,
                      embed=embed)
    print("=" * 20)

model_names=["cinema_ot","mixscape","scCAPE","scgen"]
embed_names=['cf','ef','latent','latent']
for i in range(len(model_names)):
    model_name=model_names[i]
    print(model_name)
    adata=sc.read_h5ad("./"+model_name+"/ASD1.h5ad")
    evaluate_ASD1(adata, embed_names[i])