In [None]:
import pandas as pd
import numpy as np
import torch

from speos.preprocessing.handler import InputHandler
from speos.utils.config import Config
from speos.preprocessing.datasets import DatasetBootstrapper

In [None]:
import os
os.chdir("..")

In [None]:
config = Config()
config.parse_yaml("config_uc_only_nohetio_film_newstorage.yaml")
prepro = InputHandler(config).get_preprocessor()
prepro.build_graph(adjacency=False)
data = prepro.get_data()

In [None]:

dataset = DatasetBootstrapper(holdout_size=config.input.holdout_size, name=config.name, config=config).get_dataset()

In [None]:
import json

with open("/mnt/storage/speos/results/uc_film_nohetioouter_results.json", "r") as file:
    results =  [key for key, value in json.load(file)[0].items() if value >= 11]

indices = torch.LongTensor([prepro.hgnc2id[hgnc] for hgnc in results])

with open("/mnt/storage/speos/results/uc_film_nohetioouter_results.json", "r") as file:
    results =  [key for key, value in json.load(file)[0].items() if value >= 1 and value < 11]

indices_weak = torch.LongTensor([prepro.hgnc2id[hgnc] for hgnc in results])

In [None]:
coregenes = dataset.data.y.long() 
coregenes[indices] = 1
coregenes.sum()

coregenes_weak = torch.zeros_like(coregenes)
coregenes_weak[indices_weak] = 1

