In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.lines as mlines
import matplotlib as mpl


import sklearn
import os
import scipy
import glob
import string
import logomaker

import torch
#import torch_geometric

from pathlib import Path
from statannotations.Annotator import Annotator

import Utils as ut
%matplotlib inline
ds_infos=pd.read_csv("Tables/Datasets_infos.tsv", sep="\t", index_col=0)

# Metrics

## F1 Acc Prec Rec

In [None]:
tot=pd.DataFrame(columns=["Accuracy","Precision","Recall","F1W","F1M","Dataset","FeaturesSpace","Name"])
for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    data=pd.DataFrame(columns=tot.columns)
    path=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO"
    file=f"{path}_Predictions.tsv.gz"
    if os.path.isfile(file) == True:              
        df=pd.read_csv(file, sep="\t", index_col=0)
        df=df[df["GNN_set"]=="Test"]
        cfs[f"{dataset}_{featurespace}"]=pd.DataFrame(sklearn.metrics.confusion_matrix(df[label], df.GNN_prediction), columns=sorted(set(df[label])), index=sorted(set(df[label])))
        acc=np.around(df[df[label]==df["GNN_prediction"]].shape[0]/df.shape[0], decimals=4)
        f1=sklearn.metrics.precision_recall_fscore_support(df[label], df.GNN_prediction, average="weighted")
        f1m=sklearn.metrics.precision_recall_fscore_support(df[label], df.GNN_prediction, average="macro")[2]
        to_append=pd.DataFrame(index=data.columns, data=[acc]+[f for f in np.array(f1[:3])]+[f1m, dataset, featurespace, name]).T
        data=pd.concat([data, to_append])
        tot=pd.concat([tot, data])
    else:
       print(file)
tot["F1W"]=tot["F1W"].astype(float)
tot["Accuracy"]=tot["Accuracy"].astype(float)
tot["Precision"]=tot["Precision"].astype(float)
tot["Recall"]=tot["Recall"].astype(float)
tot.reset_index(inplace=True)
tot.drop("index", axis=1, inplace=True)
tot["FeaturesSpace"]=tot["FeaturesSpace"].replace("Peak","Peaks")
tot["FeaturesSpace"]=tot["FeaturesSpace"].replace("Window","Windows")
tot.to_csv("Tables/F1PrecRec.tsv.gz", sep="\t", compression="gzip")

In [None]:
df=pd.read_csv("Tables/F1PrecRec.tsv.gz", sep="\t", index_col=0)
df=df[["Name","FeaturesSpace","Accuracy","F1M"]]
df

##  Stability, Sepecificity, vs DA

