In [None]:
import os
os.chdir("/root/data/DBP_sa_bc/")
from os.path import join as pj
import argparse
import sys
sys.path.append("modules")
import utils
import numpy as np
import scib
import scib.metrics as me
import anndata as ad
import scipy
import pandas as pd
import re

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='lung_ts')
parser.add_argument('--experiment', type=str, default='e54')
parser.add_argument('--model', type=str, default='default')
parser.add_argument('--init_model', type=str, default='sp_00001899')
parser.add_argument('--method', type=str, default='DBP_sa_bc')
parser.add_argument('--K', type=int, default='41')
o, _ = parser.parse_known_args()  # for python interactive
# o = parser.parse_args()

In [None]:
K = o.K
break_index_dir = pj("result", o.task, o.experiment, o.model, "predict", o.init_model)

if "DBP_sa_bc" in o.method:
    result_dir = pj("result", "comparison", o.task, o.method, o.experiment, o.init_model)
else:
    result_dir = pj("result", "comparison", o.task, o.method)
cfg_task = re.sub("_atlas|_generalize|_transfer|_ref_.*", "", o.task)
data_config = utils.load_toml("configs/data.toml")[cfg_task]
for k, v in data_config.items():
    vars(o)[k] = v
model_config = utils.load_toml("configs/model.toml")["default"]
if o.model != "default":
    model_config.update(utils.load_toml("configs/model.toml")[o.model])
for k, v in model_config.items():
    vars(o)[k] = v
o.s_joint, o.combs, *_ = utils.gen_all_batch_ids(o.s_joint, o.combs)


In [None]:
# Load cell type labels
if o.task == "hae":
    labels = []
    for raw_data_dir in o.raw_data_dirs:
        label = utils.load_csv(pj(raw_data_dir, "label", "labels.csv"))
        labels += utils.transpose_list(label)[13][1:]
    labels = np.array(labels)
    print(np.unique(labels))
elif o.task == "wnn_rna":
    labels = []
    for raw_data_dir in o.raw_data_dirs:
        label = utils.load_csv(pj(raw_data_dir, "label", "meta.csv"))
        labels += utils.transpose_list(label)[10][1:]
    labels = np.array(labels)
    print(np.unique(labels))
elif o.task == "lung_ts":
    labels = []
    for raw_data_dir in o.raw_data_dirs:
        label = utils.load_csv(pj(raw_data_dir, "label", "meta.csv"))
        labels += utils.transpose_list(label)[13][1:]
        # labels += utils.transpose_list(label)[14][1:]
    # replacements = {'Alveolar_Type1':'Alveolar', 'Alveolar_Type2':'Alveolar', 
    #                 'B_cell_mature':'B', 'B_cell_naive':'B', 
    #                 'Basal':'Basal', 'Blood_vessel':'Blood_vessel', 
    #                 'Ciliated':'Ciliated', 'DC_1':'DC',
    #                 'DC_2':'DC', 'DC_Monocyte_Dividing':'DC_Monocyte_Dividing',
    #                 'DC_activated':'DC_activated', 'DC_plasmacytoid':'DC_plasmacytoid', 
    #                 'Fibroblast':'Fibroblast', 'Lymph_vessel':'Lymph_vessel',
    #                 'Macrophage_Dividing':'Macrophag', 'Macrophage_MARCOneg':'Macrophag', 
    #                 'Macrophage_MARCOpos':'Macrophag', 'Mast_cells':'Mast_cells', 
    #                 'Monocyte':'Monocyte', 'Muscle_cells' :'Muscle_cells',
    #                 'NK':'NK', 'NK_Dividing':'NK', 
    #                 'Plasma_cells':'Plasma_cells', 'Secretory_club':'Secretory_club', 
    #                 'T_CD4':'T_CD4', 'T_CD8_CytT':'T_CD8_CytT', 
    #                 'T_cells_Dividing' :'T', 'T_regulatory':'T'}
    # replaced_list = [replacements[value] if value in replacements else value for value in labels]
    labels = np.array(labels)
    print(np.unique(labels))

In [None]:
# Load predicted latent variables
o.mods = ["rna"]
o.pred_dir = pj("result", o.task, o.experiment, o.model, "predict", o.init_model)
pred = utils.load_predicted(o)

In [None]:
if o.method in ["DBP_sa_bc", "mofa", "scmomat", "stabmap", "scvaeit"]:
    output_type = "embed"
elif o.method in [
    "midas_feat+wnn", 
    "harmony+wnn", 
    "pca+wnn",
    "seurat_cca+wnn",
    "seurat_rpca+wnn",
    "scanorama_embed+wnn",
    "scanorama_feat+wnn",
    "liger+wnn",
    "bbknn",
    ]:
    output_type = "graph"
else:
    assert False, o.method+": invalid method!"