In [None]:
hsps = pd.read_csv("hsps/uc.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0]]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1

In [None]:
gwas_zstat = pd.read_csv("data/gwas/UC.genes.out", header=0, index_col="GENE", usecols=["GENE", "ZSTAT"], sep=" ")
gwas_zstat = gwas_zstat.loc[gwas_zstat.index.isin(prepro.entrez2id.keys()), :]
len(gwas_zstat)

In [None]:
gwas_zstat = gwas_zstat.rename(prepro.entrez2id)

In [None]:
all_gwas_zstat = torch.zeros_like(new_y)
all_sign_indices = gwas_zstat[gwas_zstat["ZSTAT"] > 5].index
all_gwas_zstat[all_sign_indices] = 1

all_gwas_zstat[all_gwas_zstat.logical_and(coregenes)] = 0
all_gwas_zstat[all_gwas_zstat.logical_and(coregenes_weak)] = 0
all_gwas_zstat.sum()

# must be 70

In [None]:
features = pd.DataFrame(data=dataset.data.x.numpy(), columns=prepro.get_feature_names()).rename(prepro.id2hgnc, axis=0)
strongcore_features = features.iloc[coregenes.nonzero().squeeze().tolist(), :]
weakcore_features = features.iloc[coregenes_weak.nonzero().squeeze().tolist(), :]
hsp_features = features.iloc[new_y.nonzero().squeeze().tolist(), :]
gwas_hsp_features = features.iloc[all_gwas_zstat.nonzero().squeeze().tolist(), :]
peripheral_features = features.iloc[(1 - (coregenes + coregenes_weak + new_y)).nonzero().squeeze().tolist(), :]
eligible_features = features.iloc[(1 - coregenes_weak).nonzero().squeeze().tolist(), :]
eligible_features_weakcore = features.iloc[(1 - coregenes).nonzero().squeeze().tolist(), :]

In [None]:
from speos.visualization.settings import *
import matplotlib.pyplot as plt
from statsmodels.stats.multitest import fdrcorrection
from scipy.stats import fisher_exact

# fragments = [ [[0,1,2],[0.1, 0.2, 0.3]], [[4,5,6],[0.1, 0.2, 0.3]]

def get_fragments(values, sign_list):
    fragments = []
    index = 0
    fragment_idx = 0
    for sign in sign_list.tolist():
        if sign:
            try:
                fragments[fragment_idx][0].append(index)
                fragments[fragment_idx][1].append(values[index])
            except IndexError:
                fragments.append([[index], [values[index]]])
        else:
            try:
                fragments[fragment_idx]
                fragment_idx += 1
            except IndexError:
                # if we already incremented dont increment again
                pass
        index += 1

    return fragments
        
            

def plot_enrichment_fisher(featurename, eligible_features, strongcore_features, peripheral_features, hsp_features, ax=None, fig=None):
    core_enrichment = []
    peripheral_enrichment = []
    hsp_enrichment = []
    core_pval = []
    core_arrays = []
    peripheral_pval = []
    peripheral_arrays = []
    hsp_pval = []
    hsp_arrays = []

    for i in range(100, -1, -1):
      threshold = np.quantile(eligible_features[featurename], i / 100)
      total_count = (eligible_features[featurename] >= threshold).sum()
      core_count = (strongcore_features[featurename] >= threshold).sum()
      peripheral_count = (peripheral_features[featurename] >= threshold).sum()
      hsp_count = (hsp_features[featurename] >= threshold).sum()

      array = np.asarray([[core_count, total_count-core_count],
              [len(strongcore_features) - core_count, len(eligible_features) - len(strongcore_features) - total_count + core_count]])
        
      assert array[0, :].sum() == total_count
      assert array[1, :].sum() == len(eligible_features) - total_count
      assert array[:, 0].sum() == len(strongcore_features)
      assert array[:, 1].sum() == len(eligible_features) - len(strongcore_features)

      core_arrays.append(array.flatten())
      core_enrichment.append(fisher_exact(array)[0])
      core_pval.append(fisher_exact(array)[1])
      
      array = np.asarray([[peripheral_count, total_count-peripheral_count],
              [len(peripheral_features) - peripheral_count, len(eligible_features) - len(peripheral_features) - total_count + peripheral_count]])
        
      assert array[0, :].sum() == total_count
      assert array[1, :].sum() == len(eligible_features) - total_count
      assert array[:, 0].sum() == len(peripheral_features)
      assert array[:, 1].sum() == len(eligible_features) - len(peripheral_features)

      peripheral_arrays.append(array.flatten())
      peripheral_enrichment.append(fisher_exact(array)[0])
      peripheral_pval.append(fisher_exact(array)[1])

      array = np.asarray([[hsp_count, total_count-hsp_count],
              [len(hsp_features) - hsp_count, len(eligible_features) - len(hsp_features) - total_count + hsp_count]])

        
      assert array[0, :].sum() == total_count
      assert array[1, :].sum() == len(eligible_features) - total_count
      assert array[:, 0].sum() == len(hsp_features)
      assert array[:, 1].sum() == len(eligible_features) - len(hsp_features)

      hsp_arrays.append(array.flatten())
      hsp_enrichment.append(fisher_exact(array)[0])
      hsp_pval.append(fisher_exact(array)[1])
    
    

    core_enrichment = np.asarray(core_enrichment)
    peripheral_enrichment = np.asarray(peripheral_enrichment)
    hsp_enrichment = np.asarray(hsp_enrichment)
    total_sign = fdrcorrection(core_pval + peripheral_pval + hsp_pval)[0]
    total_fdr = fdrcorrection(core_pval + peripheral_pval + hsp_pval)[1]
    core_sign = total_sign[:len(core_pval)]
    core_fdr = total_fdr[:len(core_pval)]

    peripheral_sign = total_sign[len(core_pval):-len(hsp_pval)]
    peripheral_fdr = total_fdr[len(core_pval):-len(hsp_pval)]

    hsp_sign = total_sign[-len(hsp_pval):]
    hsp_fdr = total_fdr[-len(hsp_pval):]

    if ax is None:
        fig, ax = plt.subplots()
    for fragment in get_fragments(core_enrichment, core_sign):
        ax.plot(fragment[0], fragment[1], color="#01016f")
    for fragment in get_fragments(peripheral_enrichment, peripheral_sign):
        ax.plot(fragment[0], fragment[1], color="#5a5a5a")
    for fragment in get_fragments(hsp_enrichment, hsp_sign):
        ax.plot(fragment[0], fragment[1], color="#d8031c")


    ax.plot(np.arange(101), core_enrichment, color="lightblue", zorder=-5)
    #ax.plot(np.arange(101)[peripheral_sign], peripheral_enrichment[peripheral_sign], color="#5a5a5a")
    ax.plot(np.arange(101), peripheral_enrichment, color="lightgray", zorder=-5)
    #ax.plot(np.arange(101)[hsp_sign], hsp_enrichment[hsp_sign], color="#d8031c")
    ax.plot(np.arange(101), hsp_enrichment, color="pink", zorder=-5)
    ax.set_xticks((0, 20, 40, 60, 80, 100))
    ax.set_xticklabels((100, 80, 60, 40, 20, 0))
    ax.set_xlim(0, 100)
    ax.set_xlabel("Percentile ({})".format(featurename),  fontsize=7)
    ax.set_ylabel("Odds Ratio", fontsize=7)
    ax.hlines(1, 0, 100, colors="black", linewidth=1)
    ax.set_yscale("symlog")
    ax.set_yticks([ 1/10, 1/2, 1,2, 5, 10, 20])
    ax.set_yticklabels([0.1, 1/2, 1, 2, 5, 10, 20])
    ax.grid(which="major", axis="y", ls="--", color="lightgray", zorder=-5)

    core_values = np.concatenate((core_arrays, core_enrichment.reshape(-1,1), np.asarray(core_pval).reshape(-1,1), core_fdr.reshape(-1,1)), axis=1)
    periph_values = np.concatenate((peripheral_arrays, peripheral_enrichment.reshape(-1,1), np.asarray(peripheral_pval).reshape(-1,1), peripheral_fdr.reshape(-1,1)), axis=1)
    hsp_values = np.concatenate((hsp_arrays, hsp_enrichment.reshape(-1,1), np.asarray(hsp_pval).reshape(-1,1), hsp_fdr.reshape(-1,1)), axis=1)

    columns = []
    for group in ["Core", "Peripheral", "HSP"]:
        columns.extend(["{} in P.".format(group), "not {} in P.".format(group), "{} not in P.".format(group), "not {} not in P.".format(group), "{} Odds Ratio".format(group), "{} p-value".format(group), "{} FDR".format(group)])
    
    df = pd.DataFrame(data = np.concatenate((np.arange(100,-1,-1).reshape(-1,1), core_values, periph_values, hsp_values), axis=1),
                      columns=["Percentile"] + columns)
    
    return fig, ax, df

# Test with one tissue

In [None]:
fig, ax, df = plot_enrichment_fisher('Small Intestine - Terminal Ileum',  eligible_features, strongcore_features, peripheral_features, gwas_hsp_features)


In [None]:
prepro.get_feature_names()

In [None]:
fig, ax, df = plot_enrichment_fisher('Colon - Sigmoid',  eligible_features, weakcore_features, peripheral_features, gwas_hsp_features)


In [None]:
fig, ax, df = plot_enrichment_fisher('Colon - Transverse',  eligible_features_weakcore, weakcore_features, peripheral_features, gwas_hsp_features)


In [None]:
fig, axes = plt.subplots(2, 3, figsize=(full_width*cm, 7*cm), sharex=False, sharey=True)

axes = axes.flatten()

for i, (featurename, ax) in enumerate(zip(("Cells - EBV-transformed lymphocytes","Whole Blood","Spleen", "Artery - Tibial", 'Brain - Frontal Cortex (BA9)', 'Brain - Anterior cingulate cortex (BA24)'), axes)):
     fig, ax, df = plot_enrichment_fisher(featurename,  eligible_features, strongcore_features, peripheral_features, hsp_features, ax, fig)
     df.to_csv(featurename +"_by_snp.tsv", sep="\t", index=False)
     if i % 3 != 0:
          ax.set_ylabel("")
plt.tight_layout()

plt.savefig("input_features_fisher.svg")

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(full_width*cm, 7*cm), sharex=False, sharey=True)

