In [None]:
import numpy as np
import pandas as pd
import scanpy as sc

from matplotlib import pyplot as plt
import seaborn as sns
import matplotlib

import brokenaxes
from brokenaxes import brokenaxes as brokax
from matplotlib.gridspec import GridSpec
import sklearn
import uncertainties
import os

import torch

import string
import grae
import collections
import time
from grae.models import GRAE

from pathlib import Path
import matplotlib.patches as mpatches

import Utils as ut
import Models as mod
%matplotlib inline

AEs=["scVI","TopoAE","VAE","PeakVI","GRAE","pca"]
AEs_order=["GRAE","TopoAE","VAE","PeakVI","scVI","pca"]
new_colors=ut.colors_to_use_pastel[:5] + [ut.colors_to_use_pastel[8]]
palette = {AEs[i] : new_colors[i] for i in range(len(AEs))} 
palette = collections.OrderedDict((k, palette[k]) for k in AEs_order)

Replace the folders path with the appropriate ones of your machine

# Robustness

## CMs and Jobs

In [None]:
datasets=["10XhsBrain3kMO", "10XhsBrain3kMO","Kidney", "10XhsPBMC10kMO","10XhsPBMC10kMO", "MouseBrain"]
featurespaces=["Peak","GEX","Peak", "Peak", "GEX", "Peak"]
jobs=["BrP", "BrG", "KiP", "PbP", "PbG", "MbP"]
for dataset, featurespace, job in zip(datasets, featurespaces, jobs):
    print(dataset, featurespace)
    adata=sc.read_h5ad(f"Datasets/{dataset}/FeatureSpaces/{featurespace}/CM/{dataset}_{featurespace}_QC.h5ad")
    for d in np.linspace(0, 50, 11).astype(int):
        print(d)
        for run in range(0, 10):
            m = torch.nn.Dropout(p=d/100)
            inp = torch.tensor(adata.X.todense())
            output = m(inp)
            adata.layers[f"X_{str(d)}_{str(run)}"]=scipy.sparse.csr_matrix(output)
    adata.write(f"Datasets/{dataset}/FeatureSpaces/{featurespace}/CM/{dataset}_{featurespace}_Dropout.h5ad", compression="gzip")

In [None]:
datasets=["10XhsBrain3kMO", "10XhsBrain3kMO","Kidney", "10XhsPBMC10kMO","10XhsPBMC10kMO", "MouseBrain"]
featurespaces=["Peak","GEX","Peak", "Peak", "GEX", "Peak"]
jobs=["BrP", "BrG", "KiP", "PbP", "PbG", "MbP"]
for dataset, featurespace, job in zip(datasets, featurespaces, jobs):
    print(dataset, featurespace, job)
    if featurespace == "GEX":
        table = "Tables/AEsGEX.tsv"
    elif featurespace == "Peak":
        table = "Tables/AEsPeak.tsv"
    os.system(f"sbatch Run_Exp.sh Robustness VAE {dataset} {featurespace}")

## Metrics

