In [None]:
import pandas as pd
import numpy as np
import torch

from speos.preprocessing.handler import InputHandler
from speos.utils.config import Config
from speos.preprocessing.datasets import DatasetBootstrapper

In [None]:
import os
os.chdir("..")

In [None]:
config = Config()
config.parse_yaml("config_uc_only_nohetio_film_newstorage.yaml")
prepro = InputHandler(config).get_preprocessor()
prepro.build_graph(adjacency=False)
data = prepro.get_data()

In [None]:

dataset = DatasetBootstrapper(holdout_size=config.input.holdout_size, name=config.name, config=config).get_dataset()

In [None]:
import json

with open("/mnt/storage/speos/results/uc_film_nohetioouter_results.json", "r") as file:
    results =  [key for key, value in json.load(file)[0].items() if value >= 11]

indices = torch.LongTensor([prepro.hgnc2id[hgnc] for hgnc in results])

with open("/mnt/storage/speos/results/uc_film_nohetioouter_results.json", "r") as file:
    results =  [key for key, value in json.load(file)[0].items() if value >= 1 and value < 11]

indices_weak = torch.LongTensor([prepro.hgnc2id[hgnc] for hgnc in results])

In [None]:
coregenes = dataset.data.y.long() 
coregenes[indices] = 1
coregenes.sum()

coregenes_weak = torch.zeros_like(coregenes)
coregenes_weak[indices_weak] = 1