axes = axes.flatten()

for i, (featurename, ax) in enumerate(zip(("Cells - EBV-transformed lymphocytes","Whole Blood","Spleen", "Artery - Tibial", 'Brain - Frontal Cortex (BA9)', 'Brain - Anterior cingulate cortex (BA24)'), axes)):
     fig, ax, df = plot_enrichment_fisher(featurename,  eligible_features, strongcore_features, peripheral_features, gwas_hsp_features, ax, fig)
     df.to_csv(featurename +"_zscore.tsv", sep="\t", index=False)
     if i % 3 != 0:
          ax.set_ylabel("")
plt.tight_layout()

plt.savefig("input_features_fisher_gwas.svg")

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(full_width*cm, 7*cm), sharex=False, sharey=True)

axes = axes.flatten()

for i, (featurename, ax) in enumerate(zip(("Cells - EBV-transformed lymphocytes","Whole Blood","Small Intestine - Terminal Ileum", "Artery - Tibial", 'Brain - Frontal Cortex (BA9)', 'Brain - Anterior cingulate cortex (BA24)'), axes)):
     fig, ax, df = plot_enrichment_fisher(featurename,  eligible_features, strongcore_features, peripheral_features, gwas_hsp_features, ax, fig)
     df.to_csv(featurename +"_zscore.tsv", sep="\t", index=False)
     if i % 3 != 0:
          ax.set_ylabel("")
