In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import anndata
from sklearn.decomposition import PCA
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score
from tqdm import trange

In [None]:
methods = ("fastscbatch", "scbatch", "combat", "mnn", "limma", "scdml", "raw")
ARI = {method: 0 for method in methods}
AMI = {method: 0 for method in methods}

fig, ax = plt.subplots(2, 7, figsize=(21, 7))
bat = None

for method in methods:
	if method == "raw":
		adata = sc.read_h5ad("./data.h5ad")
	else:
		adata = sc.read_h5ad(f"../../method/{method}/results/{method}_usoskin.h5ad")
	if bat is None:
		bat = adata.obs["batch"]
	adata.obs["batch"] = bat
	sc.pp.neighbors(adata, n_neighbors=25, use_rep="X", random_state=0)
	sc.tl.umap(adata)
	sc.pl.umap(adata, color="batch", ax=ax[0, methods.index(method)], show=False)
	sc.pl.umap(adata, color="celltype", ax=ax[1, methods.index(method)], show=False)
	sc.tl.leiden(adata)
	ARI[method] = adjusted_rand_score(adata.obs["leiden"], adata.obs["celltype"])
	AMI[method] = adjusted_mutual_info_score(adata.obs["leiden"], adata.obs["celltype"])

for i in range(2):
	for j in range(7):
		if j != 6:
			ax[i, j].get_legend().remove()
		ax[i, j].set_xlabel("")
		ax[i, j].set_ylabel("")
		ax[i, j].set_title("")
ax[0, 0].set_ylabel("Batch")
ax[1, 0].set_ylabel("Celltype")
for j, method in enumerate(methods):
	ax[0, j].set_title(method)
plt.tight_layout()
plt.savefig("../../figures/usoskin_umap.png")

In [None]:
del ARI["raw"]
del AMI["raw"]
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
sns.barplot(x=list(ARI.keys()), y=list(ARI.values()), ax=ax[0], palette="viridis")
sns.barplot(x=list(AMI.keys()), y=list(AMI.values()), ax=ax[1], palette="viridis")
ax[0].set_title("ARI")
ax[1].set_title("AMI")
plt.tight_layout()
plt.savefig("../../figures/usoskin_metrics.png")