In [None]:
tot=pd.DataFrame(columns=["Stability","Specificity","Dataset","FeaturesSpace", "CT", "Name"])
for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    files=sorted(glob.glob(f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{label}/{dataset}_{featurespace}_GRAE_kNN_{label}*Specificity.tsv.gz"))
    for file in files:
        data=pd.DataFrame(columns=tot.columns)  
        d=pd.read_csv(file, sep="\t", index_col=0)
        data.at[0, "Specificity"]=float(1-(d/50).median())
        data["Dataset"]=dataset
        data["FeaturesSpace"]=featurespace
        data["CT"]=str(d.columns[0])
        data["Name"]=name

        stab_file=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{label}/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_{str(d.columns[0])}_Stability.tsv.gz"
        d=pd.read_csv(stab_file, sep="\t", index_col=0)
        data["Stability"]=(d/50).median().median()
        tot=pd.concat([tot, data])

tot=tot.reset_index()
tot.drop("index", axis=1, inplace=True)
tot["FeaturesSpace"]=tot["FeaturesSpace"].replace("Peak","Peaks")
tot["FeaturesSpace"]=tot["FeaturesSpace"].replace("Window","Windows")
tot.to_csv("Tables/StabSpec.tsv.gz", compression="gzip", sep="\t")

xaide=pd.DataFrame(columns=["Dataset","FeaturesSpace", "CT", "Difference","Name"])
for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    data=pd.DataFrame(columns=tot.columns)
    de=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/DE/{dataset}_{featurespace}_GRAE_kNN_{label}_DETop50Features.tsv.gz"
    xai=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_XAITop50Features.tsv.gz"
    if os.path.isfile(xai) == True & os.path.isfile(de) == True:
        de=pd.read_csv(de, sep="\t", index_col=0)
        xai=pd.read_csv(xai, sep="\t", index_col=0)
        inter=ut.intersection([de.columns, xai.columns])
        for col in inter:
            d=pd.DataFrame(index=xaide.columns, data=[dataset, featurespace, col, 1-len(ut.intersection([de[col], xai[col]]))/50, name]).T
            xaide=pd.concat([xaide, d])
    else:
        print(de, xai)
xaide["FeaturesSpace"]=xaide["FeaturesSpace"].replace("Peak","Peaks")
xaide["FeaturesSpace"]=xaide["FeaturesSpace"].replace("Window","Windows")
xaide.to_csv("Tables/XAIvsDEA.tsv.gz", compression="gzip", sep="\t")

## Artifacts

In [None]:
xaide=pd.DataFrame(columns=["Dataset","FeaturesSpace", "CT", "Intersection","Method","Name","Bias"])
for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    data=pd.DataFrame(columns=xaide.columns)
    de=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/DE/{dataset}_{featurespace}_GRAE_kNN_{label}_DETop50Features.tsv.gz"
    xai=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_XAITop50Features.tsv.gz"
    adata=sc.read_h5ad(f"Datasets/{dataset}/FeatureSpaces/{featurespace}/CM/{dataset}_{featurespace}_Def.h5ad")
    if os.path.isfile(xai) == True & os.path.isfile(de) == True:
        de=pd.read_csv(de, sep="\t", index_col=0)
        xai=pd.read_csv(xai, sep="\t", index_col=0)
        inter=ut.intersection([de.columns, xai.columns])
        for col in inter:
            ad=adata[adata.obs[adata.obs[label]==col].index]
            df=pd.DataFrame(data=ad.X.sum(axis=0), columns=adata.var.index).T
            df.columns=["Expression"]
            top50=df.sort_values(by="Expression")[::-1][:50].index
            d=pd.DataFrame(index=xaide.columns, data=[dataset, featurespace, col, 1-len(ut.intersection([top50, xai[col]]))/50, "XAI", name, "Expression"]).T
            xaide=pd.concat([xaide, d])
            d=pd.DataFrame(index=xaide.columns, data=[dataset, featurespace, col, 1-len(ut.intersection([top50, de[col]]))/50, "DA", name, "Expression"]).T
            xaide=pd.concat([xaide, d])
    else:
        print(de, xai)


for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    data=pd.DataFrame(columns=xaide.columns)
    de=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/DE/{dataset}_{featurespace}_GRAE_kNN_{label}_DETop50Features.tsv.gz"
    xai=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_XAITop50Features.tsv.gz"
    adata=sc.read_h5ad(f"Datasets/{dataset}/FeatureSpaces/{featurespace}/CM/{dataset}_{featurespace}_Def.h5ad")
    if os.path.isfile(xai) == True & os.path.isfile(de) == True:
        de=pd.read_csv(de, sep="\t", index_col=0)
        xai=pd.read_csv(xai, sep="\t", index_col=0)
        inter=ut.intersection([de.columns, xai.columns])
        for col in inter:
            ad=adata[adata.obs[adata.obs[label]==col].index]
            top50=ad.var.sort_values(by="n_cells")[::-1].index[:50]
            d=pd.DataFrame(index=xaide.columns, data=[dataset, featurespace, col, 1-len(ut.intersection([top50, xai[col]]))/50, "XAI", name, "Coverage"]).T
            xaide=pd.concat([xaide, d])
            d=pd.DataFrame(index=xaide.columns, data=[dataset, featurespace, col, 1-len(ut.intersection([top50, de[col]]))/50, "DA", name, "Coverage"]).T
            xaide=pd.concat([xaide, d])
    else:
        print(de, xai)
xaide["Intersection"]=xaide["Intersection"].astype(float)
xaide["FeaturesSpace"]=xaide["FeaturesSpace"].replace("Peak","Peaks")
xaide["FeaturesSpace"]=xaide["FeaturesSpace"].replace("Window","Windows")
xaide.to_csv("Tables/CoverageExpr.tsv.gz", compression="gzip", sep="\t")

## STN

In [None]:
stn_df=pd.DataFrame(columns=["Dataset","FeaturesSpace", "CT", "STN","Method","Name"])
for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    data=pd.DataFrame(columns=stn_df.columns)
    de=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/DE/{dataset}_{featurespace}_GRAE_kNN_{label}_DETop50Features.tsv.gz"
    xai=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_XAITop50Features.tsv.gz"
    adata=sc.read_h5ad(f"Datasets/{dataset}/FeatureSpaces/{featurespace}/CM/{dataset}_{featurespace}_Def.h5ad")
    adata.var_names_make_unique()

    if os.path.isfile(xai) == True & os.path.isfile(de) == True:
        de=pd.read_csv(de, sep="\t", index_col=0)
        xai=pd.read_csv(xai, sep="\t", index_col=0)
        inter=ut.intersection([de.columns, xai.columns])
        for col in inter:
            try:
                ad=adata[adata.obs[adata.obs[label]==col].index]
                stn=ad[:, xai[col]].X.mean()/ad[:, xai[col]].X.todense().std()
                d=pd.DataFrame(index=stn_df.columns, data=[dataset, featurespace, col, stn, "XAI", name]).T
                stn_df=pd.concat([stn_df, d])
                stn=ad[:, de[col].dropna()].X.mean()/ad[:, de[col].dropna()].X.todense().std()
                d=pd.DataFrame(index=stn_df.columns, data=[dataset, featurespace, col, stn, "DA", name]).T
                stn_df=pd.concat([stn_df, d])
            except:
                print(dataset, featurespace, label, col)
    else:
        print(de, xai)
stn_df["FeaturesSpace"]=stn_df["FeaturesSpace"].replace("Peak","Peaks")
stn_df["FeaturesSpace"]=stn_df["FeaturesSpace"].replace("Window","Windows")
stn_df["STN"]=stn_df["STN"].astype(float)
stn_df.to_csv("Tables/STN.tsv.gz", compression="gzip", sep="\t")

## Results stability

In [None]:
tot=pd.DataFrame(columns=["Stability","Specificity","Dataset","FeaturesSpace", "CT", "Name","Threshold"])
thres=[10, 25, 50, 75, 100, 125, 150, 200, 250, 300, 350, 500, 750, 1000, 1500, 2000]
for dataset, featurespace, label, ds in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    print(dataset, featurespace)
    name=f"{dataset}_{featurespace}_GRAE_kNN_{label}"
    matrix=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/CM/{dataset}_{featurespace}_Def.h5ad"
    obs = sc.read(matrix).obs
    for th in thres:
        for ct in sorted(set(obs[label].dropna())):
            d=pd.read_csv(f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/Runs/CTS/{ct}/{name}_{ct}_{th}_Specificity.tsv.gz", sep="\t", index_col=0)
            data=pd.DataFrame(columns=tot.columns)  
            data.at[0, "Specificity"]=float(1-(d/th).median())
            data["Dataset"]=dataset
            data["FeaturesSpace"]=featurespace
            data["CT"]=str(d.columns[0])
            data["Name"]=ds
            data["Threshold"]=th
            
            stab_file=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/Runs/CTS/{ct}/{name}_{ct}_{th}_Stability.tsv.gz"
            d=pd.read_csv(stab_file, sep="\t", index_col=0)
            data["Stability"]=(d/th).median().median()
            tot=pd.concat([tot, data])
tot=tot.reset_index()
tot.drop("index", axis=1, inplace=True)
tot["FeaturesSpace"]=tot["FeaturesSpace"].replace("Peak","Peaks")
tot["FeaturesSpace"]=tot["FeaturesSpace"].replace("Window","Windows")
tot.to_csv("Tables/StabSpec_Threshold.tsv.gz", compression="gzip", sep="\t")

In [None]:
xaide=pd.DataFrame(columns=["Dataset","FeaturesSpace", "CT", "Difference","Name","Threshold"])
thres=[10, 25, 50, 75, 100, 125, 150, 200, 250, 300, 350, 500, 750, 1000, 1500, 2000]
for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    print(dataset, featurespace)
    data=pd.DataFrame(columns=xaide.columns)
    for t in thres:
        de=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/DE/{dataset}_{featurespace}_GRAE_kNN_{label}_DEFeatures.tsv.gz"
        xai=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_XAIFeatures.tsv.gz"
        if os.path.isfile(xai) == True & os.path.isfile(de) == True:
            de=pd.read_csv(de, sep="\t", index_col=0)[:t]
            xai=pd.read_csv(xai, sep="\t", index_col=0)[:t]
            inter=ut.intersection([de.columns, xai.columns])
            for col in inter:
                d=pd.DataFrame(index=xaide.columns, data=[dataset, featurespace, col, 1-len(ut.intersection([de[col], xai[col]]))/t, name, t]).T
                xaide=pd.concat([xaide, d])
        else:
            print(de, xai)
xaide["FeaturesSpace"]=xaide["FeaturesSpace"].replace("Peak","Peaks")
xaide["FeaturesSpace"]=xaide["FeaturesSpace"].replace("Window","Windows")
xaide.to_csv("Tables/DAInter_Threshold.tsv.gz", compression="gzip", sep="\t")

# Figures

## Figure5

In [None]:
fss = {"GEX" : [], "Peak" : [], "Window" : []}
xaide=pd.DataFrame(columns=["Dataset","FeaturesSpace", "CT", "Difference","Name","Threshold"])
for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    xai_file=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_XAIFeaturesImportance.tsv.gz"
    xai=pd.read_csv(xai_file, sep="\t", index_col=0)
    for col in xai.columns:
        fss[featurespace].append(np.array(xai[col]))        
distributions = {key : pd.DataFrame(fss[key]).mean()[:250] for key in fss.keys()}
distributions_std = {key : pd.DataFrame(fss[key]).sem()[:250] for key in fss.keys()}

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [15, 10]}
plt.rcParams.update(params) 
fs_order=["Peaks","GEX","Windows"]
feats_palette = {fs_order[i] : ut.colors_to_use_bright[9:12][i] for i in range(len(fs_order))}
fig, axs = plt.subplots(2,1)

x=[i+1 for i in range(0, len(distributions["Peak"]))]
y=distributions["Peak"]
axs[0].errorbar(x=x, y=y, yerr=distributions_std["Peak"], color=feats_palette["Peaks"], ls="None", zorder=15, label="Peaks")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
x=[i+1 for i in range(0, len(distributions["GEX"]))]
y=distributions["GEX"]
axs[0].errorbar(x=x, y=y, yerr=distributions_std["GEX"], color=feats_palette["GEX"], ls="None", zorder=10, label="GEX")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
x=[i+1 for i in range(0, len(distributions["Window"]))]
y=distributions["Window"]
axs[0].errorbar(x=x, y=y, yerr=distributions_std["Window"], color=feats_palette["Windows"], ls="None", zorder=5, label="Windows")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
axs[0].set_ylabel("Importance", labelpad=15)
axs[0].set_yscale("log")
axs[0].axhline(y=0.1, xmin=0, xmax=250, color="purple", alpha=0.5, linestyle="-", ms=1, zorder=1, linewidth = 3)


x=[i+1 for i in range(0, len(distributions["Peak"])-1)]
y=[distributions["Peak"][1+i]-distributions["Peak"][i] for i in range(0, len(distributions["Peak"])-1)]
axs[1].plot(x, -1*np.array(y), color=feats_palette["Peaks"], zorder=15, label="Peaks")

x=[i+1 for i in range(0, len(distributions["GEX"])-1)]
y=[distributions["GEX"][1+i]-distributions["GEX"][i] for i in range(0, len(distributions["GEX"])-1)]
axs[1].plot(x, -1*np.array(y), color=feats_palette["GEX"], zorder=10, label="GEX")

x=[i+1 for i in range(0, len(distributions["Window"])-1)]
y=[distributions["Window"][1+i]-distributions["Window"][i] for i in range(0, len(distributions["Window"])-1)]
axs[1].plot(x, -1*np.array(y), color=feats_palette["Windows"], zorder=5, label="Windows")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
axs[1].set_ylabel("Decay rate of importance", labelpad=15)
axs[1].set_ylim([0, 0.035])

for i, ax in enumerate(axs):
    ax.set_xlabel("Rank")
    ax.set_xlim([0,250])
 #   ax.set_xscale("log")
    ax.axvline(x=10, ymin=0, ymax=1, color=ut.colors_to_use_pastel[3], linestyle="--", zorder=1, linewidth = 3)
    ax.axvline(x=50, ymin=0, ymax=1, color=ut.colors_to_use_bright[-6], linestyle="--", zorder=1, linewidth = 3)
    ax.axvline(x=200, ymin=0, ymax=1, color=ut.colors_to_use_pastel[3], linestyle="--", zorder=1, linewidth = 3)
    ax.text(-0.075, 1.15, string.ascii_uppercase[i+1], transform=ax.transAxes, size=20, weight='bold',rotation=0)  
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

handles=[]
for key in feats_palette.keys():
    handles.append(mlines.Line2D([], [], color=feats_palette[key], marker='o', linestyle='None', markersize=10, label=key))
axs[1].legend(handles=handles, bbox_to_anchor=(1.3, 0.5), title="Feature space", fontsize=17, title_fontsize=20)
axs[0].set_yticks(np.logspace(-2, 0, 3), ['{:.0e}'.format(i) for i in np.logspace(-2, 0, 3)[:-1]]+[1])

fig.tight_layout(w_pad=-5)
plt.savefig("Figures/Figure5_Up.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [15, 5]}
plt.rcParams.update(params) 
fs_order=["Peaks","GEX","Windows"]
feats_palette = {fs_order[i] : ut.colors_to_use_bright[9:12][i] for i in range(len(fs_order))}
fig, axs = plt.subplots(1,2)

x=[i+1 for i in range(0, len(distributions["Peak"]))]
y=distributions["Peak"]
axs[0].errorbar(x=x, y=y, yerr=distributions_std["Peak"], color=feats_palette["Peaks"], ls="None", zorder=15, label="Peaks")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
x=[i+1 for i in range(0, len(distributions["GEX"]))]
y=distributions["GEX"]
axs[0].errorbar(x=x, y=y, yerr=distributions_std["GEX"], color=feats_palette["GEX"], ls="None", zorder=10, label="GEX")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
x=[i+1 for i in range(0, len(distributions["Window"]))]
y=distributions["Window"]
axs[0].errorbar(x=x, y=y, yerr=distributions_std["Window"], color=feats_palette["Windows"], ls="None", zorder=5, label="Windows")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
axs[0].set_ylabel("Importance", labelpad=15)
axs[0].set_yscale("log")
axs[0].axhline(y=0.1, xmin=0, xmax=250, color="purple", alpha=0.5, linestyle="-", ms=1, zorder=1, linewidth = 3)


x=[i+1 for i in range(0, len(distributions["Peak"])-1)]
y=[distributions["Peak"][1+i]-distributions["Peak"][i] for i in range(0, len(distributions["Peak"])-1)]
axs[1].plot(x, -1*np.array(y), color=feats_palette["Peaks"], zorder=15, label="Peaks")

x=[i+1 for i in range(0, len(distributions["GEX"])-1)]
y=[distributions["GEX"][1+i]-distributions["GEX"][i] for i in range(0, len(distributions["GEX"])-1)]
axs[1].plot(x, -1*np.array(y), color=feats_palette["GEX"], zorder=10, label="GEX")

x=[i+1 for i in range(0, len(distributions["Window"])-1)]
y=[distributions["Window"][1+i]-distributions["Window"][i] for i in range(0, len(distributions["Window"])-1)]
axs[1].plot(x, -1*np.array(y), color=feats_palette["Windows"], zorder=5, label="Windows")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
axs[1].set_ylabel("Decay rate of importance", labelpad=15)
axs[1].set_ylim([0, 0.035])

for i, ax in enumerate(axs):
    ax.set_xlabel("Rank")
    ax.set_xlim([0,250])
 #   ax.set_xscale("log")
    ax.axvline(x=10, ymin=0, ymax=1, color=ut.colors_to_use_pastel[3], linestyle="--", zorder=1, linewidth = 3)
    ax.axvline(x=50, ymin=0, ymax=1, color=ut.colors_to_use_bright[-6], linestyle="--", zorder=1, linewidth = 3)
    ax.axvline(x=200, ymin=0, ymax=1, color=ut.colors_to_use_pastel[3], linestyle="--", zorder=1, linewidth = 3)
    ax.text(-0.075, 1.15, string.ascii_uppercase[i+1], transform=ax.transAxes, size=20, weight='bold',rotation=0)  
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

handles=[]
for key in feats_palette.keys():
    handles.append(mlines.Line2D([], [], color=feats_palette[key], marker='o', linestyle='None', markersize=10, label=key))
axs[1].legend(handles=handles, bbox_to_anchor=(1.3, 0.7), title="Feature space", fontsize=17, title_fontsize=20)
axs[0].set_yticks(np.logspace(-2, 0, 3), ['{:.0e}'.format(i) for i in np.logspace(-2, 0, 3)[:-1]]+[1])

fig.tight_layout(h_pad=2)
#plt.savefig("Figures/Figure5_Up.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
perf=pd.read_csv("Tables/F1PrecRec.tsv.gz", sep="\t", index_col=0)
stabspec=pd.read_csv("Tables/StabSpec.tsv.gz", sep="\t", index_col=0)
xaide=pd.read_csv("Tables/XAIvsDEA.tsv.gz", sep="\t", index_col=0)
covexpr=pd.read_csv("Tables/CoverageExpr.tsv.gz", sep="\t", index_col=0)
stn=pd.read_csv("Tables/STN.tsv.gz", sep="\t", index_col=0)

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [15, 5]}
plt.rcParams.update(params) 
fs_order=["Peaks","GEX","Windows"]
feats_palette = {fs_order[i] : ut.colors_to_use_bright[9:12][i] for i in range(len(fs_order))}
fig, axs = plt.subplots(1,3)
axs=axs.flatten()

for i,fs in enumerate(["Peaks","GEX","Windows"]):
    d=stabspec[stabspec["FeaturesSpace"]==fs]
    d["Name"]=d["Name"].str.replace("Breast cancer", "Breast\ncancer")
    d["Name"]=d["Name"].str.replace("Human brain", "Human\nbrain")
    d["Name"]=d["Name"].str.replace("Mouse brain", "Mouse\nbrain")
    order=sorted(set(d["Name"]))
    means=np.array(d.groupby("Name")["Specificity"].mean())
    sems=np.array(d.groupby("Name")["Specificity"].sem())
    zorder=10 if fs == "Peaks" else 1
    axs[0].errorbar(x=means, y=order, xerr=sems,ls='none', ecolor=feats_palette[fs], 
                    elinewidth=20, marker="o", c="black", label=fs, zorder=zorder)

    d=stabspec[stabspec["FeaturesSpace"]==fs]
    d["Name"]=d["Name"].str.replace("Breast cancer", "Breast\ncancer")
    order=sorted(set(d["Name"]))
    means=np.array(d.groupby("Name")["Stability"].mean())
    sems=np.array(d.groupby("Name")["Stability"].sem())
    axs[1].errorbar(x=means, y=order, xerr=sems,ls='none', ecolor=feats_palette[fs], 
                    elinewidth=20, marker="o", c="black", label=fs)
    
    d=xaide[xaide["FeaturesSpace"]==fs]
    d["Name"]=d["Name"].str.replace("Breast cancer", "Breast\ncancer")
    order=sorted(set(d["Name"]))
    means=np.array(d.groupby("Name")["Difference"].mean())
    sems=np.array(d.groupby("Name")["Difference"].sem())
    axs[2].errorbar(x=means, y=order, xerr=sems,ls='none', ecolor=feats_palette[fs], 
                    elinewidth=20, marker="o", c="black", label=fs)

#handles, labels = axs[2].get_legend_handles_labels()
#axs[2].legend(handles, labels, bbox_to_anchor=(1, 1), title="Features space", fontsize=17, 
 #             title_fontsize=20, labelcolor=ut.colors_to_use_bright[9:12])

axs[1].set_yticks([])
axs[2].set_yticks([])
axs[1].set_ylabel("")
axs[2].set_ylabel("")
axs[0].set_ylabel("")
axs[0].yaxis.set_label_coords(-0.6, y=0.5)

axs[0].set_xlabel("XAIF specificity", labelpad=10)
axs[1].set_xlabel("XAIF stability", labelpad=10)
axs[2].set_xlabel("Difference\nXAIF & DAF", labelpad=10)

for i in range(0,3):
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)   
    axs[i].text(-0.1, 1.15, string.ascii_uppercase[i+3], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)    

axs[0].set_xlim([0.8, 1])
xticks=[0.8, 0.85, 0.9, 0.95, 1]
axs[0].set_xticks(xticks,xticks)
axs[1].set_xlim([0.9, 1])
xticks=[0.9, 0.95, 1]
axs[1].set_xticks(xticks,xticks)
axs[2].set_xlim([0, 1])
xticks=[0, 0.25, 0.5, 0.75, 1]
axs[2].set_xticks(xticks,xticks)

#plt.suptitle("Figure5", y=1.05)
fig.tight_layout(w_pad=0.4)
plt.savefig("Figures/Figure5_Middle.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [15, 5]}
plt.rcParams.update(params) 
fig, axs = plt.subplots(1,3)
axs=axs.flatten()

pairs=[((feat,"XAI"),(feat,"DA")) for feat in ["Peaks","GEX","Windows"]]
fig_args = {'y': "Intersection", 'x': "FeaturesSpace", "hue" : "Method", 'data': covexpr[covexpr["Bias"]=="Expression"],
            'dodge': True, "palette" : ut.colors_to_use_bright[::-1], "showfliers" : True}
configuration = {'test':'Mann-Whitney', 'text_format':'star'}
sns.boxplot(ax=axs[0], **fig_args, showmeans=True)
annotator = Annotator(ax=axs[0], orient="v", pairs=pairs, plot='boxplot', verbose=False, **fig_args)
annotator.configure(**configuration).apply_test().annotate()

fig_args = {'y': "Intersection", 'x': "FeaturesSpace", "hue" : "Method", 'data': covexpr[covexpr["Bias"]=="Coverage"],
            'dodge': True, "palette" : ut.colors_to_use_bright[::-1], "showfliers" : True}
configuration = {'test':'Mann-Whitney', 'text_format':'star'}
sns.boxplot(ax=axs[1], **fig_args, showmeans=True)
annotator = Annotator(ax=axs[1], orient="v", pairs=pairs, plot='boxplot', verbose=False, **fig_args)
annotator.configure(**configuration).apply_test().annotate()

fig_args = {'y': "STN", 'x': "FeaturesSpace", "hue" : "Method", 'data': stn,
            'dodge': True, "palette" : ut.colors_to_use_bright[::-1],"showfliers" : True}
configuration = {'test':'Mann-Whitney', 'text_format':'star'}
sns.boxplot(ax=axs[2], **fig_args, showmeans=True)
annotator = Annotator(ax=axs[2], orient="v", pairs=pairs, plot='boxplot', verbose=False, **fig_args)
annotator.configure(**configuration).apply_test().annotate()
axs[0].get_legend().remove()
axs[1].get_legend().remove()
axs[2].legend(bbox_to_anchor=(1.3, 1), title="Method", fontsize=17, title_fontsize=20)

for i, ax in enumerate(axs):
    ax.set_xlabel("Feature space", labelpad=10)

axs[0].set_ylabel("Intersection with most\nexpressed or open features", labelpad=10)
axs[1].set_ylabel("Intersection with most\ncovered features", labelpad=10)
axs[2].set_ylabel("STN", labelpad=10)

for i in range(0,3):
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)   
    axs[i].text(-0.2, 1.15, string.ascii_uppercase[i+6], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)    
axs[0].set_ylim([0,1])
axs[1].set_ylim([0,1])
axs[2].set_ylim([0,2.6])
fig.tight_layout(w_pad=4)
plt.savefig("Figures/Figure5_Down.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

## Figure6

In [None]:
atac=sc.read_h5ad("Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/CM/10XhsBrain3kMO_Peak_Raw.h5ad")
atac.layers["counts"]=atac.X.copy()
defatac=sc.read_h5ad("Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/CM/10XhsBrain3kMO_Peak_Def.h5ad")
sc.pp.normalize_total(atac)
atac=atac[defatac.obs.index]
atac.obs=defatac.obs
atac.uns=defatac.uns
atac.obsm=defatac.obsm
atac.obsp=defatac.obsp
atac

In [None]:
os.system("scanMotifGenomeWide.pl Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Astrocyte_progenitor/knownResults/known3.motif \
../AnnotRef/hs/T2T/chm13v2.0.fa -p 8 > Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Astrocyte_progenitor/knownResults/SOX9_TFBS.bed")

In [None]:
df=pd.read_csv("Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Astrocyte_progenitor/knownResults/SOX9_TFBS.bed", sep="\t", header=None)
df.head()

In [None]:
df[[1,2,3]].to_csv("Datasets/10XhsBrain3kMO/FragFile/SOX9_TFBSAll.bed", sep="\t", header=None, index=None)
print("zcat Datasets/10XhsBrain3kMO/FragFile/atac_fragments.tsv.gz |\
bedtools intersect -a stdin -b Datasets/10XhsBrain3kMO/FragFile/SOX9_TFBS.bed \
> Datasets/10XhsBrain3kMO/FragFile/SOX9_TFBSAll_Openness.tsv")

In [None]:
TFBS=pd.read_csv("Datasets/10XhsBrain3kMO/FragFile/SOX9_TFBSAll_Openness.tsv", sep="\t", header=None, index_col=3)
TFBS["cell"]=TFBS.index
TFBS=pd.DataFrame(TFBS.groupby("cell")[4].sum())
TFBS.rename({4 : "TFBS openness"}, axis=1, inplace=True)
print(TFBS.shape)
TFBS.head()

In [None]:
TFBS=TFBS.loc[ut.intersection([atac.obs.index, TFBS.index])]
atac=atac[TFBS.index]
atac.obs["counts"]=atac.layers["counts"].sum(axis=1)
atac.obs["SOX9_TFBS"]=0
atac.obs.loc[TFBS.index, "SOX9_TFBS"]=TFBS["TFBS openness"]/atac[TFBS.index].obs["counts"]
atac

In [None]:
gex=sc.read_h5ad("Datasets/10XhsBrain3kMO/FeatureSpaces/GEX/CM/10XhsBrain3kMO_GEX_Raw.h5ad")
gex.layers["counts"]=gex.X.copy()
defgex=sc.read_h5ad("Datasets/10XhsBrain3kMO/FeatureSpaces/GEX/CM/10XhsBrain3kMO_GEX_Def.h5ad")
sc.pp.normalize_total(gex)
sc.pp.log1p(gex)
gex=gex[defgex.obs.index]
gex.obs=defgex.obs
gex.uns=defgex.uns
gex.obsm=defgex.obsm
gex.obsp=defgex.obsp
gex

In [None]:
inter=ut.intersection([atac.obs.index, gex.obs.index])
atac=atac[inter]
gex=gex[inter]
gex.obs["X_GRAE_GEX"]=gex.obsm["X_GRAE_2D"].T[0]*-1
gex.obs["Y_GRAE_GEX"]=gex.obsm["X_GRAE_2D"].T[1]
atac.obs["X_GRAE_GEX"]=gex.obsm["X_GRAE_2D"].T[0]*-1
atac.obs["Y_GRAE_GEX"]=gex.obsm["X_GRAE_2D"].T[1]

gex.obs["X_GRAE_ATAC"]=atac.obsm["X_GRAE_2D"].T[0]*-1
gex.obs["Y_GRAE_ATAC"]=atac.obsm["X_GRAE_2D"].T[1]
atac.obs["X_GRAE_ATAC"]=atac.obsm["X_GRAE_2D"].T[0]*-1
atac.obs["Y_GRAE_ATAC"]=atac.obsm["X_GRAE_2D"].T[1]

astroprog=gex[gex.obs["CellType"]=="Astrocyte_progenitor"].obs.index
astro=gex[gex.obs["CellType"]=="Astrocyte"].obs.index

astrodata_atac=atac[ut.flat_list([astro, astroprog])].copy()
astrodata_gex=gex[ut.flat_list([astro, astroprog])].copy()
astrodata_gex.obs["CT"]=astrodata_gex.obs.CellType
astrodata_gex.obs["CT"].replace({"Astrocyte": "Astrocytes", "Astrocyte_progenitor" : "Astrocytes\nprogenitors"}, inplace=True)

gex.obsm["X_umap"]=np.array([atac.obsm["X_GRAE_2D"].T[0]*-1, atac.obsm["X_GRAE_2D"].T[1]]).T

In [None]:
interall=ut.intersection([astrodata_gex.obs.index, atac.obs.index])
astrodata_gex.obs["SOX9_TFBS"]=0
astrodata_gex.obs.loc[interall, "SOX9_TFBS"]=atac[interall].obs["SOX9_TFBS"]
astrodata_gex.obsm["X_umap"]=np.array([astrodata_atac.obsm["X_GRAE_2D"].T[0]*-1, astrodata_atac.obsm["X_GRAE_2D"].T[1]]).T

In [None]:
SOX9 = pd.read_csv('Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Astrocyte_progenitor/knownResults/known3.motif',
                 sep="\t", header=None, names=["A","C","G","T"], skiprows=1)

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [12, 7]}
plt.rcParams.update(params) 
fig, axs = plt.subplots(2,2)
axs=axs.flatten()

cts_colors={"Astrocytes" : ut.colors_to_use_bright[-6], "Astrocytes\nprogenitors" : ut.colors_to_use_bright[-4]}

cmap=mpl.cm.inferno
sc.pl.umap(astrodata_gex, color="CT", ax=axs[1], frameon=True, show=False, palette=cts_colors, legend_fontsize=13, size=200,
          legend_loc="lower right", alpha=0.7)
ax=sc.pl.umap(astrodata_gex, color="SOX9_TFBS", ax=axs[2], frameon=False, show=False, legend_fontsize=10, size=200, vmin=0, vmax=1.2, color_map=cmap, colorbar_loc=None)
sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_clim(vmin=0, vmax=1.2)
fig.colorbar(ax=ax, mappable=sm, label="Norm. fragments counts", ticks=[0, 0.25, 0.5, 0.75, 1, 1.2])

ax=sc.pl.umap(astrodata_gex, color="SOX9", ax=axs[3], frameon=False, show=False, vmin=0, vmax=2, legend_fontsize=10, size=200, color_map=cmap, colorbar_loc=None)
sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_clim(vmin=0, vmax=2)
fig.colorbar(ax=ax, mappable=sm, label="Log norm. counts", ticks=[0, 0.5, 1, 1.5, 2])


axs[1].set_title("", y=1.1)
axs[2].set_title("SOX9 TFBS", y=1.05)
axs[3].set_title("SOX9", y=1.05)

for i in range(1,4):
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)
    axs[i].spines['left'].set_visible(False)
    axs[i].spines['bottom'].set_visible(False)
    if i < 2:
        axs[i].set_xlabel("GRAE-1", labelpad=20)
        axs[i].set_ylabel("GRAE-2")
    axs[i].text(-0.1, 1.15, string.ascii_uppercase[i], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)  