In [None]:
datasets=["10XhsBrain3kMO", "10XhsBrain3kMO","Kidney", "10XhsPBMC10kMO","10XhsPBMC10kMO", "MouseBrain"]
featurespaces=["Peak","GEX","Peak", "Peak", "GEX", "Peak"]
jobs=["BrP", "BrG", "KiP", "PbP", "PbG", "MbP"]
df=pd.DataFrame(columns=["Dataset","FS","AE","Dropout","MSE","Run"])
for dataset, fs, job in zip(datasets, featurespaces, jobs):
    adata=sc.read_h5ad(f"Datasets/{dataset}/FeatureSpaces/{fs}/CM/{dataset}_{fs}_QC.h5ad")
    for run in range(0, 10):
        for dp in np.linspace(0, 50, 6).astype(int):
            AE="VAE"
            model_name=f"Datasets/{dataset}/FeatureSpaces/{fs}/Dropout/VAE/{dataset}_{fs}_VAE_{dp}_{run}.pth"
            if os.path.isfile(model_name):
                print(f"Run {run}/10 and dropout {dp}", time.strftime("%a, %d %b %Y %H:%M:%S", time.localtime()), "AE is --> ", AE,  flush=True)
                ae_kwargs={}
                ae_kwargs["hidden_dim"]=int(adata.shape[1]**(1/2))
                ae_kwargs["latent_dim"]=int(adata.shape[1]**(1/3))
                ae_kwargs["input_dim"]=adata.shape[1]
                data = torch.tensor(adata.X.toarray(), dtype=torch.float32)
                model = mod.VAutoencoder(ae_kwargs=ae_kwargs)
                model.load_state_dict(torch.load(model_name))
            
                mse=sklearn.metrics.mean_squared_error(adata.X.toarray(), model.decode(model.encode(data)[0]).detach().numpy())
                
                d=pd.DataFrame(data=np.array([dataset, fs, AE, dp, mse, run]).T, index=df.columns).T
                df=pd.concat([df,d])

            AE="TopoAE"
            model_name=f"Datasets/{dataset}/FeatureSpaces/{fs}/Dropout/TopoAE/{dataset}_{fs}_TopoAE_{dp}_{run}.pth"
            if os.path.isfile(model_name):
                print(f"Run {run}/10 and dropout {dp}", time.strftime("%a, %d %b %Y %H:%M:%S", time.localtime()), "AE is --> ", AE,  flush=True)
                data = torch.tensor(adata.X.toarray(), dtype=torch.float32)
                ae_kwargs={"input_dim" : adata.shape[1],  "hidden_dim" : int(adata.shape[1]**(1/2)), "latent_dim" : int(adata.shape[1]**(1/3))}
                model = mod.TopologicallyRegularizedAutoencoder(ae_kwargs=ae_kwargs)
                model.load_state_dict(torch.load(model_name))
            
                mse=sklearn.metrics.mean_squared_error(adata.X.toarray(), model.decode(model.encode(data)).detach().numpy())
                
                d=pd.DataFrame(data=np.array([dataset, fs, AE, dp, mse, run]).T, index=df.columns).T
                df=pd.concat([df,d])
    
            AE="GRAE"
            model_name=f"Datasets/{dataset}/FeatureSpaces/{fs}/Dropout/GRAE/{dataset}_{fs}_GRAE_{dp}_{run}.pth"
            if os.path.isfile(model_name):
                print(f"Run {run}/10 and dropout {dp}", time.strftime("%a, %d %b %Y %H:%M:%S", time.localtime()), "AE is --> ", AE,  flush=True)
                model = GRAE(n_components=int(adata.shape[1]**(1/3)))
                model.load(model_name)
    
                data=grae.data.base_dataset.BaseDataset(adata.X.toarray(), np.ones(adata.shape[0]), "none", 0.85, 42, np.ones(adata.X.shape[0]))
                
                mse=sklearn.metrics.mean_squared_error(adata.X.toarray(), model.inverse_transform(model.transform(data)))
            
                d=pd.DataFrame(data=np.array([dataset, fs, AE, dp, mse, run]).T, index=df.columns).T
                df=pd.concat([df,d])
    
            if fs == "GEX":
                AE="scVI"
                path=f"Datasets/{dataset}/FeatureSpaces/{fs}/Dropout/scVI/{dataset}_{fs}_scVI_{dp}_{run}"
                model_name=f"{path}/model.pt"         
                if os.path.isfile(model_name):
                    print(f"Run {run}/10 and dropout {dp}", time.strftime("%a, %d %b %Y %H:%M:%S", time.localtime()), "AE is --> ", AE,  flush=True)
                    scvi.model.LinearSCVI.setup_anndata(adata=adata)
                    model = scvi.model.LinearSCVI(adata=adata)
                    model.load(path, adata=adata)
                    model.is_trained=True
        
                    mse=sklearn.metrics.mean_squared_error(adata.X.toarray(), model.get_normalized_expression(n_samples=1))
        
                    d=pd.DataFrame(data=np.array([dataset, fs, AE, dp, mse, run]).T, index=df.columns).T
                    df=pd.concat([df,d])
            else:
                AE="PeakVI"
                path=f"Datasets/{dataset}/FeatureSpaces/{fs}/Dropout/PeakVI/{dataset}_{fs}_PeakVI_{dp}_{run}"
                model_name=f"{path}/model.pt"         
                if os.path.isfile(model_name):
                    print(f"Run {run}/10 and dropout {dp}", time.strftime("%a, %d %b %Y %H:%M:%S", time.localtime()), "AE is --> ", AE,  flush=True)
                    scvi.model.PEAKVI.setup_anndata(adata=adata)
                    model = scvi.model.PEAKVI(adata=adata)
                    model.load(path, adata=adata)
                    model.is_trained=True
        
                    mse=sklearn.metrics.mean_squared_error(adata.X.toarray(), model.get_accessibility_estimates())
        
                    d=pd.DataFrame(data=np.array([dataset, fs, AE, dp, mse, run]).T, index=df.columns).T
                    df=pd.concat([df,d])
