In [1]:
import anndata as ad
import scanpy as sc
import gc
import sys
import cellanova as cnova
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sea

In [2]:
def get_mean_std(adata, batch_key):
    if np.max(adata.X) > 15:
        sc.pp.filter_cells(adata, min_genes=300)
        sc.pp.filter_genes(adata, min_cells=10)

        sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
        sc.pp.log1p(adata)

    if adata.shape[1] > 3000:
        sc.pp.highly_variable_genes(adata, n_top_genes=3000, batch_key=batch_key)
        adata = adata[:, adata.var["highly_variable"]].copy()
    else:
        sc.pp.highly_variable_genes(adata, n_top_genes=adata.shape[1], batch_key=batch_key)

    if not isinstance(adata.X, np.ndarray):
        adata.X = adata.X.toarray()
    
    print(adata.X.shape)
    mean=np.mean(adata.X, axis=0, keepdims=True)
    std=np.std(adata.X, axis=0, keepdims=True)
    return mean,std

In [3]:
def calculate_rowwise_correlation_scaled(adata1, adata2, std, batch_key="batch_all_with_condition"):
    # Ensure the obs index and batch_key match
    # assert np.sum(adata1.obs["batch_all_with_condition"]!=adata2.obs["batch_all_with_condition"])==0, "obs indices do not match between the two AnnData objects"
    assert batch_key in adata1.obs.columns, f"{batch_key} not found in adata1.obs"
    assert batch_key in adata2.obs.columns, f"{batch_key} not found in adata2.obs"

    results = []

    # Iterate through unique batches
    unique_batches = adata1.obs[batch_key].unique()
    for batch in unique_batches:
        # Subset the data for the current batch
        batch_mask = adata1.obs[batch_key] == batch
        data1 = adata1[batch_mask].X
        data2 = adata2[batch_mask].X

        data1=data1*std
        data2=data2*std

        barcodes = adata1[batch_mask].obs_names.tolist()

        # Ensure the data is in dense format if sparse
        if not isinstance(data1, np.ndarray):
            data1 = data1.toarray()
        if not isinstance(data2, np.ndarray):
            data2 = data2.toarray()

        # Compute correlation for each row
        for i in range(data1.shape[0]):
            row_corr = np.corrcoef(data1[i, :], data2[i, :])[0, 1]
            mse = np.mean(np.square(data1[i, :] - data2[i, :]))
            results.append({"correlation": row_corr, batch_key: batch, "barcode": barcodes[i], "mse": mse})

    # Convert results to DataFrame
    result_df = pd.DataFrame(results)
    return result_df

In [4]:
def evaluate_cellanova_mse(adata, batch_key, condition_key, dataset_name, std):
    adata.raw=None
    print("adata preprocessing...")

    import warnings
    warnings.filterwarnings("ignore")
    warnings.filterwarnings("ignore", category=FutureWarning)

    if isinstance(batch_key, str):
       batch_key = [batch_key]

    batch_all = []
    for i in range(adata.shape[0]):
        tmp = "__".join([adata.obs[batch_keyj][i] for batch_keyj in batch_key])
        batch_all.append(tmp)
    batch_all = np.array(batch_all)
    adata.obs["batch_all"]=batch_all
    adata.obs["batch_all"] = adata.obs["batch_all"].astype("category")
    print("batch_all", np.unique(batch_all))

    batch_all_with_condition = []
    for i in range(adata.shape[0]):
        tmp = "__".join([adata.obs[batch_keyj][i] for batch_keyj in batch_key])
        tmp = tmp + "__" + adata.obs[condition_key][i]
        batch_all_with_condition.append(tmp)
    batch_all_with_condition = np.array(batch_all_with_condition)
    adata.obs["batch_all_with_condition"] = batch_all_with_condition
    adata.obs["batch_all_with_condition"] = adata.obs["batch_all_with_condition"].astype("category")
    print("batch_all_with_condition", np.unique(batch_all_with_condition))

    batch_key.append("batch_all")
    batch_key.append("batch_all_with_condition")
    print("Finish preprocessing")

    main_effect_adata = ad.AnnData(adata.layers['main_effect'], dtype=np.float32)
    main_effect_adata.var_names = adata.var_names
    main_effect_adata.obs = adata.obs.copy()

    integrated = ad.AnnData(adata.layers['denoised'], dtype=np.float32)
    integrated.obs = adata.obs.copy()
    integrated.var_names = adata.var_names

    print("Calculating global distortion...")
    df_global_correlation=calculate_rowwise_correlation_scaled(adata, integrated, std)
    df_global_correlation.to_csv("./cellanova/"+dataset_name+"_global_correlation_scaled.csv")
    print("Finish")

# ECCITE dataset

In [5]:
adata=sc.read_h5ad("../data/ECCITE.h5ad")
mean,std=get_mean_std(adata=adata, batch_key='replicate')
print(mean,std)

adata=sc.read_h5ad("./cellanova/" + "ECCITE" + "_results.h5ad")
evaluate_cellanova_mse(adata=adata, batch_key='replicate', condition_key='perturbation', dataset_name="ECCITE", std=std)

(20729, 2000)
[[7.9479944e-03 7.7820746e-03 1.2972093e+00 ... 1.5478361e-04
  1.8493716e-04 2.7093486e-04]] [[0.0779786  0.09267601 0.85973036 ... 0.01117775 0.01213464 0.01424947]]
adata preprocessing...
batch_all ['rep1' 'rep2' 'rep3']
batch_all_with_condition ['rep1__NT' 'rep1__Perturbed' 'rep2__NT' 'rep2__Perturbed' 'rep3__NT'
 'rep3__Perturbed']
Finish preprocessing
Calculating global distortion...
Finish


# ASD dataset

In [6]:
adata=sc.read_h5ad("../data/ASD1.h5ad")
mean,std=get_mean_std(adata=adata, batch_key='Batch')
print(mean,std)

adata=sc.read_h5ad("./cellanova/" + "ASD1" + "_results.h5ad")
evaluate_cellanova_mse(adata=adata, batch_key='Batch', condition_key='perturb01', dataset_name="ASD1", std=std)

(40603, 2000)
[[0.01146966 0.00093528 0.06404175 ... 0.00046896 0.00179852 0.00077637]] [[0.23614296 0.06750284 0.58396368 ... 0.04894076 0.09167504 0.06467935]]
adata preprocessing...
batch_all ['1' '10' '11' '12' '13' '14' '15' '16' '17' '18' '2' '3' '4' '5' '6' '7'
 '8' '9']
batch_all_with_condition ['10__mutated' '10__nan' '11__mutated' '11__nan' '12__mutated' '12__nan'
 '13__mutated' '13__nan' '14__mutated' '14__nan' '15__mutated' '15__nan'
 '16__mutated' '16__nan' '17__mutated' '18__mutated' '18__nan'
 '1__mutated' '1__nan' '2__mutated' '2__nan' '3__mutated' '3__nan'
 '4__mutated' '4__nan' '5__mutated' '5__nan' '6__mutated' '6__nan'
 '7__mutated' '7__nan' '8__mutated' '8__nan' '9__mutated' '9__nan']
Finish preprocessing
Calculating global distortion...
Finish