In [None]:
embed = "X_emb"
batch_key = "batch"
label_key = "label"
cluster_key = "cluster"
si_metric = "euclidean"
subsample = 0.5
verbose = False

In [None]:
w = pred["w"]["joint"]
c = pred["z"]["joint"][:, :o.dim_c]*w
s = pred["s"]["joint"]
index = np.loadtxt(pj(break_index_dir, "break_index.csv"), delimiter=",", dtype=int)
c_ord = c[:,index]
c_bre = c_ord[:, :K]

In [None]:
if o.method == "DBP_sa_bc":
    adata = ad.AnnData(c_bre)
    adata.obsm[embed] = c_bre
    adata.obs[batch_key] = s.astype(str)
    adata.obs[batch_key] = adata.obs[batch_key].astype("category")
    adata.obs[label_key] = labels
    adata.obs[label_key] = adata.obs[label_key].astype("category")
elif o.method in ["mofa", "stabmap"]:
    adata = ad.AnnData(c*0)
    embeddings = utils.load_csv(pj(result_dir, "embeddings.csv"))
    adata.obsm[embed] = np.array(embeddings)[1:, 1:].astype(np.float32)
    adata.obs[batch_key] = s.astype(str)
    adata.obs[batch_key] = adata.obs[batch_key].astype("category")
    adata.obs[label_key] = labels
    adata.obs[label_key] = adata.obs[label_key].astype("category")
elif o.method in ["scmomat", "scvaeit"]:
    adata = ad.AnnData(c*0)
    embeddings = utils.load_csv(pj(result_dir, "embeddings.csv"))
    adata.obsm[embed] = np.array(embeddings).astype(np.float32)
    adata.obs[batch_key] = s.astype(str)
    adata.obs[batch_key] = adata.obs[batch_key].astype("category")
    adata.obs[label_key] = labels
    adata.obs[label_key] = adata.obs[label_key].astype("category")
elif o.method in [
    "midas_feat+wnn", 
    "harmony+wnn", 
    "pca+wnn",
    "seurat_cca+wnn",
    "seurat_rpca+wnn",
    "scanorama_embed+wnn",
    "scanorama_feat+wnn",
    "liger+wnn",
    "bbknn",
    ]:
    adata = ad.AnnData(c*0)
    adata.obs[batch_key] = s.astype(str)
    adata.obs[batch_key] = adata.obs[batch_key].astype("category")
    adata.obs[label_key] = labels
    adata.obs[label_key] = adata.obs[label_key].astype("category")
    adata.obsp["connectivities"] = scipy.io.mmread(pj(result_dir, "connectivities.mtx")).tocsr()
    adata.uns["neighbors"] = {'connectivities_key': 'connectivities'}

######

#####

In [None]:
results = {}

print('clustering...')
res_max, nmi_max, nmi_all = scib.clustering.opt_louvain(adata, label_key=label_key,
    cluster_key=cluster_key, function=me.nmi, use_rep=embed, verbose=verbose, inplace=True)

results['graph_conn'] = me.graph_connectivity(adata, label_key=label_key)
print("graph_conn: " + str(results['graph_conn']))

results['batch_ASW'] = me.silhouette_batch(adata, label_key=label_key, batch_key=batch_key,
    embed=embed, verbose=verbose)
print("batch_ASW: " + str(results['batch_ASW']))

type_ = "knn" if output_type == "graph" else None
results['kBET'] = me.kBET(adata, batch_key=batch_key, label_key=label_key, embed=embed, 
    type_=type_, verbose=verbose)
print("kBET: " + str(results['kBET']))

# results['iLISI'] = me.ilisi_graph(adata, batch_key=batch_key, type_="knn",
#     subsample=subsample*100, n_cores=1, verbose=verbose)
# print("iLISI: " + str(results['iLISI']))
results['iLISI'] = me.ilisi_graph(adata, batch_key=batch_key, type_="knn",
    subsample=subsample*100, verbose=verbose)
print("iLISI: " + str(results['iLISI']))


results['label_ASW'] = me.silhouette(adata, label_key=label_key, embed=embed)
print("label_ASW: " + str(results['label_ASW']))

results['il_score_ASW'] = me.isolated_labels(adata, label_key=label_key, batch_key=batch_key,
    embed=embed, cluster=False, verbose=verbose)
print("il_score_ASW: " + str(results['il_score_ASW']))

results['il_score_f1'] = me.isolated_labels(adata, label_key=label_key, batch_key=batch_key,
    embed=embed, cluster=True, verbose=verbose)
print("il_score_f1: " + str(results['il_score_f1']))

results['NMI'] = me.nmi(adata, group1=cluster_key, group2=label_key, method='arithmetic')
print("NMI: " + str(results['NMI']))

results['ARI'] = me.ari(adata, group1=cluster_key, group2=label_key)
print("ARI: " + str(results['ARI']))