df.reset_index(inplace=True)
df.drop("index", axis=1, inplace=True)
df.to_csv("Tables/Robustness.tsv.gz", sep="\t", compression="gzip")

In [None]:
x=np.linspace(0, 50, 6).astype(int)/100
xticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5]

sps1, sps2 = GridSpec(2,1)
params = {'axes.labelsize': 20,
         'axes.titlesize': 20,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [18, 10]}
plt.rcParams.update(params)
fig = plt.figure()

gs = GridSpec(2,3, wspace=0.5, hspace=0.6)
spss = {f"sps{i}" : sp for i, sp in enumerate(gs)}
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 0])
ax4 = fig.add_subplot(gs[1, 1])
ax5 = fig.add_subplot(gs[0, 2])
ax6 = fig.add_subplot(gs[1, 2])
xticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5]
i=0
for dataset, featurespace, job in zip(datasets, featurespaces, jobs):
    d=df[(df["Dataset"]==dataset) & (df["FS"]==featurespace)]
    if "PeakVI" in set(d["AE"]):
        y_breaks=((0.0062, 0.021), (0.255,0.263))
    else:
        y_breaks=None
    bax = brokax(ylims=y_breaks, hspace=0.3, subplot_spec=spss[f"sps{i}"])
    for AE in AEs_order:
        if AE in list(set(d["AE"])):
            d_ae=d[d["AE"]==AE].copy()
            y=d_ae.groupby("Dropout")["MSE"].mean()
            yerr=3*d_ae.groupby("Dropout")["MSE"].sem()
            bax.errorbar(x=x, y=y, yerr=yerr, ecolor=palette[AE], elinewidth=10, marker="o", c=palette[AE])  
    if "PeakVI" in set(d["AE"]):      
        bax.axs[1].set_xticks(xticks, xticks)   
    else:
        bax.axs[0].set_xticks(xticks, xticks)   

    if "PeakVI" not in set(d["AE"]):
        bax.axs[0].set_ylim([0.01, 0.05])
        yticks=[0.015, 0.025, 0.045]
        bax.axs[0].set_yticks(yticks, yticks)
    bax.spines['top'][0].set_visible(False)
    bax.spines['right'][0].set_visible(False)
    try:
        bax.axs[0].set_xlabel('', rotation=0, labelpad=10)
        bax.axs[0].set_ylabel('', rotation=90, labelpad=10)
    except:
        bax.set_xlabel('', rotation=0, labelpad=10)
        bax.set_ylabel('', rotation=90, labelpad=10)
    bax.set_title(f"{dataset}\n{featurespace}\n", loc='left', y=0.975, size=17)
    i+=1
    
for ax in [ax1, ax2, ax3, ax4, ax5, ax6]:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_yticks([])
    ax.set_xticks([], [])
    ax.set_xlabel('Dropout', rotation=0, labelpad=30)
    ax.set_ylabel('MSE', rotation=90, labelpad=60)
handles=[]
for key in AEs_order:
    handles.append(mpatches.Patch(color=palette[key], label=key))
ax5.legend(bbox_to_anchor=(1.67, 0.2), title="Method", fontsize=17, title_fontsize=20, handles=handles)
for i, ax in enumerate([ax1, ax2, ax5, ax3, ax4, ax6]):
    ax.text(-0.1, 1.075, string.ascii_uppercase[i], transform=ax.transAxes, size=20, weight='bold',rotation=0)    