crp_logo1 = logomaker.Logo(SOX9, shade_below=.5,fade_below=.5, ax=axs[0], color_scheme="colorblind_safe")
crp_logo1.style_spines(visible=False)
crp_logo1.style_xticks(fmt='%d', anchor=0)
crp_logo1.ax.xaxis.set_ticks_position('none')
crp_logo1.ax.xaxis.set_tick_params(pad=-1)
axs[0].set_title("SOX9 Motif", y=1.05)
axs[0].text(-0.1, 1.15, string.ascii_uppercase[0], transform=axs[0].transAxes, size=20, weight='bold',rotation=0)  

fig.tight_layout(w_pad=2, h_pad=2)
plt.savefig("Figures/Figure6_Up.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [12, 4]}
plt.rcParams.update(params) 
fig, axs = plt.subplots(1,2)
axs=axs.flatten()

cmap=mpl.cm.inferno

sc.pl.umap(astrodata_gex, color="DNAH7", ax=axs[0], frameon=False, show=False, legend_fontsize=10, size=200, vmin=0, vmax=2.55, color_map=cmap, colorbar_loc=None)
ax=sc.pl.umap(astrodata_gex, color="EFEMP1", ax=axs[1], frameon=False, show=False, vmin=0, vmax=2.55, size=200, color_map=cmap, colorbar_loc=None)
sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_clim(vmin=0, vmax=2.55)
fig.colorbar(ax=ax, mappable=sm, label="Log norm. counts", ticks=[0, 0.5, 1, 1.5, 2, 2.5])
for i in range(0,2):
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)
    axs[i].spines['left'].set_visible(False)
    axs[i].spines['bottom'].set_visible(False)
    axs[i].text(-0.1, 1.15, string.ascii_uppercase[i+4], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)  

