In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import anndata
from sklearn.decomposition import PCA
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score
from tqdm import trange

In [None]:
methods = ("fastscbatch", "scbatch", "combat", "mnn", "limma", "scdml", "raw")
methods_alias = ("fast-scBatch", "scBatch", "ComBat", "MNN", "Limma", "scDML", "Raw")

# n = 960
ARI_960 = {method: np.zeros(10) for method in methods}
AMI_960 = {method: np.zeros(10) for method in methods}
for i in trange(1, 11):
	for method in methods:
		if method == "raw":
			adata = sc.read_h5ad(f"./data/data{i}.h5ad")
		else:
			adata = sc.read_h5ad(f"../method/{method}/results/{method}_simu{i}.h5ad")
		sc.tl.pca(adata)
		sc.pp.neighbors(adata, n_neighbors=20, use_rep="X")
		sc.tl.leiden(adata, resolution=0.5)

		leiden_labels = adata.obs["leiden"].tolist()
		leiden_labels = [int(label[-1])+1 for label in leiden_labels]

		true_labels = adata.obs["Group"].tolist()
		true_labels = [int(label[-1]) for label in true_labels]

		ARI_960[method][i-1] = adjusted_rand_score(true_labels, leiden_labels)
		AMI_960[method][i-1] = adjusted_mutual_info_score(true_labels, leiden_labels)

In [None]:
# n = 600
ARI_600 = {method: np.zeros(10) for method in methods}
AMI_600 = {method: np.zeros(10) for method in methods}

for i in trange(11, 21):
	for method in methods:
		if method == "raw":
			adata = sc.read_h5ad(f"./data/data{i}.h5ad")
		else:
			adata = sc.read_h5ad(f"../method/{method}/results/{method}_simu{i}.h5ad")
		sc.tl.pca(adata)
		sc.pp.neighbors(adata, n_neighbors=20, use_rep="X")
		sc.tl.leiden(adata, resolution=0.5)

		leiden_labels = adata.obs["leiden"].tolist()
		leiden_labels = [int(label[-1])+1 for label in leiden_labels]

		true_labels = adata.obs["Group"].tolist()
		true_labels = [int(label[-1]) for label in true_labels]

		ARI_600[method][i-11] = adjusted_rand_score(true_labels, leiden_labels)
		AMI_600[method][i-11] = adjusted_mutual_info_score(true_labels, leiden_labels)

In [None]:
# n = 360
ARI_360 = {method: np.zeros(10) for method in methods}
AMI_360 = {method: np.zeros(10) for method in methods}

for i in trange(21, 31):
	for method in methods:
		if method == "raw":
			adata = sc.read_h5ad(f"./data/data{i}.h5ad")
		else:
			adata = sc.read_h5ad(f"../method/{method}/results/{method}_simu{i}.h5ad")
		sc.tl.pca(adata)
		sc.pp.neighbors(adata, n_neighbors=20, use_rep="X")
		sc.tl.leiden(adata, resolution=0.5)

		leiden_labels = adata.obs["leiden"].tolist()
		leiden_labels = [int(label[-1])+1 for label in leiden_labels]

		true_labels = adata.obs["Group"].tolist()
		true_labels = [int(label[-1]) for label in true_labels]

		ARI_360[method][i-21] = adjusted_rand_score(true_labels, leiden_labels)
		AMI_360[method][i-21] = adjusted_mutual_info_score(true_labels, leiden_labels)

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i, method in enumerate(methods):
	sns.boxplot(ARI_360[method], ax=axes[0, 0], positions=[i], width=0.6)
	sns.boxplot(ARI_600[method], ax=axes[0, 1], positions=[i], width=0.6)
	sns.boxplot(ARI_960[method], ax=axes[0, 2], positions=[i], width=0.6)
	sns.boxplot(AMI_360[method], ax=axes[1, 0], positions=[i], width=0.6)
	sns.boxplot(AMI_600[method], ax=axes[1, 1], positions=[i], width=0.6)
	sns.boxplot(AMI_960[method], ax=axes[1, 2], positions=[i], width=0.6)
	axes[0, 0].set_xticks(range(len(methods_alias)))
	axes[0, 0].set_xticklabels(methods_alias, rotation=30, fontsize=14)
	axes[0, 1].set_xticks(range(len(methods_alias)))
	axes[0, 1].set_xticklabels(methods_alias, rotation=30, fontsize=14)
	axes[0, 2].set_xticks(range(len(methods_alias)))
	axes[0, 2].set_xticklabels(methods_alias, rotation=30, fontsize=14)
	axes[1, 0].set_xticks(range(len(methods_alias)))
	axes[1, 0].set_xticklabels(methods_alias, rotation=30, fontsize=14)
	axes[1, 1].set_xticks(range(len(methods_alias)))
	axes[1, 1].set_xticklabels(methods_alias, rotation=30, fontsize=14)
	axes[1, 2].set_xticks(range(len(methods_alias)))
	axes[1, 2].set_xticklabels(methods_alias, rotation=30, fontsize=14)
	axes[0, 0].set_title("n=360", fontsize=15)
	axes[0, 1].set_title("n=600", fontsize=15)
	axes[0, 2].set_title("n=960", fontsize=15)
	axes[0, 0].set_ylabel("ARI", fontsize=15)
	axes[1, 0].set_ylabel("AMI", fontsize=15)
plt.tight_layout()
plt.savefig("../figures/simu_boxplot.png")

In [None]:
fig, axes = plt.subplots(2, 7, figsize=(15, 5))

i = 7
for method in methods:
	if method == "raw":
		adata = sc.read_h5ad(f"./data/data{i}.h5ad")
	else:
		adata = sc.read_h5ad(f"../method/{method}/results/{method}_simu{i}.h5ad")
	sc.tl.pca(adata)
	sc.pp.neighbors(adata, n_neighbors=20, use_rep="X")
	sc.tl.leiden(adata, resolution=0.5)
	sc.tl.umap(adata)
	sc.pl.umap(adata, show=False, color=["Batch"], ax=axes[0][methods.index(method)], size=50, title="")
	sc.pl.umap(adata, show=False, color=["Group"], ax=axes[1][methods.index(method)], size=50, title="")
	# hide legend except for the last
	if methods.index(method) != 6:
		axes[0][methods.index(method)].get_legend().remove()
		axes[1][methods.index(method)].get_legend().remove()

axes[0][0].set_ylabel("colored by batch")
axes[1][0].set_ylabel("colored by cell type")
axes[0][0].set_title("fast_scBatch")
axes[0][1].set_title("scBatch")
axes[0][2].set_title("ComBat")
axes[0][3].set_title("MNN")
axes[0][4].set_title("limma")
axes[0][5].set_title("scDML")
axes[0][6].set_title("Raw (log count)")

fig.tight_layout()
plt.savefig("../figures/simu_umap.png")