plt.savefig(f"Figures/Figure2.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

# Homogeineity

In [None]:
datasets=["10XhsBrain3kMO", "10XhsBrain3kMO","Kidney", "10XhsPBMC10kMO","10XhsPBMC10kMO", "MouseBrain"]
featurespaces=["Peak","GEX","Peak", "Peak", "GEX", "Peak"]
jobs=["BrP", "BrG", "KiP", "PbP", "PbG", "MbP"]
adatas={}
for dataset, featurespace, job in zip(datasets, featurespaces, jobs):
    print(dataset, featurespace)
    adatas[job]=sc.read_h5ad(f"Datasets/{dataset}/FeatureSpaces/{featurespace}/CM/{dataset}_{featurespace}_MS.h5ad")

In [None]:
df=pd.DataFrame(columns=["Dataset","FS","Method","AbsoluteHetero","NN","N_CT"])
for dataset, featurespace, key in zip(datasets, featurespaces, jobs):
    n_ct=len(set(adatas[key].obs.CellType.dropna()))
    for graph in adatas[key].obsp.keys():
        print(key, graph)
        for a in adatas[key].obsp[graph].toarray():
            b=[True if el>0 else False for el in a]
            d=adatas[key].obs.loc[b].copy()
            subd=pd.DataFrame(index=df.columns, data=[dataset, featurespace, graph, len(set(d["CellType"])), len(d), n_ct]).T
            df=pd.concat([df,subd], axis=0)
df.reset_index(inplace=True)
df.drop("index", axis=1, inplace=True)
df.to_csv("Tables/Heteorgeneity.tsv.gz", sep="\t", compression="gzip")

# Figures

## Figure 2

In [None]:
df=pd.read_csv("Tables/Robustness.tsv.gz", sep="\t", index_col=0)
datasets=["Kidney","Human brain","Human brain", "Mouse brain", "PBMC","PBMC"]
featurespaces=["Peaks","Peaks", "GEX", "Peaks", "Peaks", "GEX"]
jobs=["BrP", "BrG", "KiP", "PbP", "PbG", "MbP"]

df["FS"]=df["FS"].replace("Peak","Peaks")
df["Dataset"]=df["Dataset"].replace("10XhsBrain3kMO","Human brain")
df["Dataset"]=df["Dataset"].replace("10XhsPBMC10kMO","PBMC")
df["Dataset"]=df["Dataset"].replace("MouseBrain","Mouse brain")
df.head()

In [None]:
x=np.linspace(0, 50, 6).astype(int)/100
xticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5]

sps1, sps2 = GridSpec(2,1)
params = {'axes.labelsize': 20,
         'axes.titlesize': 20,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [19, 10]}
plt.rcParams.update(params)
fig = plt.figure()

gs = GridSpec(2,3, wspace=0.5, hspace=0.6)
spss = {f"sps{i}" : sp for i, sp in enumerate(gs)}
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 0])
ax4 = fig.add_subplot(gs[1, 1])
ax5 = fig.add_subplot(gs[0, 2])
ax6 = fig.add_subplot(gs[1, 2])
xticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5]
i=0
for dataset, featurespace, job in zip(datasets, featurespaces, jobs):
    d=df[(df["Dataset"]==dataset) & (df["FS"]==featurespace)]
    if "PeakVI" in set(d["AE"]):
        y_breaks=((0.0062, 0.021), (0.255,0.263))
    else:
        y_breaks=None
    bax = brokax(ylims=y_breaks, hspace=0.3, subplot_spec=spss[f"sps{i}"])
    for AE in AEs_order:
        if AE in list(set(d["AE"])):
            d_ae=d[d["AE"]==AE].copy()
            y=d_ae.groupby("Dropout")["MSE"].mean()
            yerr=3*d_ae.groupby("Dropout")["MSE"].sem()
            bax.errorbar(x=x, y=y, yerr=yerr, ecolor=palette[AE], elinewidth=10, marker="o", c=palette[AE])  
    if "PeakVI" in set(d["AE"]):      
        bax.axs[1].set_xticks(xticks, xticks)   
    else:
        bax.axs[0].set_xticks(xticks, xticks)   

    if "PeakVI" not in set(d["AE"]):
        bax.axs[0].set_ylim([0.01, 0.05])
        yticks=[0.015, 0.025, 0.045]
        bax.axs[0].set_yticks(yticks, yticks)
    bax.spines['top'][0].set_visible(False)
    bax.spines['right'][0].set_visible(False)
    try:
        bax.axs[0].set_xlabel('', rotation=0, labelpad=10)
        bax.axs[0].set_ylabel('', rotation=90, labelpad=10)
    except:
        bax.set_xlabel('', rotation=0, labelpad=10)
        bax.set_ylabel('', rotation=90, labelpad=10)
    bax.set_title(f"    {dataset} {featurespace}\n", loc='center', size=17)
    i+=1
    
for ax in [ax1, ax2, ax3, ax4, ax5, ax6]:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_yticks([])
    ax.set_xticks([], [])
    ax.set_xlabel('Dropout', rotation=0, labelpad=30)
    ax.set_ylabel('MSE', rotation=90, labelpad=60)
handles=[]
for key in AEs_order:
    handles.append(mpatches.Patch(color=palette[key], label=key))