fig.tight_layout()
plt.savefig("Figures/Figure6_Down.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

## SuppFig3

In [None]:
fss = {"GEX" : [], "Peak" : [], "Window" : []}
xaide=pd.DataFrame(columns=["Dataset","FeaturesSpace", "CT", "Difference","Name","Threshold"])
for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    xai_file=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_XAIFeaturesImportance.tsv.gz"
    xai=pd.read_csv(xai_file, sep="\t", index_col=0)
    for col in xai.columns:
        fss[featurespace].append(np.array(xai[col]))        
distributions = {key : pd.DataFrame(fss[key]).mean() for key in fss.keys()}
distributions_std = {key : pd.DataFrame(fss[key]).sem() for key in fss.keys()}

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [15, 10]}
plt.rcParams.update(params) 
fs_order=["Peaks","GEX","Windows"]
feats_palette = {fs_order[i] : ut.colors_to_use_bright[9:12][i] for i in range(len(fs_order))}
fig, axs = plt.subplots(2,1)

x=[i+1 for i in range(0, len(distributions["Peak"]))]
y=distributions["Peak"]
axs[0].errorbar(x=x, y=y, yerr=distributions_std["Peak"], color=feats_palette["Peaks"], ls="None", zorder=15, label="Peaks")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
x=[i+1 for i in range(0, len(distributions["GEX"]))]
y=distributions["GEX"]
axs[0].errorbar(x=x, y=y, yerr=distributions_std["GEX"], color=feats_palette["GEX"], ls="None", zorder=10, label="GEX")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
x=[i+1 for i in range(0, len(distributions["Window"]))]
y=distributions["Window"]
axs[0].errorbar(x=x, y=y, yerr=distributions_std["Window"], color=feats_palette["Windows"], ls="None", zorder=5, label="Windows")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
axs[0].set_ylabel("Importance", labelpad=15)
axs[0].set_yticks(np.logspace(-5, 0, 6), ['{:.0e}'.format(i) for i in np.logspace(-5, 0, 6)[:-1]]+[1])
axs[0].set_yscale("log")