results['cLISI'] = me.clisi_graph(adata, batch_key=batch_key, label_key=label_key, type_="knn",
    subsample=subsample*100,  verbose=verbose)
print("cLISI: " + str(results['cLISI']))

# results = {k: float(v) for k, v in results.items()}
# results['batch_score'] = np.nanmean([results['graph_conn'], results['batch_ASW'], results['kBET'], results['iLISI']])
# results['bio_score'] = np.nanmean([results['label_ASW'], results['il_score_ASW'], results['il_score_f1'], 
#                                    results['NMI'], results['ARI'], results['cLISI']])
# results["overall_score"] = float(0.4 * results['batch_score'] + 0.6 * results['bio_score'])

df = pd.DataFrame({
    'graph_conn':     [results['graph_conn']],
    'batch_ASW':      [results['batch_ASW']],
    'kBET':           [results['kBET']],
    'iLISI':          [results['iLISI']],
    # 'batch_score':    [results['batch_score']],
    'label_ASW':      [results['label_ASW']],
    'il_score_ASW':   [results['il_score_ASW']],
    'il_score_f1':    [results['il_score_f1']],   
    'NMI':            [results['NMI']],
    'ARI':            [results['ARI']],
    'cLISI':          [results['cLISI']],
    # 'bio_score':      [results['bio_score']],
    # 'overall_score':  [results['overall_score']]
})
print(df)
utils.mkdirs(result_dir, remove_old=False)
df.to_excel(pj(result_dir, "metrics_batch_bio_break"+str(o.K)+".xlsx"), index=False)

In [None]:
# results = {}

# print('clustering...')
# res_max, nmi_max, nmi_all = scib.clustering.opt_louvain(adata, label_key=label_key,
#     cluster_key=cluster_key, function=me.nmi, use_rep=embed, verbose=verbose, inplace=True)

# results['NMI'] = me.nmi(adata, group1=cluster_key, group2=label_key, method='arithmetic')
# print("NMI: " + str(results['NMI']))

# results['ARI'] = me.ari(adata, group1=cluster_key, group2=label_key)
# print("ARI: " + str(results['ARI']))

# type_ = "knn" if output_type == "graph" else None
# results['kBET'] = me.kBET(adata, batch_key=batch_key, label_key=label_key, embed=embed, 
#     type_=type_, verbose=verbose)
# print("kBET: " + str(results['kBET']))

# results['il_score_f1'] = me.isolated_labels(adata, label_key=label_key, batch_key=batch_key,
#     embed=embed, cluster=True, verbose=verbose)
# print("il_score_f1: " + str(results['il_score_f1']))

# results['graph_conn'] = me.graph_connectivity(adata, label_key=label_key)
# print("graph_conn: " + str(results['graph_conn']))

# # results['cLISI'] = me.clisi_graph(adata, batch_key=batch_key, label_key=label_key, type_="knn",
# #     subsample=subsample*100, n_cores=1, verbose=verbose)
# # print("cLISI: " + str(results['cLISI']))

# results['cLISI'] = me.clisi_graph(adata, batch_key=batch_key, label_key=label_key, type_="knn",
#     subsample=subsample*100,  verbose=verbose)
# print("cLISI: " + str(results['cLISI']))

# # results['iLISI'] = me.ilisi_graph(adata, batch_key=batch_key, type_="knn",
# #     subsample=subsample*100, n_cores=1, verbose=verbose)
# # print("iLISI: " + str(results['iLISI']))
# results['iLISI'] = me.ilisi_graph(adata, batch_key=batch_key, type_="knn",
#     subsample=subsample*100, verbose=verbose)
# print("iLISI: " + str(results['iLISI']))

# results = {k: float(v) for k, v in results.items()}
# results['batch_score'] = np.nanmean([results['iLISI'], results['graph_conn'], results['kBET']])
# results['bio_score'] = np.nanmean([results['NMI'], results['ARI'], results['il_score_f1'], results['cLISI']])
# results["overall_score"] = float(0.4 * results['batch_score'] + 0.6 * results['bio_score'])

# df = pd.DataFrame({
#     'iLISI':          [results['iLISI']],
#     'graph_conn':     [results['graph_conn']],
#     'kBET':           [results['kBET']],
#     'batch_score':    [results['batch_score']],
#     'NMI':            [results['NMI']],
#     'ARI':            [results['ARI']],
#     'il_score_f1':    [results['il_score_f1']],
#     'cLISI':          [results['cLISI']],
#     'bio_score':      [results['bio_score']],
#     'overall_score':  [results['overall_score']]
# })
# print(df)
# utils.mkdirs(result_dir, remove_old=False)
# df.to_excel(pj(result_dir, "metrics_batch_bio_break_k150.xlsx"), index=False)