ax5.legend(bbox_to_anchor=(1.67, 0.2), title="Method", fontsize=17, title_fontsize=20, handles=handles)
for i, ax in enumerate([ax1, ax2, ax5, ax3, ax4, ax6]):
    ax.text(-0.1, 1.075, string.ascii_uppercase[i+1], transform=ax.transAxes, size=20, weight='bold',rotation=0)    

plt.savefig(f"Figures/Figure2.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

## Figure 3

In [None]:
df=pd.read_csv("Tables/Heteorgeneity.tsv.gz", sep="\t", index_col=0)
df["Heteorgeneity"]=np.array(df["AbsoluteHetero"]/(df["NN"]*df["N_CT"]))
df["Homogeneity"]=1-df["Heteorgeneity"]
df["Method"]=df["Method"].str.replace("_kNN","")
print(df.shape)
df["FS"]=df["FS"].replace("Peak","Peaks")
df["Dataset"]=df["Dataset"].replace("10XhsBrain3kMO","Human brain")
df["Dataset"]=df["Dataset"].replace("10XhsPBMC10kMO","PBMC")
df["Dataset"]=df["Dataset"].replace("MouseBrain","Mouse brain")
df["Dataset"]=df["Dataset"].replace("Kidney","Kidney")
datasets=["Kidney","Human brain","Human brain", "Mouse brain", "PBMC","PBMC"]
featurespaces=["Peaks","Peaks", "GEX", "Peaks", "Peaks", "GEX"]
jobs=["BrP", "BrG", "KiP", "PbP", "PbG", "MbP"]
df.head()

In [None]:
ae_order=["GRAE","TopoAE","VAE","PeakVI","scVI","pca"]
col_order=['Kidney+Peaks', 'Human brain+Peaks', 'Human brain+GEX', 'Mouse brain+Peaks', 'PBMC+Peaks','PBMC+GEX']

d=pd.DataFrame()
metric="Homogeneity"
for dataset, featurespace, job in zip(datasets, featurespaces, jobs):
    print(dataset, featurespace)
    t=df[(df["Dataset"]==dataset) & (df["FS"]==featurespace)]
    aes=t[[metric,"Method"]].groupby("Method")[metric].mean().index
    mean=np.array(t[[metric,"Method"]].dropna().groupby("Method")[metric].mean())
    sem=np.array(3*t[[metric,"Method"]].dropna().groupby("Method")[metric].sem())
    for ae, m, s in zip(aes, mean, sem):
        d=pd.concat([d, pd.DataFrame([f"{dataset}+{featurespace}", ae, str(uncertainties.ufloat(m, s))], index=["Ds&Fs", "DR method", "Mean +- 3*SEM"]).T], axis=0)
defd=d.pivot(index="DR method", columns="Ds&Fs", values="Mean +- 3*SEM")[col_order].loc[ae_order]
for col in defd.columns:
    defd[col]=defd[col].str.replace("+/-",u"\u00B1")
defd.to_csv("Tables/TableHomogeneity.tsv", sep="\t", columns=None)
defd

In [None]:
params = {'axes.labelsize': 20,
         'axes.titlesize': 20,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [18, 10]}
plt.rcParams.update(params)

fig, axs = plt.subplots(2,3)
axs=axs.flatten()
i=0
for dataset, featurespace, job in zip(datasets, featurespaces, jobs):
    d=df[(df["Dataset"]==dataset) & (df["FS"]==featurespace)]  
    order=[ae for ae in AEs_order if ae in set(d["Method"])]
    colors=[palette[ae] for ae in order]
    means=np.array(d.groupby("Method")["Homogeneity"].mean().loc[order])
    sems=5*np.array(d.groupby("Method")["Homogeneity"].sem().loc[order])
    axs[i].bar(x=order, height=means, yerr=sems, edgecolor=colors, facecolor=colors, capsize=10, ecolor="black", linewidth=5, alpha=0.75, width=0.8)
    axs[i].set_ylim([0.95, 1])
    axs[i].set_title(f"{dataset} {featurespace}\n", loc='center', size=17)
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)    
    axs[i].text(-0.2, 1.15, string.ascii_uppercase[i+1], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)    
    axs[i].set_xlabel('Method', rotation=0, labelpad=10)
    axs[i].set_ylabel('Homogeneity', rotation=90, labelpad=10)
    i+=1    

fig.tight_layout(h_pad=3, w_pad=2)
plt.savefig("Figures/Figure3.png", format="png", dpi=300, bbox_inches='tight')
plt.show()