x=[i+1 for i in range(0, len(distributions["Peak"])-1)]
y=[distributions["Peak"][1+i]-distributions["Peak"][i] for i in range(0, len(distributions["Peak"])-1)]
axs[1].plot(x, -1*np.array(y), color=feats_palette["Peaks"], zorder=15, label="Peaks")

x=[i+1 for i in range(0, len(distributions["GEX"])-1)]
y=[distributions["GEX"][1+i]-distributions["GEX"][i] for i in range(0, len(distributions["GEX"])-1)]
axs[1].plot(x, -1*np.array(y), color=feats_palette["GEX"], zorder=10, label="GEX")

x=[i+1 for i in range(0, len(distributions["Window"])-1)]
y=[distributions["Window"][1+i]-distributions["Window"][i] for i in range(0, len(distributions["Window"])-1)]
axs[1].plot(x, -1*np.array(y), color=feats_palette["Windows"], zorder=5, label="Windows")
#ax.scatter(x=x, y=y, s=5, color="black", marker="o", zorder=10)
axs[1].set_ylabel("Decay rate of importance", labelpad=15)
axs[1].set_ylim([0, 0.035])

for i, ax in enumerate(axs):
    ax.set_xlabel("Rank")
    ax.set_xscale("log")
    ax.axvline(x=10, ymin=0, ymax=1, color=ut.colors_to_use_pastel[3], linestyle="--", zorder=1, linewidth = 3)
    ax.axvline(x=50, ymin=0, ymax=1, color=ut.colors_to_use_bright[-6], linestyle="--", zorder=1, linewidth = 3)
    ax.axvline(x=200, ymin=0, ymax=1, color=ut.colors_to_use_pastel[3], linestyle="--", zorder=1, linewidth = 3)
    ax.axhline(y=0.1, xmin=0, xmax=5000, color="purple", alpha=0.5, linestyle="-", ms=1, zorder=1, linewidth = 3)
    ax.text(-0.075, 1.15, string.ascii_uppercase[i], transform=ax.transAxes, size=20, weight='bold',rotation=0)  
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