plt.tight_layout()

plt.savefig("input_features_fisher_gwas_intestine.svg")

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(full_width*cm, 7*cm), sharex=False, sharey=True)

axes = axes.flatten()

for i, (featurename, ax) in enumerate(zip(("Cells - EBV-transformed lymphocytes","Whole Blood","Spleen", "Colon - Sigmoid", 'Colon - Transverse', 'Small Intestine - Terminal Ileum'), axes)):
     fig, ax, df = plot_enrichment_fisher(featurename,  eligible_features, strongcore_features, peripheral_features, gwas_hsp_features, ax, fig)
     #df.to_csv(featurename +"_zscore.tsv", sep="\t", index=False)
     if i % 3 != 0:
          ax.set_ylabel("")
plt.tight_layout()

plt.savefig("input_features_fisher_gwas_strongcore_intestine.svg")

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(full_width*cm, 7*cm), sharex=False, sharey=True)

axes = axes.flatten()

for i, (featurename, ax) in enumerate(zip(("Cells - EBV-transformed lymphocytes","Whole Blood","Spleen", "Colon - Sigmoid", 'Colon - Transverse', 'Small Intestine - Terminal Ileum'), axes)):
     fig, ax, df = plot_enrichment_fisher(featurename,  eligible_features_weakcore, weakcore_features, peripheral_features, gwas_hsp_features, ax, fig)
     #df.to_csv(featurename +"_zscore.tsv", sep="\t", index=False)
     if i % 3 != 0:
          ax.set_ylabel("")
plt.tight_layout()
plt.savefig("input_features_fisher_gwas_weakcore_intestine.svg")

In [None]:
features["Colon - Sigmoid"].sort_values()[-20:]

In [None]:
id2entrez = {value: key for key, value in prepro.entrez2id.items()}

In [None]:
gwas_zstat_hgnc = gwas_zstat.rename(prepro.id2hgnc)

In [None]:
zstat_colon_sigmoid = [gwas_zstat_hgnc.loc[_id].item() for _id in features["Colon - Sigmoid"].sort_values().index]

In [None]:
zstat_terminal_ileum = [gwas_zstat_hgnc.loc[_id].item() for _id in features["Small Intestine - Terminal Ileum"].sort_values().index]

In [None]:
zstat_ebv = [gwas_zstat_hgnc.loc[_id].item() for _id in features['Cells - EBV-transformed lymphocytes'].sort_values().index]
zstat_brain = [gwas_zstat_hgnc.loc[_id].item() for _id in features['Brain - Anterior cingulate cortex (BA24)'].sort_values().index]

In [None]:

ysmoothed = gaussian_filter1d(zstat_brain, sigma=100)


plt.plot(ysmoothed)

In [None]:

ysmoothed = gaussian_filter1d(zstat_colon_sigmoid, sigma=100)


plt.plot(ysmoothed)

In [None]:
ysmoothed = gaussian_filter1d(zstat_terminal_ileum, sigma=100)


plt.plot(ysmoothed)

In [None]:
from scipy.ndimage.filters import gaussian_filter1d

ysmoothed = gaussian_filter1d(zstat_ebv, sigma=100)

plt.plot(ysmoothed)

In [None]:
features.columns

In [None]:
features.loc["CDH1", "Colon - Sigmoid"]

In [None]:
features.loc["HNF4A", "Colon - Sigmoid"]

In [None]:
features.loc["GNA12", "Colon - Sigmoid"]

In [None]:
features.loc["GNA12", 'ZSTAT UC']

In [None]:
features.loc["GNA12", :].sort_values()[-20:]

In [None]:
features.loc["SLC26A3", "Colon - Sigmoid"]

In [None]:
features.loc["ORMDL3", "Colon - Sigmoid"]

In [None]:
features.loc[:, "Colon - Sigmoid"].quantile(0.80)

In [None]:
features.loc["ORMDL3", :].sort_values()[-15:]

In [None]:
import scipy.stats as spstats

In [None]:
spstats.spearmanr(zstat_ebv, np.arange(len(zstat_ebv)))

In [None]:
spstats.spearmanr(zstat_colon_sigmoid, np.arange(len(zstat_ebv)))

In [None]:
spstats.spearmanr(zstat_brain, np.arange(len(zstat_ebv)))

In [None]:
coregenes