In [None]:
hsps = pd.read_csv("hsps/uc.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0]]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1

In [None]:
gwas_zstat = pd.read_csv("data/gwas/UC.genes.out", header=0, index_col="GENE", usecols=["GENE", "ZSTAT"], sep=" ")
gwas_zstat = gwas_zstat.loc[gwas_zstat.index.isin(prepro.entrez2id.keys()), :]
len(gwas_zstat)

In [None]:
gwas_zstat = gwas_zstat.rename(prepro.entrez2id)

In [None]:
all_gwas_zstat = torch.zeros_like(new_y)
all_sign_indices = gwas_zstat[gwas_zstat["ZSTAT"] > 5].index
all_gwas_zstat[all_sign_indices] = 1

all_gwas_zstat[all_gwas_zstat.logical_and(coregenes)] = 0
all_gwas_zstat[all_gwas_zstat.logical_and(coregenes_weak)] = 0
all_gwas_zstat.sum()

# must be 70

In [None]:
features = pd.DataFrame(data=dataset.data.x.numpy(), columns=prepro.get_feature_names()).rename(prepro.id2hgnc, axis=0)
strongcore_features = features.iloc[coregenes.nonzero().squeeze().tolist(), :]
hsp_features = features.iloc[new_y.nonzero().squeeze().tolist(), :]
gwas_hsp_features = features.iloc[all_gwas_zstat.nonzero().squeeze().tolist(), :]
peripheral_features = features.iloc[(1 - (coregenes + coregenes_weak + new_y)).nonzero().squeeze().tolist(), :]
eligible_features = features.iloc[(1 - coregenes_weak).nonzero().squeeze().tolist(), :]

In [None]:

def plot_enrichment(featurename,  features, strongcore_features, peripheral_features, hsp_features, ax=None, fig=None):
    core_fraction = []
    peripheral_fraction = []
    hsp_fraction = []
    total_fraction = []

    for i in range(100, -1, -1):
        threshold = np.quantile(features[featurename], i / 100)
        total_fraction = ((features[featurename] >= threshold).sum() / len(features))
        core_fraction.append((strongcore_features[featurename] >= threshold).sum() / len(strongcore_features) / total_fraction)
        peripheral_fraction.append((peripheral_features[featurename] >= threshold).sum() / len(peripheral_features) / total_fraction)
        hsp_fraction.append((hsp_features[featurename] >= threshold).sum() / len(hsp_features) / total_fraction)

    core_fraction = np.asarray(core_fraction)
    peripheral_fraction = np.asarray(peripheral_fraction)
    hsp_fraction = np.asarray(hsp_fraction)

    if ax is None:
        fig, ax = plt.subplots()
    ax.plot(np.arange(101)[core_fraction > 0], core_fraction[core_fraction > 0], color="#01016f")
    ax.plot(np.arange(101)[peripheral_fraction > 0], peripheral_fraction[peripheral_fraction > 0], color="#5a5a5a")
    ax.plot(np.arange(101)[hsp_fraction > 0], hsp_fraction[hsp_fraction > 0], color="#d8031c")
    ax.set_xticks((0, 20, 40, 60, 80, 100))
    ax.set_xticklabels((100, 80, 60, 40, 20, 0))
    ax.set_xlim(0, 100)
    ax.set_xlabel("Percentile ({})".format(featurename),  fontsize=7)
    ax.set_ylabel("Under/Over-\nrepresentation", fontsize=8)
    
    ax.set_yscale("symlog")
    ax.set_yticks([ 1/3, 1/2, 1, 2, 3, 4, 5,10])
    ax.set_yticklabels([0.33, 1/2, 1, 2, 3, 4, 5, 10], fontsize=5)
    ax.grid(which="major", axis="y", ls="--", color="lightgray", zorder=-5)
    return fig, ax

In [None]:
from speos.visualization.settings import *
import matplotlib.pyplot as plt
from statsmodels.stats.multitest import fdrcorrection
from scipy.stats import fisher_exact

# fragments = [ [[0,1,2],[0.1, 0.2, 0.3]], [[4,5,6],[0.1, 0.2, 0.3]]

def get_fragments(values, sign_list):
    fragments = []
    index = 0
    fragment_idx = 0
    for sign in sign_list.tolist():
        if sign:
            try:
                fragments[fragment_idx][0].append(index)
                fragments[fragment_idx][1].append(values[index])
            except IndexError:
                fragments.append([[index], [values[index]]])
        else:
            try:
                fragments[fragment_idx]
                fragment_idx += 1
            except IndexError:
                # if we already incremented dont increment again
                pass
        index += 1

    return fragments
        
            

def plot_enrichment_fisher(featurename,  eligible_features, strongcore_features, peripheral_features, hsp_features, ax=None, fig=None):
    core_enrichment = []
    peripheral_enrichment = []
    hsp_enrichment = []
    core_pval = []
    peripheral_pval = []
    hsp_pval = []

    for i in range(100, -1, -1):
      threshold = np.quantile(eligible_features[featurename], i / 100)
      total_count = (eligible_features[featurename] >= threshold).sum()
      core_count = (strongcore_features[featurename] >= threshold).sum()
      peripheral_count = (peripheral_features[featurename] >= threshold).sum()
      hsp_count = (hsp_features[featurename] >= threshold).sum()

      array = np.asarray([[core_count, total_count-core_count],
              [len(strongcore_features) - core_count, len(eligible_features) - len(strongcore_features) - total_count + core_count]])
        
      assert array[0, :].sum() == total_count
      assert array[1, :].sum() == len(eligible_features) - total_count
      assert array[:, 0].sum() == len(strongcore_features)
      assert array[:, 1].sum() == len(eligible_features) - len(strongcore_features)


      core_enrichment.append(fisher_exact(array)[0])
      core_pval.append(fisher_exact(array)[1])
      
      array = np.asarray([[peripheral_count, total_count-peripheral_count],
              [len(peripheral_features) - peripheral_count, len(eligible_features) - len(peripheral_features) - total_count + peripheral_count]])
        
      assert array[0, :].sum() == total_count
      assert array[1, :].sum() == len(eligible_features) - total_count
      assert array[:, 0].sum() == len(peripheral_features)
      assert array[:, 1].sum() == len(eligible_features) - len(peripheral_features)

      peripheral_enrichment.append(fisher_exact(array)[0])
      peripheral_pval.append(fisher_exact(array)[1])

      array = np.asarray([[hsp_count, total_count-hsp_count],
              [len(hsp_features) - hsp_count, len(eligible_features) - len(hsp_features) - total_count + hsp_count]])

        
      assert array[0, :].sum() == total_count
      assert array[1, :].sum() == len(eligible_features) - total_count
      assert array[:, 0].sum() == len(hsp_features)
      assert array[:, 1].sum() == len(eligible_features) - len(hsp_features)

      hsp_enrichment.append(fisher_exact(array)[0])
      hsp_pval.append(fisher_exact(array)[1])
    
    

    core_enrichment = np.asarray(core_enrichment)
    peripheral_enrichment = np.asarray(peripheral_enrichment)
    hsp_enrichment = np.asarray(hsp_enrichment)
    total_sign = fdrcorrection(core_pval + peripheral_pval + hsp_pval)[0]
    core_sign = total_sign[:len(core_pval)]

    peripheral_sign = total_sign[len(core_pval):-len(hsp_pval)]
    hsp_sign = total_sign[-len(hsp_pval):]

    if ax is None:
        fig, ax = plt.subplots()
    for fragment in get_fragments(core_enrichment, core_sign):
        ax.plot(fragment[0], fragment[1], color="#01016f")
    for fragment in get_fragments(peripheral_enrichment, peripheral_sign):
        ax.plot(fragment[0], fragment[1], color="#5a5a5a")
    for fragment in get_fragments(hsp_enrichment, hsp_sign):
        ax.plot(fragment[0], fragment[1], color="#d8031c")


    ax.plot(np.arange(101), core_enrichment, color="lightblue", zorder=-5)
    #ax.plot(np.arange(101)[peripheral_sign], peripheral_enrichment[peripheral_sign], color="#5a5a5a")
    ax.plot(np.arange(101), peripheral_enrichment, color="lightgray", zorder=-5)
    #ax.plot(np.arange(101)[hsp_sign], hsp_enrichment[hsp_sign], color="#d8031c")
    ax.plot(np.arange(101), hsp_enrichment, color="pink", zorder=-5)
    ax.set_xticks((0, 20, 40, 60, 80, 100))
    ax.set_xticklabels((100, 80, 60, 40, 20, 0))
    ax.set_xlim(0, 100)
    ax.set_xlabel("Percentile ({})".format(featurename),  fontsize=7)
    ax.set_ylabel("Depletion/Enrichment", fontsize=7)
    ax.hlines(1, 0, 100, colors="black", linewidth=1)
    ax.set_yscale("symlog")
    ax.set_yticks([ 1/10, 1/2, 1,2, 5, 10, 20])
    ax.set_yticklabels([0.1, 1/2, 1, 2, 5, 10, 20])
    ax.grid(which="major", axis="y", ls="--", color="lightgray", zorder=-5)
    return fig, ax

In [None]:
fig, ax = plot_enrichment_fisher('Small Intestine - Terminal Ileum',  eligible_features, strongcore_features, peripheral_features, gwas_hsp_features)


In [None]:
fig, axes = plt.subplots(2, 3, figsize=(full_width*cm, 7*cm), sharex=False, sharey=True)

axes = axes.flatten()

for i, (featurename, ax) in enumerate(zip(("Cells - EBV-transformed lymphocytes","Whole Blood","Spleen", "Artery - Tibial", 'Brain - Frontal Cortex (BA9)', 'Brain - Anterior cingulate cortex (BA24)'), axes)):
     fig, ax = plot_enrichment_fisher(featurename,  eligible_features, strongcore_features, peripheral_features, hsp_features, ax, fig)
     if i % 3 != 0:
          ax.set_ylabel("")
plt.tight_layout()

plt.savefig("input_features_fisher.svg")

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(full_width*cm, 7*cm), sharex=False, sharey=True)

axes = axes.flatten()

for i, (featurename, ax) in enumerate(zip(("Cells - EBV-transformed lymphocytes","Whole Blood","Spleen", "Artery - Tibial", 'Brain - Frontal Cortex (BA9)', 'Brain - Anterior cingulate cortex (BA24)'), axes)):
     fig, ax = plot_enrichment_fisher(featurename,  eligible_features, strongcore_features, peripheral_features, gwas_hsp_features, ax, fig)
     if i % 3 != 0:
          ax.set_ylabel("")
plt.tight_layout()

plt.savefig("input_features_fisher_gwas.svg")