handles=[]
for key in feats_palette.keys():
    handles.append(mlines.Line2D([], [], color=feats_palette[key], marker='o', linestyle='None', markersize=10, label=key))
axs[0].legend(handles=handles, bbox_to_anchor=(1.3, 0.5), title="Feature space", fontsize=17, title_fontsize=20)
   
fig.tight_layout(w_pad=-5)
plt.savefig("Figures/SuppFig3.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

## SuppFig4

In [None]:
tot=pd.read_csv("Tables/StabSpec_Threshold.tsv.gz", index_col=0, sep="\t")
xaide=pd.read_csv("Tables/DAInter_Threshold.tsv.gz", index_col=0, sep="\t")
xaide["Threshold"]=xaide["Threshold"].astype(int)
thres=[10, 25, 50, 75, 100, 125, 150, 200, 250, 300, 350, 500, 750, 1000, 1500, 2000]

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [10, 12]}
plt.rcParams.update(params) 
fs_order=["Peaks","GEX","Windows"]
feats_palette = {fs_order[i] : ut.colors_to_use_bright[9:12][i] for i in range(len(fs_order))}
fig, axs = plt.subplots(3,1)
axs=axs.flatten()

for feat in fs_order:   
    axs[0].errorbar(x=thres, y=tot[tot["FeaturesSpace"]==feat].groupby("Threshold")["Specificity"].mean(),
                    yerr=tot[tot["FeaturesSpace"]==feat].groupby("Threshold")["Specificity"].sem(), 
               c=feats_palette[feat], zorder=10, label=feat, linewidth=3)
axs[0].set_title("Specificity")
axs[0].set_xscale("log")
axs[0].axhline(y=0.75, xmin=0, xmax=5000, color=ut.colors_to_use_pastel[-6], linestyle="-", zorder=1, linewidth = 2)
axs[0].set_ylim([0, 1])
axs[0].set_yticks([0, 0.25, 0.5, 0.75, 0.9,1], [0, 0.25, 0.5, 0.75, 0.9,1])
axs[0].legend(bbox_to_anchor=(1, 0.6), title="Feature space", fontsize=17, title_fontsize=20)

for feat in fs_order:   
    axs[1].errorbar(x=thres, y=tot[tot["FeaturesSpace"]==feat].groupby("Threshold")["Stability"].mean(), 
                    yerr=tot[tot["FeaturesSpace"]==feat].groupby("Threshold")["Stability"].sem(),
               c=feats_palette[feat], zorder=10, label=feat, linewidth=3)
axs[1].set_ylabel("Stability")
axs[1].set_xscale("log")
axs[1].set_ylim([0.9, 1])
axs[1].set_yticks([0.9, 0.95,1], [0.9, 0.95, 1])


for feat in fs_order:   
    axs[2].errorbar(x=thres, y=xaide[xaide["FeaturesSpace"]==feat].groupby("Threshold")["Difference"].mean(),
                    yerr=xaide[xaide["FeaturesSpace"]==feat].groupby("Threshold")["Difference"].sem(),
               c=feats_palette[feat], zorder=10, label=feat, linewidth=3)
axs[2].set_ylim([0,1])
axs[2].set_xscale("log")
axs[2].set_ylabel("Relative intersection\n between XAIF and DAF")

for i in range(0,3):
    axs[i].axvline(x=10, ymin=0, ymax=1, color=ut.colors_to_use_pastel[3], linestyle="--", zorder=1, linewidth = 2)
    axs[i].axvline(x=50, ymin=0, ymax=1, color=ut.colors_to_use_bright[-6], linestyle="-", zorder=1, linewidth = 2)
    axs[i].axvline(x=200, ymin=0, ymax=1, color=ut.colors_to_use_pastel[3], linestyle="--", zorder=1, linewidth = 2)
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)   
    axs[i].set_xlabel("Number of features")
    axs[i].text(-0.1, 1.15, string.ascii_uppercase[i], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)    
    