In [None]:
coregene_features

In [None]:
from umap import UMAP

In [None]:
reducer = UMAP()

transformed = reducer.fit_transform(np.log(strongcore_features + strongcore_features.min(axis=0).abs() + 0.01))

In [None]:
input_attributions = []
genes= []
for _id in coregenes.nonzero().flatten().tolist() + ["HNF4A", "LAMB1"]:
    try:
        hgnc = prepro.id2hgnc[_id]
    except KeyError:
        hgnc = _id
    try:
        input_attributions.append(torch.load("/mnt/storage/speos/explanations/{}_ig_attr_self_total_{}.pt".format(config.name, hgnc)).detach().cpu().numpy().tolist())
        genes.append(hgnc)
    except FileNotFoundError:
        continue

In [None]:
reduced_strongcore_features = features[features.index.isin(genes)]

In [None]:
reducer = UMAP()



In [None]:
normalized_features = np.log(reduced_strongcore_features + reduced_strongcore_features.min(axis=0).abs() + 0.01)

In [None]:
normalized_features["ZSTAT UC"] = reduced_strongcore_features["ZSTAT UC"]
normalized_features["NSNPS UC"] = reduced_strongcore_features["NSNPS UC"]
normalized_features["P UC"] = reduced_strongcore_features["P UC"]

In [None]:
transformed = reducer.fit_transform(normalized_features)

In [None]:
transformed = pd.DataFrame(data=transformed, columns=["Dim0", "Dim1"], index=reduced_strongcore_features.index)

In [None]:
attributions = pd.DataFrame(data=input_attributions, columns=prepro.get_feature_names(), index=genes)

In [None]:
genes_of_interest = ["CDH1", "ECM1", "IL10", "HNF4A", "LAMB1"]

In [None]:
from speos.visualization.settings import *
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=3,nrows=2, figsize=(full_width*cm, full_width*cm))

for ax, feature in zip(axs.flatten(), ["Colon - Sigmoid", "Small Intestine - Terminal Ileum", "Esophagus - Mucosa", "Whole Blood", "Cells - EBV-transformed lymphocytes", "ZSTAT UC"]):
    im = ax.scatter(
    transformed["Dim0"],
    transformed["Dim1"],
    c=normalized_features.loc[:, feature],
    s=5)
    fig.colorbar(im, ax=ax)
    ax.set_title(feature, fontsize=7)


In [None]:
fig, axs = plt.subplots(ncols=3,nrows=2, figsize=(full_width*cm, full_width*cm))

for ax, feature in zip(axs.flatten(), ["Colon - Sigmoid", "Small Intestine - Terminal Ileum", "Esophagus - Mucosa", "Whole Blood", "Cells - EBV-transformed lymphocytes", "ZSTAT UC"]):
    im = ax.scatter(
    transformed[:, 0],
    transformed[:, 1],
    c=attributions.loc[:, feature],
    s=5)
    fig.colorbar(im, ax=ax)
    ax.set_title(feature, fontsize=7)

In [None]:


def plot_expression_vs_attribution(feature, expression,  importance, embeddings, genes_of_interest, label_top_expressed):

    fig, axs = plt.subplots(ncols=2,nrows=1, figsize=(full_width*cm*1.5, full_width*cm))

    im = axs[0].scatter(
        transformed["Dim0"],
        transformed["Dim1"],
        c=normalized_features.loc[:, feature],
        s=20)
    fig.colorbar(im, ax=axs[0], label="Expression")
    axs[0].set_title("Expression", fontsize=10)

    im = axs[1].scatter(
        transformed["Dim0"],
        transformed["Dim1"],
        c=importance.loc[:, feature],
        s=20)
    fig.colorbar(im, ax=axs[1], label="Importance")
    axs[1].set_title("Importance", fontsize=10)

    for gene_of_interest in genes_of_interest:
        axs[0].scatter(
        transformed.loc[gene_of_interest, "Dim0"],
        transformed.loc[gene_of_interest, "Dim1"],
        c=(0,0,0,0),
        edgecolors="red",
        linewidth=0.5,
        s=20)

        axs[1].scatter(
        transformed.loc[gene_of_interest, "Dim0"],
        transformed.loc[gene_of_interest, "Dim1"],
        c=(0,0,0,0),
        edgecolors="red",
        linewidth=0.5,
        s=20)
        
        axs[0].text(s=gene_of_interest, x=transformed.loc[gene_of_interest, "Dim0"], y=transformed.loc[gene_of_interest, "Dim1"], va="bottom", ha="left", color="red")
        axs[1].text(s=gene_of_interest, x=transformed.loc[gene_of_interest, "Dim0"], y=transformed.loc[gene_of_interest, "Dim1"], va="bottom", ha="left", color="red")

    sorted_genes = expression[feature].sort_values()[-label_top_expressed:].index

    for gene_of_interest in sorted_genes:
        axs[0].scatter(
        transformed.loc[gene_of_interest, "Dim0"],
        transformed.loc[gene_of_interest, "Dim1"],
        c=(0,0,0,0),
        edgecolors="black",
        linewidth=0.5,
        s=20)
        axs[0].text(s=gene_of_interest, x=transformed.loc[gene_of_interest, "Dim0"], y=transformed.loc[gene_of_interest, "Dim1"], va="bottom", ha="left")

    sorted_genes = attributions[feature].sort_values()[-label_top_expressed:].index

    for gene_of_interest in sorted_genes:
        axs[1].scatter(
        transformed.loc[gene_of_interest, "Dim0"],
        transformed.loc[gene_of_interest, "Dim1"],
        c=(0,0,0,0),
        edgecolors="black",
        linewidth=0.5,
        s=20)
        axs[1].text(s=gene_of_interest, x=transformed.loc[gene_of_interest, "Dim0"], y=transformed.loc[gene_of_interest, "Dim1"], va="bottom", ha="left")

    fig.suptitle(feature)
    plt.tight_layout()

    return fig, ax


plot_expression_vs_attribution(feature = "Colon - Sigmoid", expression=normalized_features,importance = attributions, embeddings=transformed, genes_of_interest=genes_of_interest, label_top_expressed=10)

In [None]:
plot_expression_vs_attribution(feature = "Small Intestine - Terminal Ileum", expression=normalized_features,importance = attributions, embeddings=transformed, genes_of_interest=genes_of_interest, label_top_expressed= 10)

In [None]:
plot_expression_vs_attribution(feature = "Cells - EBV-transformed lymphocytes", expression=normalized_features,importance = attributions, embeddings=transformed, genes_of_interest=genes_of_interest, label_top_expressed=10)

In [None]:
plot_expression_vs_attribution(feature = "Whole Blood", expression=normalized_features,importance = attributions, embeddings=transformed, genes_of_interest=genes_of_interest, label_top_expressed= 10)

In [None]:
plot_expression_vs_attribution(feature = "ZSTAT UC", expression=normalized_features,importance = attributions, embeddings=transformed, genes_of_interest=genes_of_interest, label_top_expressed= 10)

In [None]:
plot_expression_vs_attribution(feature = "Liver", expression=normalized_features,attributions = attributions, embeddings=transformed, genes_of_interest=genes_of_interest, label_top_expressed= 10)

In [None]:
plot_expression_vs_attribution(feature = "Lung", expression=normalized_features,importance = attributions, embeddings=transformed, genes_of_interest=genes_of_interest, label_top_expressed= 10)

In [None]:
plot_expression_vs_attribution(feature = "Thyroid", expression=normalized_features,attributions = attributions, embeddings=transformed, genes_of_interest=genes_of_interest, label_top_expressed= 10)

In [None]:
fig, axs = plt.subplots(ncols=2,nrows=1, figsize=(full_width*cm*1.5, full_width*cm))


with open("/mnt/storage/speos/results/uc_film_nohetioouter_results.json", "r") as file:
    results =  [key for key, value in json.load(file)[0].items() if value >= 11]

im = axs[0].scatter(
    transformed["Dim0"],
    transformed["Dim1"],
    c=["red" if gene in results else "black" for gene in transformed.index],
    s=20)

axs[0].set_title("Mendelian (Red); CS11 (Black)", fontsize=10)

In [None]:
transformed["Dim0"].sort_values()[:20]

In [None]:
normalized_features.loc["LHX3", :].sort_values()[-20:]

In [None]:
normalized_features.loc["CHRNA1", :].sort_values()[-20:]

In [None]:
normalized_features.loc["CHRNG", :].sort_values()[-20:]

In [None]:
normalized_features.loc["FOXE1", :].sort_values()[-20:]

In [None]:
attributions.loc["HNF4A", :].sort_values()[-20:]

In [None]:
attributions.loc["HNF4A", "Liver"]