fig.tight_layout()
plt.savefig("Figures/SuppFig4.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

## SuppFig5

In [None]:
tot=pd.DataFrame(columns=["Dataset","FeaturesSpace", "CT", "Difference","Name"])
for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    data=pd.DataFrame(columns=tot.columns)
    nn=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_NN_XAIFeatures.tsv.gz"
    gnn=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_XAITop50Features.tsv.gz"
    if os.path.isfile(gnn) == True & os.path.isfile(nn) == True:
        nn=pd.read_csv(nn, sep="\t", index_col=0)[:50]
        gnn=pd.read_csv(gnn, sep="\t", index_col=0)[:50]
        inter=ut.intersection([nn.columns, gnn.columns])
        for col in inter:
            d=pd.DataFrame(index=tot.columns, data=[dataset, featurespace, col, 1-len(ut.intersection([nn[col], gnn[col]]))/50, name]).T
            tot=pd.concat([tot, d])
    else:
        print(nn, gnn)

tot1=pd.DataFrame(columns=["Dataset","FeaturesSpace", "CT", "Difference","Name"])
for dataset, featurespace, label, name in zip(ds_infos["DSs"], ds_infos["FsSs"], ds_infos["LBs"], ds_infos["Names"]):
    data=pd.DataFrame(columns=tot1.columns)
    grae=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_GRAE_kNN_{label}_HPO_XAITop50Features.tsv.gz"
    ae=f"Datasets/{dataset}/FeatureSpaces/{featurespace}/XAI/{dataset}_{featurespace}_AE_kNN_{label}_HPO_XAITop50Features.tsv.gz"
    if os.path.isfile(grae) == True & os.path.isfile(ae) == True:
        grae=pd.read_csv(grae, sep="\t", index_col=0)[:50]
        ae=pd.read_csv(ae, sep="\t", index_col=0)[:50]
        inter=ut.intersection([grae.columns, ae.columns])
        for col in inter:
            d=pd.DataFrame(index=tot1.columns, data=[dataset, featurespace, col, 1-len(ut.intersection([grae[col], ae[col]]))/50, name]).T
            tot1=pd.concat([tot1, d])
    else:
        print(grae, ae)

tot["FeaturesSpace"]=tot["FeaturesSpace"].replace("Peak","Peaks")
tot["FeaturesSpace"]=tot["FeaturesSpace"].replace("Window","Windows")
tot1["FeaturesSpace"]=tot1["FeaturesSpace"].replace("Peak","Peaks")
tot1["FeaturesSpace"]=tot1["FeaturesSpace"].replace("Window","Windows")

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [15, 6]}
plt.rcParams.update(params) 
fig, axs = plt.subplots(1,2, sharey=True)

for i,fs in enumerate(["Peaks","GEX","Windows"]):
    d=tot[tot["FeaturesSpace"]==fs]
    order=sorted(set(d["Name"]))
    means=np.array(d.groupby("Name")["Difference"].mean())
    sems=np.array(d.groupby("Name")["Difference"].sem())
    axs[0].errorbar(x=means, y=order, xerr=sems,ls='none', ecolor=ut.colors_to_use_bright[9:12][i], 
                    elinewidth=20, marker="o", c="black", label=fs)

    d=tot1[tot1["FeaturesSpace"]==fs]
    order=sorted(set(d["Name"]))
    means=np.array(d.groupby("Name")["Difference"].mean())
    sems=np.array(d.groupby("Name")["Difference"].sem())
    axs[1].errorbar(x=means, y=order, xerr=sems,ls='none', ecolor=ut.colors_to_use_bright[9:12][i], 
                    elinewidth=20, marker="o", c="black", label=fs)

axs[0].set_ylabel("Dataset")
axs[0].set_xlabel("Relative intersection\nbetween XAIF of GAT and NN")
axs[0].set_xlim([0,1])
axs[0].legend().remove()

axs[1].set_xlabel("Relative intersection\nbetween XAIF of GRAE and AE")
axs[1].legend(bbox_to_anchor=(1.1, 1), title="Feature space", fontsize=17, title_fontsize=20)
axs[1].set_xlim([0,1])
for i in range(0,2):
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)   
    axs[i].text(-0.1, 1.15, string.ascii_uppercase[i], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)    

fig.tight_layout()
plt.savefig("Figures/SuppFig5.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

## SuppFig6

In [None]:
SOX9 = pd.read_csv('Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Astrocyte_progenitor/knownResults/known3.motif',
                 sep="\t", header=None, names=["A","C","G","T"], skiprows=1)
SOX17 = pd.read_csv('Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Astrocyte_progenitor/knownResults/known5.motif',
                 sep="\t", header=None, names=["A","C","G","T"], skiprows=1)
SOX1 = pd.read_csv('Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Astrocyte_progenitor/knownResults/known7.motif',
                   sep="\t", header=None, names=["A","C","G","T"], skiprows=1)

ASCL1 = pd.read_csv('Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Inhibitory_neuron_MAF/knownResults/known14.motif',
                 sep="\t", header=None, names=["A","C","G","T"], skiprows=1)

JUNb = pd.read_csv('Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Microglia/knownResults/known12.motif',
                 sep="\t", header=None, names=["A","C","G","T"], skiprows=1)
FOSL2 = pd.read_csv('Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Microglia/knownResults/known15.motif',
                 sep="\t", header=None, names=["A","C","G","T"], skiprows=1)
FOS = pd.read_csv('Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/Microglia/knownResults/known16.motif',
                 sep="\t", header=None, names=["A","C","G","T"], skiprows=1)

RUNX1 = pd.read_csv('Datasets/10XhsPBMC10kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/DCp/knownResults/known1.motif',
                 sep="\t", header=None, names=["A","C","G","T"], skiprows=1)
RUNX2 = pd.read_csv('Datasets/10XhsPBMC10kMO/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/DCp/knownResults/known5.motif',
                 sep="\t", header=None, names=["A","C","G","T"], skiprows=1)

logos=[SOX17, SOX1, JUNb, FOSL2, FOS, ASCL1,  RUNX1, RUNX2]
names=["SOX17","SOX1","JUNb", "FOSL2", "FOS", "ASCL1", "RUNX1", "RUNX2"]
fig, axs = plt.subplots(3,3, figsize=(17, 10), dpi=300, sharey=True)
axs=axs.flatten()
for i, log in enumerate(logos):
    crp_logo1 = logomaker.Logo(log, shade_below=.5,fade_below=.5, ax=axs[i], color_scheme="colorblind_safe")
    crp_logo1.style_spines(visible=False)
    crp_logo1.style_xticks(fmt='%d', anchor=0)
    crp_logo1.ax.xaxis.set_ticks_position('none')
    crp_logo1.ax.xaxis.set_tick_params(pad=-1)
    axs[i].set_title(names[i])
    axs[i].text(-0.045, 1.075, string.ascii_uppercase[i], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)
axs[-1].remove()
fig.tight_layout(w_pad=3, h_pad=3)
plt.savefig("Figures/SuppFig6.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
dss=["10XhsBrain3kMO", "10XhsBrain3kMO", "10XhsBrain3kMO", "10XhsPBMC10kMO"]
names=["Human brain","Human brain","Human Brain","PBMC"]
cts=["Astrocyte_progenitor", "Inhibitory_neuron_MAF", "Microglia", "DCp"]
cols=['Motif Name', 'Consensus','q-value (Benjamini)']
order=['Data set',
 'Cell type',
    'DA Motif Name',
 'DA Consensus',
 'DA q-value (Benjamini)',
 'SEAGALL Motif Name',
 'SEAGALL Consensus',
 'SEAGALL q-value (Benjamini)']
tot=pd.DataFrame()
for ds, ct, n, in zip(dss, cts, names):
    print(ds, ct)
    xai=pd.read_csv(f"Datasets/{ds}/FeatureSpaces/Peak/XAI/MotifAnalysis/HOMER_XAIonly/{ct}/knownResults.txt", sep="\t")
    de=pd.read_csv(f"Datasets/{ds}/FeatureSpaces/Peak/XAI/DE/MotifAnalysis/HOMER_DEonly/{ct}/knownResults.txt", sep="\t")
    xai=xai[xai["q-value (Benjamini)"]<0.05][cols]
    de=de[de["q-value (Benjamini)"]<0.05][cols]
  #  de.set_index("Motif Name", inplace=True)
   # xai.set_index("Motif Name", inplace=True)
    xai.columns = [f"SEAGALL {col}" for col in xai.columns]
    de.columns = [f"DA {col}" for col in de.columns]
    xai["Data set"]=n
    xai["Cell type"]=ct
    de["Data set"]=n
    de["Cell type"]=ct
    df=pd.concat([de, xai])
    df.reset_index(inplace=True)
    df.drop("index", axis=1, inplace=True)
    tot=pd.concat([df, tot], axis=0)
    df.to_csv(f"Tables/Motif/{ds}_{ct}.tsv", sep="\t", index=None)
tot=tot[order]
tot.reset_index(inplace=True)
tot.drop("index", axis=1, inplace=True)
tot.to_csv("Tables/Motif/All_Motif.tsv", sep="\t", index=None)

## SuppFig7

In [None]:
brain=sc.read_h5ad("Datasets/10XhsBrain3kMO//FeatureSpaces/GEX/CM/10XhsBrain3kMO_GEX_Def.h5ad")
brain_atac=sc.read_h5ad("Datasets/10XhsBrain3kMO/FeatureSpaces/Peak/CM/10XhsBrain3kMO_Peak_Def.h5ad")

pbmc=sc.read_h5ad("Datasets/10XhsPBMC10kMO/FeatureSpaces/GEX/CM/10XhsPBMC10kMO_GEX_Def.h5ad")
pbmc_atac=sc.read_h5ad("Datasets/10XhsPBMC10kMO/FeatureSpaces/Peak/CM/10XhsPBMC10kMO_Peak_Def.h5ad")

In [None]:
inter=ut.intersection([brain_atac.obs.index, brain.obs.index])
brain_atac=brain_atac[inter]
brain=brain[inter]
brain.obsm["X_umap"]=np.array([brain_atac.obsm["X_GRAE_2D"].T[0]*-1, brain_atac.obsm["X_GRAE_2D"].T[1]]).T
brain.obs["CT"]=brain.obs.CellType
brain.obs["CT"].replace({"Astrocyte": "Astrocytes", "Astrocyte_progenitor" : "Astrocytes progenitors"}, inplace=True)

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [18, 5]}
plt.rcParams.update(params) 
fig, axs = plt.subplots(1,3)
axs=axs.flatten()

cts_colors={}
for i,ct in enumerate(list(set(brain.obs.CT))):
    cts_colors[ct]=ut.colors_to_use_pastel[i]
cts_colors["Astrocytes"]=ut.colors_to_use_bright[-6]
cts_colors["Astrocytes progenitors"]=ut.colors_to_use_bright[-4]
cts_colors["Inhibitory_neuron_PVALB_SST"]=ut.colors_to_use_pastel[-2]
cts_colors["Microglia"]=ut.colors_to_use_bright[1]


cmap=mpl.cm.inferno

sc.pl.umap(brain, color="CT", ax=axs[0], frameon=True, show=False, palette=cts_colors, legend_fontsize=13, size=200,
          legend_loc="right margin", alpha=0.7)
axs[0].legend(bbox_to_anchor=(-0.1, 1), title="", fontsize=12, frameon=False)
axs[0].set_title("")
ax=sc.pl.umap(brain, color="TLR2", ax=axs[1], frameon=False, show=False, legend_fontsize=10, size=200, vmin=0, vmax=2.5,
              color_map=cmap, colorbar_loc=None)

ax=sc.pl.umap(brain, color="RIPK2", ax=axs[2], frameon=False, show=False, vmin=0, vmax=2.5, 
              legend_fontsize=10, size=200, color_map=cmap, colorbar_loc=None)
sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_clim(vmin=0, vmax=2.5)
fig.colorbar(ax=ax, mappable=sm, label="Log norm. counts", ticks=[0, 0.5, 1, 1.5, 2, 2.5])

for i in range(0,3):
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)
    axs[i].spines['left'].set_visible(False)
    axs[i].spines['bottom'].set_visible(False)
    if i == 0:
        axs[i].set_xlabel("GRAE-1", labelpad=20)
        axs[i].set_ylabel("GRAE-2")
    axs[i].text(-0.1, 1.15, string.ascii_uppercase[i], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)  

fig.tight_layout(w_pad=-0.5)
plt.savefig("Figures/SuppFig7_Up.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
inter=ut.intersection([pbmc_atac.obs.index, pbmc.obs.index])
pbmc_atac=pbmc_atac[inter]
pbmc=pbmc[inter]
pbmc.obsm["X_umap"]=np.array([pbmc_atac.obsm["X_GRAE_2D"].T[0]*-1, pbmc_atac.obsm["X_GRAE_2D"].T[1]]).T
pbmc.obs["CT"]=pbmc.obs.CellType

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [18, 5]}
plt.rcParams.update(params) 
fig, axs = plt.subplots(1,2)
axs=axs.flatten()

cts_colors={}
for i,ct in enumerate(list(set(pbmc.obs.CellType))):
    cts_colors[ct]=ut.colors_to_use_bright[i]
cts_colors["T_MAIT"]=ut.colors_to_use_pastel[0]
cts_colors["NK"]=ut.colors_to_use_pastel[1]

cmap=mpl.cm.inferno

sc.pl.umap(pbmc, color="CellType", ax=axs[0], frameon=True, show=False, palette=cts_colors, legend_fontsize=13, size=200,
          legend_loc="right margin", alpha=0.7)
#axs[0].legend(bbox_to_anchor=(-0.05, 1), title="", fontsize=12, title_fontsize=20)
axs[0].set_title("")

ax=sc.pl.umap(pbmc, color="SOX4", ax=axs[1], frameon=False, show=False, legend_fontsize=10, size=200, vmin=0, vmax=2.5,
              color_map=cmap, colorbar_loc=None)
sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_clim(vmin=0, vmax=2.5)
fig.colorbar(ax=ax, mappable=sm, label="Log norm. counts", ticks=[0, 0.5, 1, 1.5, 2, 2.5])


for i in range(0,2):
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)
    axs[i].spines['left'].set_visible(False)
    axs[i].spines['bottom'].set_visible(False)
    if i == 0:
        axs[i].set_xlabel("GRAE-1", labelpad=20)
        axs[i].set_ylabel("GRAE-2")
    axs[i].text(-0.1, 1.15, string.ascii_uppercase[i+3], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)  

fig.tight_layout(w_pad=15)
plt.savefig("Figures/SuppFig7_Middle.png", format="png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
params = {'axes.labelsize': 15,
         'axes.titlesize': 15,
         'xtick.labelsize' : 15,
         'ytick.labelsize': 15,
         "lines.linewidth" : 4,
         "figure.dpi" : 300,
         "figure.figsize": [18, 5]}
plt.rcParams.update(params) 
fig, axs = plt.subplots(1,3)
axs=axs.flatten()

cts_colors={}
for i,ct in enumerate(list(set(pbmc.obs.CellType))):
    cts_colors[ct]=ut.colors_to_use_bright[i]
cts_colors["T_MAIT"]=ut.colors_to_use_pastel[0]
cts_colors["NK"]=ut.colors_to_use_pastel[1]

cmap=mpl.cm.inferno

ax=sc.pl.umap(pbmc, color="CR1", ax=axs[0], frameon=False, show=False, legend_fontsize=10, size=200, vmin=0, vmax=2.5,
              color_map=cmap, colorbar_loc=None)
ax=sc.pl.umap(pbmc, color="LAIR2", ax=axs[1], frameon=False, show=False, vmin=0, vmax=2.5, 
              legend_fontsize=10, size=200, color_map=cmap, colorbar_loc=None)
ax=sc.pl.umap(pbmc, color="CD8B", ax=axs[2], frameon=False, show=False, vmin=0, vmax=2.5, 
              legend_fontsize=10, size=200, color_map=cmap, colorbar_loc=None)
sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_clim(vmin=0, vmax=2.5)
fig.colorbar(ax=ax, mappable=sm, label="Log norm. counts", ticks=[0, 0.5, 1, 1.5, 2, 2.5])

for i in range(0,3):
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)
    axs[i].spines['left'].set_visible(False)
    axs[i].spines['bottom'].set_visible(False)
    axs[i].text(0.05, 1.15, string.ascii_uppercase[i+5], transform=axs[i].transAxes, size=20, weight='bold',rotation=0)  

fig.tight_layout(w_pad=-2)
plt.savefig("Figures/SuppFig7_Down.png", format="png", dpi=300, bbox_inches='tight')
plt.show()