In [None]:
import pandas as pd
import numpy as np
import torch

from speos.preprocessing.handler import InputHandler
from speos.utils.config import Config
from speos.preprocessing.datasets import DatasetBootstrapper

In [None]:
import os
os.chdir("..")

In [None]:
config = Config()
config.parse_yaml("config_uc_only_nohetio_film_newstorage.yaml")
prepro = InputHandler(config).get_preprocessor()
G = prepro.get_graph()
prepro.get_data()

In [None]:
[prepro.hgnc2id["TNFSF15"]]

# must be 15506

In [None]:
data = prepro.get_data()


In [None]:

dataset = DatasetBootstrapper(holdout_size=config.input.holdout_size, name=config.name, config=config).get_dataset()

In [None]:
from torch_geometric.nn.models import LabelPropagation

In [None]:
from speos.utils.nn_utils import typed_edges_to_sparse_tensor
from torch_geometric.utils import add_remaining_self_loops, to_undirected 

edge_index, encoder = typed_edges_to_sparse_tensor(dataset.data.x, dataset.data.edge_index_dict)

In [None]:
from torch_sparse import SparseTensor
edge_index_flat = torch.vstack((edge_index.storage.row(), edge_index.storage.col()))
edge_index_flat_reversed = torch.vstack((edge_index.storage.col(), edge_index.storage.row()))
#edge_index_flat = add_remaining_self_loops(edge_index_flat)[0]
edge_index_new = SparseTensor(row = edge_index_flat[0, :], col= edge_index_flat[1, :])

In [None]:
from typing import Callable, Optional

from torch import Tensor
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from torch_geometric.utils import one_hot

class PULabelPropagation(LabelPropagation):

    def forward(
        self,
        y: Tensor,
        edge_index: Adj,
        mask: OptTensor = None,
        edge_weight: OptTensor = None,
        post_step: Optional[Callable[[Tensor], Tensor]] = None,
    ) -> Tensor:
        r"""
        Args:
            y (torch.Tensor): The ground-truth label information
                :math:`\mathbf{Y}`.
            edge_index (torch.Tensor or SparseTensor): The edge connectivity.
            mask (torch.Tensor, optional): A mask or index tensor denoting
                which nodes are used for label propagation.
                (default: :obj:`None`)
            edge_weight (torch.Tensor, optional): The edge weights.
                (default: :obj:`None`)
            post_step (callable, optional): A post step function specified
                to apply after label propagation. If no post step function
                is specified, the output will be clamped between 0 and 1.
                (default: :obj:`None`)
        """
        pos_mask = y.clone().bool()

        if y.dtype == torch.long and y.size(0) == y.numel():
            y = one_hot(y.view(-1))

        initial_y = y.clone()

        out = y
        if mask is not None:
            out = torch.zeros_like(y)
            out[mask] = y[mask]

        if isinstance(edge_index, SparseTensor) and not edge_index.has_value():
            edge_index = gcn_norm(edge_index, add_self_loops=False)
        elif isinstance(edge_index, Tensor) and edge_weight is None:
            edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0),
                                               add_self_loops=False)

        res = (1 - self.alpha) * out
        for _ in range(self.num_layers):
            # propagate_type: (y: Tensor, edge_weight: OptTensor)
            out = self.propagate(edge_index, x=out, edge_weight=edge_weight,
                                 size=None)
            out.mul_(self.alpha).add_(res)

            out = torch.nn.functional.normalize(out, p=1, dim=1)

            out[pos_mask, 1] = 1
            
            

        return out

In [None]:
import json

with open("/mnt/storage/speos/results/uc_film_nohetioouter_results.json", "r") as file:
    results =  [key for key, value in json.load(file)[0].items() if value >= 11]

indices = torch.LongTensor([prepro.hgnc2id[hgnc] for hgnc in results])

with open("/mnt/storage/speos/results/uc_film_nohetioouter_results.json", "r") as file:
    results =  [key for key, value in json.load(file)[0].items() if value >= 1 and value < 11]

indices_weak = torch.LongTensor([prepro.hgnc2id[hgnc] for hgnc in results])

 

In [None]:
dataset.data.y.long().sum()
# must be 379

In [None]:
coregenes = dataset.data.y.long() 
coregenes[indices] = 1
coregenes.sum()

coregenes_weak = torch.zeros_like(coregenes)
coregenes_weak[indices_weak] = 1

In [None]:
coregenes

In [None]:
G.in_degree(prepro.hgnc2id["PARK7"])
# must be 120

In [None]:
G.out_degree(prepro.hgnc2id["PARK7"])
# must be 0

# See if HSPs are "closer" to core genes

In [None]:
hsps = pd.read_csv("hsps/uc.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0]]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1

In [None]:
hsps = pd.read_csv("hsps/uc.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0]]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1

import seaborn as sns
from speos.visualization.settings import *


import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 6, figsize=(full_width*cm,5*cm), sharey=False)

for i, (num_layers, edges, ax) in enumerate(zip((1,3,5,1,3,5), [edge_index_flat, edge_index_flat, edge_index_flat, edge_index_flat_reversed, edge_index_flat_reversed, edge_index_flat_reversed], axes)):
                             
    model = LabelPropagation(num_layers=num_layers, alpha=0.9)
    out = model(new_y.long(), edges)

    df = pd.DataFrame()
    df["HGNC"] = list(prepro.id2hgnc.values())
    df["coregenes"] = coregenes
    df["weak_coregenes"] = coregenes_weak
    df["total_coregenes"] = coregenes_weak + coregenes
    df["hsp"] = new_y
    df["propagated"] = out[:, 1]

    new_df = df[df["hsp"] == 0]
    new_df = new_df[new_df["weak_coregenes"] == 0]
    new_df = new_df[new_df["propagated"] > 0]
                             
    ax = sns.boxplot(new_df,x ="coregenes", y="propagated", fliersize=0.3, ax=ax, order=[1, 0], palette={0: "#04964d", 1: "darkgray"}, linewidth=1)
    if i != 0:
        ax.set_ylabel("")
    else:
        ax.set_ylabel("Propagated z'")
    ax.set_xlabel("")
    ax.set_xticklabels(["Core\n(n={})".format((new_df["coregenes"] == 1).sum()), "Peripheral\n(n={})".format((new_df["coregenes"] == 0).sum())])
    topval = np.quantile(new_df["propagated"], 0.99)
    ax.set_ylim((0, topval))

    from scipy.stats import mannwhitneyu

    pval =  mannwhitneyu(new_df["propagated"][new_df["coregenes"] == 1], new_df["propagated"][new_df["coregenes"] == 0])[1] * 6
    
    if pval < 0.001:
        s = "***"
    elif pval < 0.01:
        s = "**"
    elif pval < 0.05:
        s = "*"
    else:
        s = "n.s."

    ax.text(0.5, y=max(np.quantile(new_df["propagated"][new_df["coregenes"] == 1], 0.75), np.quantile(new_df["propagated"][new_df["coregenes"] == 0], 0.75) ) * 1.2,
            s=s, fontsize=small_font, ha="center")

    ax.tick_params(axis='x', labelrotation=90)
plt.tight_layout()

plt.savefig("label_propagation_pyg.pdf", bbox_inches="tight")

# Undirected

In [None]:
from torch_geometric.utils import to_undirected

hsps = pd.read_csv("hsps/uc.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0]]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1

import seaborn as sns
from speos.visualization.settings import *


import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 6, figsize=(full_width*cm,5*cm), sharey=False)

for i, (num_layers, edges, ax) in enumerate(zip((1,3,5,1,3,5), [edge_index_flat, edge_index_flat, edge_index_flat, edge_index_flat_reversed, edge_index_flat_reversed, edge_index_flat_reversed], axes)):
                             
    model = LabelPropagation(num_layers=num_layers, alpha=0.9)
    out = model(new_y.long(), to_undirected(edges))

    df = pd.DataFrame()
    df["HGNC"] = list(prepro.id2hgnc.values())
    df["coregenes"] = coregenes
    df["weak_coregenes"] = coregenes_weak
    df["total_coregenes"] = coregenes_weak + coregenes
    df["hsp"] = new_y
    df["propagated"] = out[:, 1]

    new_df = df[df["hsp"] == 0]
    new_df = new_df[new_df["weak_coregenes"] == 0]
    new_df = new_df[new_df["propagated"] > 0]
                             
    ax = sns.boxplot(new_df,x ="coregenes", y="propagated", fliersize=0.3, ax=ax, order=[1, 0], palette={0: "#04964d", 1: "darkgray"}, linewidth=1)
    if i != 0:
        ax.set_ylabel("")
    else:
        ax.set_ylabel("Propagated z'")
    ax.set_xlabel("")
    ax.set_xticklabels(["Core\n(n={})".format((new_df["coregenes"] == 1).sum()), "Peripheral\n(n={})".format((new_df["coregenes"] == 0).sum())])
    topval = np.quantile(new_df["propagated"], 0.99)
    ax.set_ylim((0, topval))

    from scipy.stats import mannwhitneyu

    pval =  mannwhitneyu(new_df["propagated"][new_df["coregenes"] == 1], new_df["propagated"][new_df["coregenes"] == 0])[1] * 6
    
    if pval < 0.001:
        s = "***"
    elif pval < 0.01:
        s = "**"
    elif pval < 0.05:
        s = "*"
    else:
        s = "n.s."

    ax.text(0.5, y=max(np.quantile(new_df["propagated"][new_df["coregenes"] == 1], 0.75), np.quantile(new_df["propagated"][new_df["coregenes"] == 0], 0.75) ) * 1.2,
            s=s, fontsize=small_font, ha="center")

    ax.tick_params(axis='x', labelrotation=90)
plt.tight_layout()

plt.savefig("label_propagation_pyg_undirected.pdf", bbox_inches="tight")

# Get Connection Statistics

In [None]:
from torch_geometric.utils import degree
from collections import Counter
from speos.visualization.settings import *
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

out_degrees = degree(edge_index_flat[0, :], dataset.data.x.shape[0])
in_degrees = degree(edge_index_flat[1, :], dataset.data.x.shape[0])
total_degrees = in_degrees + out_degrees



out_degree_core = out_degrees[coregenes.nonzero()]
out_degree_hsp = out_degrees[new_y.nonzero()]
out_degree_peripheral = out_degrees[(1 - (coregenes + coregenes_weak + new_y)).nonzero()]

in_degree_core = in_degrees[coregenes.nonzero()]
in_degree_hsp = in_degrees[new_y.nonzero()]
in_degree_peripheral = in_degrees[(1 - (coregenes + coregenes_weak + new_y)).nonzero()]

total_degree_core = total_degrees[coregenes.nonzero()]
total_degree_hsp = total_degrees[new_y.nonzero()]
total_degree_peripheral = total_degrees[(1 - (coregenes + coregenes_weak + new_y)).nonzero()]


out_core_counter = Counter(out_degree_core.squeeze().tolist())
out_hsp_counter = Counter(out_degree_hsp.squeeze().tolist())
out_peripheral_counter = Counter(out_degree_peripheral.squeeze().tolist())

in_core_counter = Counter(in_degree_core.squeeze().tolist())
in_hsp_counter = Counter(in_degree_hsp.squeeze().tolist())
in_peripheral_counter = Counter(in_degree_peripheral.squeeze().tolist())

total_core_counter = Counter(total_degree_core.squeeze().tolist())
total_hsp_counter = Counter(total_degree_hsp.squeeze().tolist())
total_peripheral_counter = Counter(total_degree_peripheral.squeeze().tolist())

out_counter = [out_peripheral_counter, out_core_counter, out_hsp_counter]
in_counter = [in_peripheral_counter, in_core_counter, in_hsp_counter]
total_counter = [total_peripheral_counter, total_hsp_counter, total_hsp_counter]

fig, axes = plt.subplots(1,4, figsize=(full_width*cm*1.3,5*cm*1.3), sharey=True, width_ratios=(3,3,3, 1.2))

for counters, ax, title, xval in zip((out_counter, in_counter, total_counter, None), axes, ("Out-Degree", "In-Degree", "Total Degree", None), (1e5, 1e3 *1.3, 1e4 *6.2, None)):
    if title is None:
        legend_elements = [Patch(facecolor='#5a5a5a', edgecolor='#5a5a5a',
                                label='Peripheral\nn={}'.format((1 - (coregenes + coregenes_weak)).sum())),
                            Patch(facecolor='#01016f', edgecolor='#01016f',
                                    label='Core Gene\nn={}'.format(coregenes.sum())),
                            Patch(facecolor='#d8031c', edgecolor='#d8031c',
                                    label='HSP\nn={}'.format(new_y.sum().long()))]

        leg = ax.legend(handles=legend_elements, loc='center', title="Node Class", fontsize=6.8, title_fontsize=8, ncol=1, columnspacing=1.7, handletextpad=-0.5, labelspacing=1.7)

        for patch in leg.get_patches():
            patch.set_height(15)
            patch.set_width(5)
            patch.set_y(-5)
        ax.set_axis_off()

    else:

        ax.text(xval, 1e3 * 2, "Degree 0:", color="black", fontsize=8, ha="right")
        for counter, color, yval, totalnum in zip(counters, ("#5a5a5a", "#01016f", "#d8031c"), (1e3, 1e3 * 0.5, 1e3 * 0.25), ((1 - (coregenes + coregenes_weak)).sum(), coregenes.sum(),new_y.sum())):
            x, y = zip(*counter.items())           
            ax.scatter(x, y, marker='.', color=color, alpha=0.1)   
            ax.text(xval, yval, "{} ({:.1f}%)".format(counter[0], (counter[0] / totalnum)*100), color=color, fontsize=8, ha="right")  
                                            

                                                                                                                                                                                                                                                                
        # prep axes                                                                                                                      
        ax.set_xlabel(title)                                                                                        
        ax.set_xscale('log')                                                                                                                
        #ax.set_xlim(0.9, max(x) + 0.1 * max(x))  
        if title == "Out-Degree":                                                                                                        
            ax.set_ylabel('Frequency')                                                                                                          
        ax.set_yscale('log')                                                                                                                
        #ax.set_ylim(0.9, max(y) + 0.1 *max(y))       

plt.savefig("degree_distributions.svg", bbox_inches="tight")                                                                                                      

In [None]:
os.getcwd()

In [None]:
import matplotlib as mpl

fig, ax = plt.subplots(figsize=(full_width*cm*0.1, 0.2*cm))
col_map = plt.get_cmap('Reds')
cbar = mpl.colorbar.ColorbarBase(ax, cmap=col_map, orientation = 'horizontal', ticks=[0,  0.5,  1])
cbar.ax.tick_params(labelsize=5)
cbar.set_label(label="Propagated $Z'$",size=6,weight='bold')
# As for a more fancy example, you can also give an axes by hand:
c_map_ax = fig.add_axes([0.2, 0.8, 0.6, 0.02])
c_map_ax.axes.get_xaxis().set_visible(False)
c_map_ax.axes.get_yaxis().set_visible(False)
plt.tight_layout()
# and create another colorbar with:
#mpl.colorbar.ColorbarBase(c_map_ax, cmap=col_map, orientation = 'horizontal', )
plt.savefig("colorbar.svg")

In [None]:
core_and_isolated = ((out_degrees == 0 )[coregenes.nonzero()]).sum()
hsp_and_isolated = ((out_degrees == 0 )[new_y.nonzero()]).sum()
core_not_isolated = ((out_degrees > 0 )[coregenes.nonzero()]).sum()
hsp_not_isolated = ((out_degrees > 0 )[new_y.nonzero()]).sum()

In [None]:
from scipy.stats import fisher_exact

array = np.asarray([[hsp_and_isolated, hsp_not_isolated],
                    [core_and_isolated, core_not_isolated]])

fisher_exact(array)

In [None]:
from scipy.stats import mannwhitneyu

out_degree_core = out_degrees[coregenes.nonzero()]
out_degree_hsp = out_degrees[new_y.nonzero()]

mannwhitneyu(out_degree_core, out_degree_hsp)

In [None]:
from scipy.stats import mannwhitneyu

out_degree_core = out_degrees[coregenes.nonzero()]
out_degree_hsp = out_degrees[new_y.nonzero()]

mannwhitneyu(out_degree_core[out_degree_core > 0], out_degree_hsp[out_degree_hsp > 0])

In [None]:
out_degree_counts = Counter(out_degree.tolist())       
in_degree_counts = Counter(in_degree.tolist())         

fig, axes = plt.subplots(2,1, figsize=(3,6))

for counter, ax, title, color in zip((out_degree_counts, in_degree_counts), axes, ("Out-Degree", "In-Degree"), ("#03CAF7", "#59D52F")):
    x, y = zip(*counter.items())                                                      

                                                                                                                                                                                                                                                            
    # prep axes                                                                                                                      
    ax.set_xlabel('degree')                                                                                        
    ax.set_xscale('log')                                                                                                                
    ax.set_xlim(0.9, max(x) + 0.1 * max(x))  
                                                                                                                
    ax.set_ylabel('frequency')                                                                                                          
    ax.set_yscale('log')                                                                                                                
    ax.set_ylim(0.9, max(y) + 0.1 *max(y))                                                                                                             
                                                                                                                                            # do plot                                                                                                                        
    ax.scatter(x, y, marker='.', color=color)
    ax.set_title(title)

plt.tight_layout()
plt.show()

# HSPs from other Phenotypes

In [None]:
hsps = pd.read_csv("hsps/uc.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0]]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1

import seaborn as sns
from speos.visualization.settings import *

#test_df_list = []

import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 6, figsize=(full_width*cm,5*cm), sharey=False)
uc_pvals = []
for i, (num_layers, edges, ax) in enumerate(zip((1,3,5,1,3,5), [edge_index_flat, edge_index_flat, edge_index_flat, edge_index_flat_reversed, edge_index_flat_reversed, edge_index_flat_reversed], axes)):
                             
    model = LabelPropagation(num_layers=num_layers, alpha=0.9)
    out = model(new_y.long(), edges)

    df = pd.DataFrame()
    df["HGNC"] = list(prepro.id2hgnc.values())
    df["coregenes"] = coregenes
    df["weak_coregenes"] = coregenes_weak
    df["total_coregenes"] = coregenes_weak + coregenes
    df["hsp"] = new_y
    df["propagated"] = out[:, 1]

    new_df = df[df["hsp"] == 0]
    new_df = new_df[new_df["weak_coregenes"] == 0]
    new_df = new_df[new_df["propagated"] > 0]
                             
    ax = sns.boxplot(new_df,x ="coregenes", y="propagated", fliersize=0.3, ax=ax, order=[1, 0], palette={0: "#5a5a5a", 1: "#01016f"}, linewidth=1)
    if i != 0:
        ax.set_ylabel("")
    else:
        ax.set_ylabel("Propagated z'")
    ax.set_xlabel("")
    ax.set_xticklabels(["Core\n(n={})".format((new_df["coregenes"] == 1).sum()), "Peripheral\n(n={})".format((new_df["coregenes"] == 0).sum())])
    topval = np.quantile(new_df["propagated"], 0.99)
    ax.set_ylim((0, topval))

    from scipy.stats import mannwhitneyu

    pval =  mannwhitneyu(new_df["propagated"][new_df["coregenes"] == 1], new_df["propagated"][new_df["coregenes"] == 0])[1]
    uc_pvals.append(pval)
    
    if pval * 6 < 0.001:
        s = "***"
    elif pval * 6 < 0.01:
        s = "**"
    elif pval * 6 < 0.05:
        s = "*"
    else:
        s = "n.s."

    ax.text(0.5, y=max(np.quantile(new_df["propagated"][new_df["coregenes"] == 1], 0.75), np.quantile(new_df["propagated"][new_df["coregenes"] == 0], 0.75) ) * 1.2,
            s=s, fontsize=small_font, ha="center")

    ax.tick_params(axis='x', labelrotation=90)

#test_df_list.append(pvals)
plt.tight_layout()

plt.savefig("label_propagation_pyg_uc.pdf", bbox_inches="tight")

In [None]:
hsps = pd.read_csv("hsps/uc.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0]]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1

edge_weights = torch.load("edge_attributions_tensor_UC.pt")

import seaborn as sns
from speos.visualization.settings import *

#test_df_list = []

import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 6, figsize=(full_width*cm,5*cm), sharey=False)
uc_pvals = []
for i, (num_layers, edges, ax) in enumerate(zip((1,3,5,1,3,5), [edge_index_flat, edge_index_flat, edge_index_flat, edge_index_flat_reversed, edge_index_flat_reversed, edge_index_flat_reversed], axes)):
                             
    model = LabelPropagation(num_layers=num_layers, alpha=0.9)
    out = model(new_y.long(), edges, edge_weight=edge_weights)

    df = pd.DataFrame()
    df["HGNC"] = list(prepro.id2hgnc.values())
    df["coregenes"] = coregenes
    df["weak_coregenes"] = coregenes_weak
    df["total_coregenes"] = coregenes_weak + coregenes
    df["hsp"] = new_y
    df["propagated"] = out[:, 1]

    new_df = df[df["hsp"] == 0]
    new_df = new_df[new_df["weak_coregenes"] == 0]
    new_df = new_df[new_df["propagated"] > 0]
                             
    ax = sns.boxplot(new_df,x ="coregenes", y="propagated", fliersize=0.3, ax=ax, order=[1, 0], palette={0: "#5a5a5a", 1: "#01016f"}, linewidth=1)
    if i != 0:
        ax.set_ylabel("")
    else:
        ax.set_ylabel("Propagated z'")
    ax.set_xlabel("")
    ax.set_xticklabels(["Core\n(n={})".format((new_df["coregenes"] == 1).sum()), "Peripheral\n(n={})".format((new_df["coregenes"] == 0).sum())])
    topval = np.quantile(new_df["propagated"], 0.99)
    #ax.set_ylim((0, topval))

    from scipy.stats import mannwhitneyu

    pval =  mannwhitneyu(new_df["propagated"][new_df["coregenes"] == 1], new_df["propagated"][new_df["coregenes"] == 0])[1]
    uc_pvals.append(pval)
    
    if pval * 6 < 0.001:
        s = "***"
    elif pval * 6 < 0.01:
        s = "**"
    elif pval * 6 < 0.05:
        s = "*"
    else:
        s = "n.s."

    ax.text(0.5, y=max(np.quantile(new_df["propagated"][new_df["coregenes"] == 1], 0.75), np.quantile(new_df["propagated"][new_df["coregenes"] == 0], 0.75) ) * 1.2,
            s=s, fontsize=small_font, ha="center")

    ax.tick_params(axis='x', labelrotation=90)

#test_df_list.append(pvals)
plt.tight_layout()

plt.savefig("label_propagation_pyg_uc_weighted.pdf", bbox_inches="tight")

In [None]:
hsps = pd.read_csv("hsps/cad.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0]]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1

import seaborn as sns
from speos.visualization.settings import *


import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 6, figsize=(full_width*cm,5*cm), sharey=False)

cad_pvals = []
for i, (num_layers, edges, ax) in enumerate(zip((1,3,5,1,3,5), [edge_index_flat, edge_index_flat, edge_index_flat, edge_index_flat_reversed, edge_index_flat_reversed, edge_index_flat_reversed], axes)):
                             
    model = LabelPropagation(num_layers=num_layers, alpha=0.9)
    out = model(new_y.long(), edges)

    df = pd.DataFrame()
    df["HGNC"] = list(prepro.id2hgnc.values())
    df["coregenes"] = coregenes
    df["weak_coregenes"] = coregenes_weak
    df["total_coregenes"] = coregenes_weak + coregenes
    df["hsp"] = new_y
    df["propagated"] = out[:, 1]

    new_df = df[df["hsp"] == 0]
    new_df = new_df[new_df["weak_coregenes"] == 0]
    new_df = new_df[new_df["propagated"] > 0]
                             
    ax = sns.boxplot(new_df,x ="coregenes", y="propagated", fliersize=0.3, ax=ax, order=[1, 0], palette={0: "#5a5a5a", 1: "#01016f"}, linewidth=1)
    if i != 0:
        ax.set_ylabel("")
    else:
        ax.set_ylabel("Propagated z'")
    ax.set_xlabel("")
    ax.set_xticklabels(["Core\n(n={})".format((new_df["coregenes"] == 1).sum()), "Peripheral\n(n={})".format((new_df["coregenes"] == 0).sum())])
    topval = np.quantile(new_df["propagated"], 0.99)
    ax.set_ylim((0, topval))

    from scipy.stats import mannwhitneyu

    pval =  mannwhitneyu(new_df["propagated"][new_df["coregenes"] == 1], new_df["propagated"][new_df["coregenes"] == 0])[1]
    cad_pvals.append(pval)
    if pval * 6 < 0.001:
        s = "***"
    elif pval * 6 < 0.01:
        s = "**"
    elif pval * 6 < 0.05:
        s = "*"
    else:
        s = "n.s."

    ax.text(0.5, y=max(np.quantile(new_df["propagated"][new_df["coregenes"] == 1], 0.75), np.quantile(new_df["propagated"][new_df["coregenes"] == 0], 0.75) ) * 1.2,
            s=s, fontsize=small_font, ha="center")

    ax.tick_params(axis='x', labelrotation=90)
plt.tight_layout()

plt.savefig("label_propagation_pyg_cad.pdf", bbox_inches="tight")

In [None]:
hsps = pd.read_csv("hsps/cad.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0]]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1

import seaborn as sns
from speos.visualization.settings import *


import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 6, figsize=(full_width*cm,5*cm), sharey=False)

cad_pvals = []
for i, (num_layers, edges, ax) in enumerate(zip((1,3,5,1,3,5), [edge_index_flat, edge_index_flat, edge_index_flat, edge_index_flat_reversed, edge_index_flat_reversed, edge_index_flat_reversed], axes)):
                             
    model = LabelPropagation(num_layers=num_layers, alpha=0.9)
    out = model(new_y.long(), edges, edge_weight=edge_weights)

    df = pd.DataFrame()
    df["HGNC"] = list(prepro.id2hgnc.values())
    df["coregenes"] = coregenes
    df["weak_coregenes"] = coregenes_weak
    df["total_coregenes"] = coregenes_weak + coregenes
    df["hsp"] = new_y
    df["propagated"] = out[:, 1]

    new_df = df[df["hsp"] == 0]
    new_df = new_df[new_df["weak_coregenes"] == 0]
    new_df = new_df[new_df["propagated"] > 0]
                             
    ax = sns.boxplot(new_df,x ="coregenes", y="propagated", fliersize=0.3, ax=ax, order=[1, 0], palette={0: "#5a5a5a", 1: "#01016f"}, linewidth=1)
    if i != 0:
        ax.set_ylabel("")
    else:
        ax.set_ylabel("Propagated z'")
    ax.set_xlabel("")
    ax.set_xticklabels(["Core\n(n={})".format((new_df["coregenes"] == 1).sum()), "Peripheral\n(n={})".format((new_df["coregenes"] == 0).sum())])
    topval = np.quantile(new_df["propagated"], 0.99)
    #ax.set_ylim((0, topval))

    from scipy.stats import mannwhitneyu

    pval =  mannwhitneyu(new_df["propagated"][new_df["coregenes"] == 1], new_df["propagated"][new_df["coregenes"] == 0])[1]
    cad_pvals.append(pval)
    if pval * 6 < 0.001:
        s = "***"
    elif pval * 6 < 0.01:
        s = "**"
    elif pval * 6 < 0.05:
        s = "*"
    else:
        s = "n.s."

    ax.text(0.5, y=max(np.quantile(new_df["propagated"][new_df["coregenes"] == 1], 0.75), np.quantile(new_df["propagated"][new_df["coregenes"] == 0], 0.75) ) * 1.2,
            s=s, fontsize=small_font, ha="center")

    ax.tick_params(axis='x', labelrotation=90)
plt.tight_layout()

plt.savefig("label_propagation_pyg_cad_weighted.pdf", bbox_inches="tight")

In [None]:

import seaborn as sns
from speos.visualization.settings import *


import matplotlib.pyplot as plt
hsps = pd.read_csv("hsps/scz.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0] if hgnc in prepro.hgnc2id.keys()]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1

fig, axes = plt.subplots(1, 6, figsize=(full_width*cm,5*cm), sharey=False)
scz_pvals = []
for i, (num_layers, edges, ax) in enumerate(zip((1,3,5,1,3,5), [edge_index_flat, edge_index_flat, edge_index_flat, edge_index_flat_reversed, edge_index_flat_reversed, edge_index_flat_reversed], axes)):
                             
    model = LabelPropagation(num_layers=num_layers, alpha=0.9)
    out = model(new_y.long(), edges)

    df = pd.DataFrame()
    df["HGNC"] = list(prepro.id2hgnc.values())
    df["coregenes"] = coregenes
    df["weak_coregenes"] = coregenes_weak
    df["total_coregenes"] = coregenes_weak + coregenes
    df["hsp"] = new_y
    df["propagated"] = out[:, 1]

    new_df = df[df["hsp"] == 0]
    new_df = new_df[new_df["weak_coregenes"] == 0]
    new_df = new_df[new_df["propagated"] > 0]
                             
    ax = sns.boxplot(new_df,x ="coregenes", y="propagated", fliersize=0.3, ax=ax, order=[1, 0], palette={0: "#5a5a5a", 1: "#01016f"}, linewidth=1)
    if i != 0:
        ax.set_ylabel("")
    else:
        ax.set_ylabel("Propagated z'")
    ax.set_xlabel("")
    ax.set_xticklabels(["Core\n(n={})".format((new_df["coregenes"] == 1).sum()), "Peripheral\n(n={})".format((new_df["coregenes"] == 0).sum())])
    topval = np.quantile(new_df["propagated"], 0.99)
    ax.set_ylim((0, topval))

    from scipy.stats import mannwhitneyu

    pval =  mannwhitneyu(new_df["propagated"][new_df["coregenes"] == 1], new_df["propagated"][new_df["coregenes"] == 0])[1]
    scz_pvals.append(pval)
    if pval * 6 < 0.001:
        s = "***"
    elif pval * 6 < 0.01:
        s = "**"
    elif pval * 6 < 0.05:
        s = "*"
    else:
        s = "n.s."
    
    ax.text(0.5, y=max(np.quantile(new_df["propagated"][new_df["coregenes"] == 1], 0.75), np.quantile(new_df["propagated"][new_df["coregenes"] == 0], 0.75) ) * 1.2,
            s=s, fontsize=small_font, ha="center")

    ax.tick_params(axis='x', labelrotation=90)
plt.tight_layout()

plt.savefig("label_propagation_pyg_scz.pdf", bbox_inches="tight")

In [None]:
from random import choice, seed
from scipy.stats import mannwhitneyu
from speos.visualization.settings import *

seed(1)
for _ in range(500):
    hsps = pd.read_csv("hsps/uc.txt", header=None, index_col=None, sep="\t")
    hsp_indices = [choice(list(prepro.hgnc2id.values())) for _ in range(len(hsps))]
    new_y = torch.zeros_like(new_y)
    new_y[np.asarray(hsp_indices)] = 1

    pvals = []
    for i, (num_layers, edges) in enumerate(zip((1,3,5,1,3,5), [edge_index_flat, edge_index_flat, edge_index_flat, edge_index_flat_reversed, edge_index_flat_reversed, edge_index_flat_reversed])):
                                
        model = LabelPropagation(num_layers=num_layers, alpha=0.9)
        out = model(new_y.long(), edges)

        df = pd.DataFrame()
        df["HGNC"] = list(prepro.id2hgnc.values())
        df["coregenes"] = coregenes
        df["weak_coregenes"] = coregenes_weak
        df["total_coregenes"] = coregenes_weak + coregenes
        df["hsp"] = new_y
        df["propagated"] = out[:, 1]

        new_df = df[df["hsp"] == 0]
        new_df = new_df[new_df["weak_coregenes"] == 0]
        new_df = new_df[new_df["propagated"] > 0]

        from scipy.stats import mannwhitneyu

        pval =  mannwhitneyu(new_df["propagated"][new_df["coregenes"] == 1], new_df["propagated"][new_df["coregenes"] == 0])[1]
        pvals.append(pval)
    test_df_list.append(pvals)
#plt.tight_layout()

#plt.savefig("label_propagation_pyg_scz.pdf", bbox_inches="tight")

In [None]:
from random import choice, seed
from scipy.stats import mannwhitneyu
from speos.visualization.settings import *

test_df_list = []

seed(1)
for _ in range(500):
    hsps = pd.read_csv("hsps/uc.txt", header=None, index_col=None, sep="\t")
    hsp_indices = [choice(list(prepro.hgnc2id.values())) for _ in range(len(hsps))]
    new_y = torch.zeros_like(new_y)
    new_y[np.asarray(hsp_indices)] = 1

    pvals = []
    for i, (num_layers, edges) in enumerate(zip((1,3,5,1,3,5), [edge_index_flat, edge_index_flat, edge_index_flat, edge_index_flat_reversed, edge_index_flat_reversed, edge_index_flat_reversed])):
                                
        model = LabelPropagation(num_layers=num_layers, alpha=0.9)
        out = model(new_y.long(), edges, edge_weight=edge_weights)

        df = pd.DataFrame()
        df["HGNC"] = list(prepro.id2hgnc.values())
        df["coregenes"] = coregenes
        df["weak_coregenes"] = coregenes_weak
        df["total_coregenes"] = coregenes_weak + coregenes
        df["hsp"] = new_y
        df["propagated"] = out[:, 1]

        new_df = df[df["hsp"] == 0]
        new_df = new_df[new_df["weak_coregenes"] == 0]
        new_df = new_df[new_df["propagated"] > 0]

        from scipy.stats import mannwhitneyu

        pval =  mannwhitneyu(new_df["propagated"][new_df["coregenes"] == 1], new_df["propagated"][new_df["coregenes"] == 0])[1]
        pvals.append(pval)
    test_df_list.append(pvals)


In [None]:
test_df_list = np.asarray(test_df_list)
old_shape = test_df_list.shape

adjusted = fdrcorrection(test_df_list.flatten())[1].reshape(old_shape)

In [None]:
(test_df_list < 0.05).sum(axis=0)

In [None]:
(adjusted < 0.05).sum(axis=0)

In [None]:
test_df = pd.DataFrame(adjusted, index=["UC"] + ["Random{}".format(i) for i in range(500)], columns=["1","3","5","1_rev", "3_rev", "5_rev"])

In [None]:
(test_df.sort_values("1").index == "UC").nonzero()[0] / len(test_df)

In [None]:
(test_df.sort_values("3").index == "UC").nonzero()[0] / len(test_df)

In [None]:
(test_df.sort_values("5").index == "UC").nonzero()[0] / len(test_df)

In [None]:
(test_df.sort_values("1_rev").index == "UC").nonzero()[0] / len(test_df)

In [None]:
(test_df.sort_values("3_rev").index == "UC").nonzero()[0] / len(test_df)

In [None]:
(test_df.sort_values("5_rev").index == "UC").nonzero()[0] / len(test_df)

In [None]:
uc_test_df_list = cad_test_df_list.tolist()
uc_test_df_list[0] = uc_pvals
uc_test_df_list = np.asarray(uc_test_df_list)
old_shape = uc_test_df_list.shape

adjusted = fdrcorrection(uc_test_df_list.flatten())[1].reshape(old_shape)
test_df = pd.DataFrame(adjusted, index=["CAD"] + ["Random{}".format(i) for i in range(500)], columns=["1","3","5","1_rev", "3_rev", "5_rev"])

print((test_df.sort_values("1").index == "CAD").nonzero()[0] / len(test_df))
print((test_df.sort_values("3").index == "CAD").nonzero()[0] / len(test_df))
print((test_df.sort_values("5").index == "CAD").nonzero()[0] / len(test_df))
print((test_df.sort_values("1_rev").index == "CAD").nonzero()[0] / len(test_df))
print((test_df.sort_values("3_rev").index == "CAD").nonzero()[0] / len(test_df))
print((test_df.sort_values("5_rev").index == "CAD").nonzero()[0] / len(test_df))

# For CAD

In [None]:
cad_test_df_list = test_df_list[:]
cad_test_df_list[0] = cad_pvals
cad_test_df_list = np.asarray(cad_test_df_list)
old_shape = cad_test_df_list.shape

adjusted = fdrcorrection(cad_test_df_list.flatten())[1].reshape(old_shape)
test_df = pd.DataFrame(adjusted, index=["CAD"] + ["Random{}".format(i) for i in range(500)], columns=["1","3","5","1_rev", "3_rev", "5_rev"])

print((test_df.sort_values("1").index == "CAD").nonzero()[0] / len(test_df))
print((test_df.sort_values("3").index == "CAD").nonzero()[0] / len(test_df))
print((test_df.sort_values("5").index == "CAD").nonzero()[0] / len(test_df))
print((test_df.sort_values("1_rev").index == "CAD").nonzero()[0] / len(test_df))
print((test_df.sort_values("3_rev").index == "CAD").nonzero()[0] / len(test_df))
print((test_df.sort_values("5_rev").index == "CAD").nonzero()[0] / len(test_df))

In [None]:
scz_test_df_list = test_df_list[:]
scz_test_df_list[0] = scz_pvals
scz_test_df_list = np.asarray(cad_test_df_list)
old_shape = scz_test_df_list.shape

adjusted = fdrcorrection(scz_test_df_list.flatten())[1].reshape(old_shape)
test_df = pd.DataFrame(adjusted, index=["SCZ"] + ["Random{}".format(i) for i in range(500)], columns=["1","3","5","1_rev", "3_rev", "5_rev"])

print((test_df.sort_values("1").index == "SCZ").nonzero()[0] / len(test_df))
print((test_df.sort_values("3").index == "SCZ").nonzero()[0] / len(test_df))
print((test_df.sort_values("5").index == "SCZ").nonzero()[0] / len(test_df))
print((test_df.sort_values("1_rev").index == "SCZ").nonzero()[0] / len(test_df))
print((test_df.sort_values("3_rev").index == "SCZ").nonzero()[0] / len(test_df))
print((test_df.sort_values("5_rev").index == "SCZ").nonzero()[0] / len(test_df))

In [None]:
test_df

In [None]:
import matplotlib.patches as patches
from scipy.stats import rankdata

fig, axes = plt.subplots(nrows=1, ncols=7, figsize=(full_width*cm, 5*cm))

for ax, str_ind, ind in zip(axes.tolist(), test_df.columns.tolist() + ["None"], range(7)):
    if ind < 6:
        ax = sns.kdeplot(y = np.log10(test_df[str_ind][]) * -1, cut=0, fill="lightblue", ax=ax)
        if ind > 2:
            value = np.quantile(np.log10(test_df[str_ind]) * -1, 0.95)
        else:
            value = np.quantile(np.log10(test_df[str_ind]) * -1, 0.05)

        ax.hlines(value, 0, 0.05, color="gray", zorder=2)
        ax.hlines(np.log10(uc_pvals[ind]) * -1, 0, 0.1, color="red", zorder=1)
        ax.hlines(np.log10(cad_pvals[ind]) * -1, 0, 0.1, color="blue", zorder=1)
        ax.hlines(np.log10(scz_pvals[ind]) * -1, 0, 0.1, color="blue", zorder=1)
        ax.set_ylabel("")
        ax.set_title(str_ind)
        # Create a Rectangle patch
        xlim = ax.get_xlim()
        ylim= ax.get_ylim()
        rect = patches.Rectangle((0, value*0.99 if ind >2 else value*1.01), 0.3, -200 if ind > 2 else 200,  linewidth=0, facecolor='white', alpha=0.7, zorder=5)


        # Add the patch to the Axes
        ax.add_patch(rect)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        if ind ==0:
            ax.set_ylabel("-log(p)")
    else:
        legend_elements = [Patch(facecolor='red', edgecolor='red',
                                label='Traits Match'),
                            Patch(facecolor='blue', edgecolor='blue',
                                    label='Trait Mismatch'),
                            Patch(facecolor='gray', edgecolor='gray',
                                    label='95th Percentile')]

        leg = ax.legend(handles=legend_elements, loc='center', title="p-Values", fontsize=8, title_fontsize=8, ncol=1, columnspacing=1.7, handletextpad=-0.5, labelspacing=1.7)

        for patch in leg.get_patches():
            patch.set_height(10)
            patch.set_width(10)
            patch.set_y(-2.5)
        ax.set_axis_off()
        
    #ax.set_xscale("log")
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.patches as patches
from scipy.stats import rankdata
import matplotlib.ticker as tck

fig, axes = plt.subplots(nrows=1, ncols=7, figsize=(full_width*cm, 5*cm))
df = pd.read_csv("random_labelprop_target_uc_film_nohetio.tsv", index_col=0, header=0, sep="\t")

for ax, str_ind, ind in zip(axes.tolist(), df.columns.tolist() + ["None"], range(7)):
    if ind < 6:
        ax = sns.kdeplot(y = np.log10(df[str_ind]) * -1, cut=0, fill="lightblue", ax=ax)
        if ind > 2:
            value = np.quantile(np.log10(df[str_ind]) * -1, 0.95)
            pval = rankdata(df[str_ind])[df[str_ind] == df.loc["UC", str_ind]] / len(df)
        else:
            value = np.quantile(np.log10(df[str_ind]) * -1, 0.05)
            pval = 1 - (rankdata(df[str_ind])[df[str_ind] == df.loc["UC", str_ind]] / len(df))
        
        text = "p={:.3f}".format(pval.item())
        ax.hlines(value, 0, 0.05, color="#5a5a5a", zorder=2)
        ax.hlines(np.log10(df.loc["UC", str_ind]) * -1, 0, 0.1, color="#d8031c", zorder=1)
        ax.hlines(np.log10(df.loc["CAD", str_ind]) * -1, 0, 0.1, color="#2b5d34", zorder=1)
        ax.hlines(np.log10(df.loc["SCZ", str_ind]) * -1, 0, 0.1, color="#2b5d34", zorder=1)
        ax.set_ylabel("")
        #ax.set_title(str_ind)
        # Create a Rectangle patch
        xlim = ax.get_xlim()
        ylim= ax.get_ylim()
        ax.hlines(value, 0, xlim[1]*0.4, color="#5a5a5a", zorder=2)
        ax.hlines(np.log10(df.loc["UC", str_ind]) * -1, 0, xlim[1]*0.8, color="#d8031c", zorder=1)
        ax.hlines(np.log10(df.loc["CAD", str_ind]) * -1, 0, xlim[1]*0.8, color="#2b5d34", zorder=1)
        ax.hlines(np.log10(df.loc["SCZ", str_ind]) * -1, 0, xlim[1]*0.8, color="#2b5d34", zorder=1)
        rect = patches.Rectangle((0, value*0.99 if ind >2 else value*1.01), 0.3, -200 if ind > 2 else 200,  linewidth=0, facecolor='white', alpha=0.7, zorder=5)
        ax.text(x=np.mean(xlim), y=ylim[1] * 0.9, s=text, fontsize=5, zorder=7, ha="center")

        # Add the patch to the Axes
        ax.add_patch(rect)
        ax.set_xlim(xlim)
        ax.set_ylim((-0.05, ylim[1]))
        nticks = int(ylim[1] / 8)
        ax.yaxis.set_major_locator(tck.MultipleLocator(nticks))
        if ind ==0:
            ax.set_ylabel("-log(p)")
    else:
        legend_elements = [Patch(facecolor='#d8031c', edgecolor='#d8031c',
                                label='UC HSPs'),
                            Patch(facecolor='#2b5d34', edgecolor='#2b5d34',
                                    label='CAD/SCZ HSPs'),
                            Patch(facecolor='#5a5a5a', edgecolor='#5a5a5a',
                                    label='5th/95th Percentile')]

        leg = ax.legend(handles=legend_elements, loc='center', title="p-Values", fontsize=8, title_fontsize=8, ncol=1, columnspacing=1.7, handletextpad=-0.2, labelspacing=1.7)

        for patch in leg.get_patches():
            patch.set_height(10)
            patch.set_width(10)
            patch.set_y(-2.5)
        ax.set_axis_off()
        
    #ax.set_xscale("log")
plt.tight_layout()
plt.subplots_adjust(wspace=0.4)
plt.savefig("pvals_labelprop_uc.svg")


In [None]:
import matplotlib.patches as patches

fig, axes = plt.subplots(nrows=1, ncols=7, figsize=(full_width*cm, 5*cm))
df = pd.read_csv("random_labelprop_target_cad_really_film_nohetio.tsv", index_col=0, header=0, sep="\t")

for ax, str_ind, ind in zip(axes.tolist(), df.columns.tolist() + ["None"], range(7)):
    if ind < 6:
        ax = sns.kdeplot(y = np.log10(df[str_ind]) * -1, cut=0, fill="lightblue", ax=ax)
        if ind > 2:
            value = np.quantile(np.log10(df[str_ind]) * -1, 0.95)
            pval = rankdata(df[str_ind])[df[str_ind] == df.loc["CAD", str_ind]] / len(df)
        else:
            value = np.quantile(np.log10(df[str_ind]) * -1, 0.05)
            pval = 1 - (rankdata(df[str_ind])[df[str_ind] == df.loc["CAD", str_ind]] / len(df))
        
        text = "p={:.3f}".format(pval.item())
        ax.hlines(value, 0, 0.05, color="gray", zorder=2)
        ax.hlines(value, 0, 0.05, color="gray", zorder=2)
        ax.hlines(np.log10(df.loc["CAD", str_ind]) * -1, 0, 0.1, color="red", zorder=1)
        ax.hlines(np.log10(df.loc["UC", str_ind]) * -1, 0, 0.1, color="blue", zorder=1)
        ax.hlines(np.log10(df.loc["SCZ", str_ind]) * -1, 0, 0.1, color="blue", zorder=1)
        ax.set_ylabel("")
        ax.set_title(str_ind)
        # Create a Rectangle patch
        xlim = ax.get_xlim()
        ylim= ax.get_ylim()
        rect = patches.Rectangle((0, value*0.99 if ind >2 else value*1.01), 1.5, -200 if ind > 2 else 200,  linewidth=0, facecolor='white', alpha=0.7, zorder=5)

        ax.text(x=np.mean(xlim), y=ylim[1] * 0.9, s=text, fontsize=5, zorder=7, ha="center")
        # Add the patch to the Axes
        ax.add_patch(rect)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        if ind ==0:
            ax.set_ylabel("-log(p)")
    else:
        legend_elements = [Patch(facecolor='red', edgecolor='red',
                                label='Traits Match'),
                            Patch(facecolor='blue', edgecolor='blue',
                                    label='Trait Mismatch'),
                            Patch(facecolor='gray', edgecolor='gray',
                                    label='5th/95th\nPercentile')]

        leg = ax.legend(handles=legend_elements, loc='center', title="p-Values", fontsize=8, title_fontsize=8, ncol=1, columnspacing=1.7, handletextpad=-0.5, labelspacing=1.7)

        for patch in leg.get_patches():
            patch.set_height(10)
            patch.set_width(10)
            patch.set_y(-2.5)
        ax.set_axis_off()
        
    #ax.set_xscale("log")
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.patches as patches

fig, axes = plt.subplots(nrows=1, ncols=7, figsize=(full_width*cm, 5*cm))
df = pd.read_csv("random_labelprop_target_scz_film_nohetio.tsv", index_col=0, header=0, sep="\t")

for ax, str_ind, ind in zip(axes.tolist(), df.columns.tolist() + ["None"], range(7)):
    if ind < 6:
        ax = sns.kdeplot(y = np.log10(df[str_ind]) * -1, cut=0, fill="lightblue", ax=ax)
        if ind > 2:
            value = np.quantile(np.log10(df[str_ind]) * -1, 0.95)
            pval = rankdata(df[str_ind])[df[str_ind] == df.loc["CAD", str_ind]] / len(df)
        else:
            value = np.quantile(np.log10(df[str_ind]) * -1, 0.05)
            pval = 1 - (rankdata(df[str_ind])[df[str_ind] == df.loc["CAD", str_ind]] / len(df))
        
        text = "p={:.3f}".format(pval.item())
        ax.hlines(value, 0, 0.05, color="gray", zorder=2)
        ax.hlines(value, 0, 0.05, color="gray", zorder=2)
        ax.hlines(np.log10(df.loc["CAD", str_ind]) * -1, 0, 0.1, color="red", zorder=1)
        ax.hlines(np.log10(df.loc["UC", str_ind]) * -1, 0, 0.1, color="blue", zorder=1)
        ax.hlines(np.log10(df.loc["SCZ", str_ind]) * -1, 0, 0.1, color="blue", zorder=1)
        ax.set_ylabel("")
        ax.set_title(str_ind)
        # Create a Rectangle patch
        xlim = ax.get_xlim()
        ylim= ax.get_ylim()
        rect = patches.Rectangle((0, value*0.99 if ind >2 else value*1.01), 1.5, -200 if ind > 2 else 200,  linewidth=0, facecolor='white', alpha=0.7, zorder=5)

        ax.text(x=np.mean(xlim), y=ylim[1] * 0.9, s=text, fontsize=5, zorder=7, ha="center")
        # Add the patch to the Axes
        ax.add_patch(rect)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        if ind ==0:
            ax.set_ylabel("-log(p)")
    else:
        legend_elements = [Patch(facecolor='red', edgecolor='red',
                                label='Traits Match'),
                            Patch(facecolor='blue', edgecolor='blue',
                                    label='Trait Mismatch'),
                            Patch(facecolor='gray', edgecolor='gray',
                                    label='5th/95th\nPercentile')]

        leg = ax.legend(handles=legend_elements, loc='center', title="p-Values", fontsize=8, title_fontsize=8, ncol=1, columnspacing=1.7, handletextpad=-0.5, labelspacing=1.7)

        for patch in leg.get_patches():
            patch.set_height(10)
            patch.set_width(10)
            patch.set_y(-2.5)
        ax.set_axis_off()
        
    #ax.set_xscale("log")
plt.tight_layout()
plt.show()


# Check Edgetypes and connectivities of Core Genes

In [None]:
import networkx as nx

hsps = pd.read_csv("hsps/uc.txt", header=None, index_col=None, sep="\t")
hsp_indices = [prepro.hgnc2id[hgnc] for hgnc in hsps.iloc[:, 0]]
new_y = torch.zeros_like(dataset.data.y)
new_y[np.asarray(hsp_indices)] = 1



centrality = nx.degree_centrality(G)
core_centrality = [centrality[idx.item()] for idx in coregenes.nonzero() if idx.item() in centrality.keys()]
hsp_centrality = [centrality[idx] for idx in hsp_indices if idx in centrality.keys()]
peripheral_centrality = [centrality[idx.item()] for idx in (torch.ones_like(new_y) - coregenes - new_y - coregenes_weak).nonzero() if idx.item() in centrality.keys()]
fig, ax = plt.subplots()

ax.boxplot([core_centrality, hsp_centrality, peripheral_centrality], positions=[0,1,2])

mannwhitneyu(core_centrality, peripheral_centrality)


In [None]:
centrality = nx.out_degree_centrality(G)
core_centrality = [centrality[idx.item()] for idx in coregenes.nonzero() if idx.item() in centrality.keys()]
hsp_centrality = [centrality[idx] for idx in hsp_indices if idx in centrality.keys()]
peripheral_centrality = [centrality[idx.item()] for idx in (torch.ones_like(new_y) - coregenes - new_y - coregenes_weak).nonzero() if idx.item() in centrality.keys()]
fig, ax = plt.subplots()

ax.boxplot([core_centrality, hsp_centrality, peripheral_centrality], positions=[0,1,2])

mannwhitneyu(core_centrality, peripheral_centrality)

In [None]:
centrality = nx.in_degree_centrality(G)
core_centrality = [centrality[idx.item()] for idx in coregenes.nonzero() if idx.item() in centrality.keys()]
hsp_centrality = [centrality[idx] for idx in hsp_indices if idx in centrality.keys()]
peripheral_centrality = [centrality[idx.item()] for idx in (torch.ones_like(new_y) - coregenes - new_y - coregenes_weak).nonzero() if idx.item() in centrality.keys()]
fig, ax = plt.subplots()

ax.boxplot([core_centrality, hsp_centrality, peripheral_centrality], positions=[0,1,2])

mannwhitneyu(core_centrality, peripheral_centrality)

In [None]:
centrality = nx.pagerank(G)
core_centrality = [centrality[idx.item()] for idx in coregenes.nonzero() if idx.item() in centrality.keys()]
hsp_centrality = [centrality[idx] for idx in hsp_indices if idx in centrality.keys()]
peripheral_centrality = [centrality[idx.item()] for idx in (torch.ones_like(new_y) - coregenes - new_y - coregenes_weak).nonzero() if idx.item() in centrality.keys()]
fig, ax = plt.subplots()

ax.boxplot([core_centrality, hsp_centrality, peripheral_centrality], positions=[0,1,2])

mannwhitneyu(core_centrality, peripheral_centrality)

In [None]:
centrality = nx.betweenness_centrality(G, 20)
core_centrality = [centrality[idx.item()] for idx in coregenes.nonzero() if idx.item() in centrality.keys()]
hsp_centrality = [centrality[idx] for idx in hsp_indices if idx in centrality.keys()]
peripheral_centrality = [centrality[idx.item()] for idx in (torch.ones_like(new_y) - coregenes - new_y - coregenes_weak).nonzero() if idx.item() in centrality.keys()]
fig, ax = plt.subplots()

ax.boxplot([core_centrality, hsp_centrality, peripheral_centrality], positions=[0,1,2])

print(mannwhitneyu(core_centrality, peripheral_centrality))
print(mannwhitneyu(core_centrality, hsp_centrality))
print(mannwhitneyu(peripheral_centrality, hsp_centrality))

In [None]:
centrality = nx.load_centrality(G)
core_centrality = [centrality[idx.item()] for idx in coregenes.nonzero() if idx.item() in centrality.keys()]
hsp_centrality = [centrality[idx] for idx in hsp_indices if idx in centrality.keys()]
peripheral_centrality = [centrality[idx.item()] for idx in (torch.ones_like(new_y) - coregenes - new_y - coregenes_weak).nonzero() if idx.item() in centrality.keys()]
fig, ax = plt.subplots()

ax.boxplot([core_centrality, hsp_centrality, peripheral_centrality], positions=[0,1,2])

mannwhitneyu(core_centrality, peripheral_centrality)

In [None]:
from tqdm.notebook import tqdm
outgoing = {}
for key, edge_index in tqdm(dataset.data.edge_index_dict.items()):
    key = key[1]
    outgoing[key] = {}
    for index in coregenes.nonzero():
        values = edge_index[1, :][edge_index[0, :] == index].tolist()
        if len(values) > 0:
            outgoing[key][prepro.id2hgnc[index.item()]] = [prepro.id2hgnc[value] for value in values]

        


In [None]:
from tqdm.notebook import tqdm
outgoing_background = {}
for key, edge_index in tqdm(dataset.data.edge_index_dict.items()):
    key = key[1]
    outgoing_background[key] = len(edge_index[0, :].unique().tolist())
        


In [None]:
for key, value in outgoing.items():
    print("{}: {}".format(key, len(value.keys())))

In [None]:
outgoing_background

In [None]:
from scipy.stats import fisher_exact
from statsmodels.stats.multitest import fdrcorrection

adj_fisher = {}
for key in dataset.data.edge_index_dict.keys():
    total_n_core = coregenes.sum().item()
    total_n_genes = coregenes.shape[0]
    tf_and_core = len(outgoing[key[1]].keys())
    total_tf = outgoing_background[key[1]]

    #               TF
    #            Yes    No
    # Core  Yes
    #       No

    array = np.asarray([[tf_and_core, total_n_core - tf_and_core],
            [total_tf- tf_and_core, total_n_genes - total_tf - total_n_core + tf_and_core]])
    
    result = fisher_exact(array)
    adj_fisher[key[1]] = [array[0,0], array[0,1], array[1,0], array[1,1], result[0], result[1]]

tf_df = pd.DataFrame.from_dict(adj_fisher, orient="index", columns=["Core_and_Out", "Core_not_Out", "not_Core_and_Out", "not_Core_not_Out", "OR", "pval"])
tf_df["FDR"] = fdrcorrection(tf_df["pval"])[1]
tf_df


In [None]:
fig, ax = plt.subplots(figsize=(full_width*cm*0.5, 3))

tf_df["log_OR"] = np.log10(tf_df["OR"])
ax.bar(range(len(tf_df["log_OR"])), tf_df["log_OR"], color="#01016f", zorder=3, width=1, edgecolor="white")
for i, label in enumerate(tf_df.index):
    ax.text(i+0.2, +0.02, pretty_names[label], va="bottom", ha="center", rotation=90, fontsize=5, color="white", zorder=5)
ax.set_yscale("symlog")
ax.set_yticks(np.log10([0.01, 0.1, 0.5, 1, 2, 4]))
ax.set_yticklabels([0.01, 0.1, 0.5, 1, 2, 4])
ax.grid(axis="y", linestyle="--", color="lightgray", zorder=-5)
ax.set_ylabel("Odds Ratio")
ax.set_xticks([])
ax.hlines(0, -0.75, len(tf_df)-0.25, color="black", linewidth=0.5)
ax.set_xlim((-0.75, len(tf_df)-0.25))
#ax.set_yticks(ax.get_yticks())
#ax.set_yticklabels(ticklabels)
plt.savefig("edge_frequency.svg")

In [None]:
tf_df.to_csv("edge_frequency_df.tsv", sep="\t")

# Try with edge frequency instead

In [None]:
from tqdm.notebook import tqdm
outgoing_background_frequency = {}
for key, edge_index in tqdm(dataset.data.edge_index_dict.items()):
    key = key[1]
    outgoing_background_frequency[key] = len(edge_index[0, :].tolist())
        


In [None]:
adj_fisher = {}
for key in dataset.data.edge_index_dict.keys():
    total_edges_core = np.sum(list(chain(*[[len(value) for value in values.values()] for values in outgoing.values()])))
    total_edges = edge_index_flat.shape[1]
    tf_and_core = np.sum([len(value) for value in outgoing[key[1]].values()])
    total_tf = outgoing_background_frequency[key[1]]
    #               edges
    #               Yes  No
    #   Core    Yes
    #           No

    array = np.asarray([[tf_and_core, total_edges_core - tf_and_core],
            [total_tf- tf_and_core, total_edges - total_tf - total_edges_core + tf_and_core]])
    
    result = fisher_exact(array)
    adj_fisher[key[1]] = array.flatten().tolist() + [result[0], result[1]]

tf_df = pd.DataFrame.from_dict(adj_fisher, orient="index", columns=["Core_and_Out", "Core_not_Out", "not_Core_and_Out", "not_Core_not_Out", "OR", "pval"])
tf_df["FDR"] = fdrcorrection(tf_df["pval"])[1]
tf_df


In [None]:
prepro.hgnc2id["FCGR2A"]

In [None]:
tf_df["log_OR"]

In [None]:
tfs = []
for key, value in outgoing.items():
    if key.startswith("GRNDB"):
        tfs.extend(list(value.keys()))

In [None]:
from collections import Counter

counter = Counter(tfs)
counter

In [None]:


countcounter = Counter(counter.values())
countcounter

fig, ax = plt.subplots(figsize=(full_width*cm*0.5, 6*cm))

labels = []
for i, (key, value) in enumerate(sorted(countcounter.items())[::-1]):
    ax.bar(i, height=value)
    labels.append(key)

ax.set_xticks(range(20))
ax.set_xticklabels(labels)
ax.set_ylabel("Number of Core Gene TFS")
ax.set_xlabel("TF in # Tissues (out of 27)")
plt.savefig("tfs_per_tissue.svg", bbox_inches="tight")

In [None]:
tfs_withedges = {}
for key, value in outgoing.items():
    if key.startswith("GRNDB"):
        for tf, targets in value.items():
            try:
                tfs_withedges[tf].append(len(targets))
            except KeyError:
                tfs_withedges[tf] = [len(targets)]


#edgecounter = Counter(tfs)
#edgecounter

In [None]:
tfs_withedges

In [None]:
tfs_aggregated = {}
for tf, value in counter.items():
    tfs_aggregated[tf] = np.sum(tfs_withedges[tf])
tfs_aggregated

In [None]:
tfs_aggregated_by_num_adjacencies = {value: 0 for value in counter.values()}
for tf, value in counter.items():
    tfs_aggregated_by_num_adjacencies[value] += tfs_aggregated[tf]
tfs_aggregated_by_num_adjacencies

In [None]:
fig, ax = plt.subplots(figsize=(full_width*cm*0.5, 6*cm))

labels = []
cumsum = [0]
for i, (key, value) in enumerate(sorted(tfs_aggregated_by_num_adjacencies.items())[::-1]):
    ax.bar(i, height=value)
    labels.append(key)
    cumsum.append(cumsum[i] + value)

ax.set_xticks(range(20))
ax.set_xticklabels(labels)
ax.set_ylabel("Number of Edges outgoing\nfrom Core Gene TF")
ax.set_xlabel("TF in # Tissues (out of 27)")
print(np.asarray(cumsum) / cumsum[-1])
plt.savefig("tf_edges_per_tissue.svg", bbox_inches="tight")

In [None]:
unspecific_TFS = set([key for key, value in counter.items() if value >24])

In [None]:
len(unspecific_TFS)

In [None]:
from scipy.stats import fisher_exact
from statsmodels.stats.multitest import fdrcorrection

adj_fisher = {}
for key in dataset.data.edge_index_dict.keys():
    total_n_core = coregenes.sum().item()
    total_n_genes = coregenes.shape[0]
    if key[1].startswith("GRNDB"):
        tf_and_core = len(set(outgoing[key[1]].keys()).difference(unspecific_TFS))
    else:
        tf_and_core = len(outgoing[key[1]].keys())
    total_tf = outgoing_background[key[1]]

    #               TF
    #            Yes    No
    # Core  Yes
    #       No

    array = np.asarray([[tf_and_core, total_n_core - tf_and_core],
            [total_tf- tf_and_core, total_n_genes - total_tf - total_n_core + tf_and_core]])
    
    result = fisher_exact(array)
    adj_fisher[key[1]] = [array[0,0], array[0,1], array[1,0], array[1,1], result[0], result[1]]

tf_df_corrected = pd.DataFrame.from_dict(adj_fisher, orient="index", columns=["Core_and_Out", "Core_not_Out", "not_Core_and_Out", "not_Core_not_Out", "OR", "pval"])
tf_df_corrected["FDR"] = fdrcorrection(tf_df_corrected["pval"])[1]
tf_df_corrected


In [None]:
fig, ax = plt.subplots(figsize=(full_width*cm*0.5, 6*cm))

tf_df["log_OR"] = np.log10(tf_df["OR"])
for i, (log_or, fdr) in enumerate(zip(tf_df["log_OR"], tf_df["FDR"])):
    ax.bar(i, log_or, color="#01016f" if fdr <0.05 else "gray", zorder=3, width=1, edgecolor="white")
for i, (label, log_or) in enumerate(zip(tf_df.index, tf_df["OR"])):
    ax.text(i+0.18,np.log10(log_or) + 0.02 if label.startswith("GRNDB") and log_or > 1 else 0.02, label, va="bottom", ha="center", rotation=90, fontsize=6, color="black" if label.startswith("GRNDB") else "white" , zorder=5)
ax.set_yscale("symlog")
ax.set_yticks(np.log10([0.01, 0.1, 0.5, 1, 2, 4]))
ax.set_yticklabels([0.01, 0.1, 0.5, 1, 2, 4])
ax.grid(axis="y", linestyle="--", color="lightgray", zorder=-5)
ax.set_ylabel("Odds Ratio")
ax.set_xticks([])

ax.hlines(0, -0.75, len(tf_df)-0.25, color="black", linewidth=0.5)
ax.set_xlim((-0.75, len(tf_df)-0.25))
#ax.set_yticks(ax.get_yticks())
#ax.set_yticklabels(ticklabels)
plt.savefig("edge_frequency_minus_unspec_dfs.svg")

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(full_width*cm, 6*cm), sharey=True)

tf_df["log_OR"] = np.log10(tf_df["OR"])
ax1.bar(range(len(tf_df["log_OR"])), tf_df["log_OR"], color="#01016f", zorder=3, width=1, edgecolor="white", linewidth=0.5)
for i, label in enumerate(tf_df.index):
    ax1.text(i+0.115, +0.01, pretty_names[label], va="bottom", ha="center", rotation=90, fontsize=6, color="white", zorder=5)
ax1.set_yscale("symlog")
ax1.set_yticks(np.log10([0.01, 0.1, 0.5, 1, 2, 4, 6]))
ax1.set_yticklabels([0.01, 0.1, 0.5, 1, 2, 4, 6])
ax1.grid(axis="y", linestyle="--", color="lightgray", zorder=-5)
ax1.set_ylabel(r"Odds Ratio of $d_{out} > 0$")
ax1.set_xticks([])
ax1.hlines(0, -0.75, len(tf_df)-0.25, color="black", linewidth=0.5)
ax1.set_xlim((-0.75, len(tf_df)-0.25))
ax1.set_xlabel("Subnetworks")

tf_df_corrected["log_OR"] = np.log10(tf_df_corrected["OR"])
for i, (log_or, fdr) in enumerate(zip(tf_df_corrected["log_OR"], tf_df_corrected["FDR"])):
    ax2.bar(i, log_or, color="#01016f" if fdr <0.05 else "gray", zorder=3, width=1, edgecolor="white", linewidth=0.5)
for i, (label, log_or) in enumerate(zip(tf_df_corrected.index, tf_df_corrected["OR"])):
    ax2.text(i+0.115,np.log10(log_or) + 0.01 if label.startswith("GRNDB") and log_or > 1 else 0.02, pretty_names[label], va="bottom", ha="center", rotation=90, fontsize=6, color="black" if label.startswith("GRNDB") else "white" , zorder=5)
ax2.set_yscale("symlog")
ax2.set_yticks(np.log10([0.01, 0.1, 0.5, 1, 2, 4, 6]))
ax2.set_yticklabels([0.01, 0.1, 0.5, 1, 2, 4, 6])
ax2.grid(axis="y", linestyle="--", color="lightgray", zorder=-5)
ax2.set_xticks([])

ax2.hlines(0, -0.75, len(tf_df)-0.25, color="black", linewidth=0.5)
ax2.set_xlim((-0.75, len(tf_df)-0.25))
ax2.set_xlabel("Subnetworks")
ax1.spines['top'].set_visible(False)
ax1.spines['bottom'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['bottom'].set_visible(False)
plt.subplots_adjust(wspace=0.05)

plt.savefig("edge_frequency_both.svg", bbox_inches="tight")

In [None]:
from tqdm.notebook import tqdm
from itertools import chain
firsthop_outgoing = {}
new_start_nodes = set()
key = list(outgoing.keys())[0]
new_start_nodes.update(set(chain(*list(outgoing[key].values()))))

for key, edge_index in tqdm(dataset.data.edge_index_dict.items()):
    key = key[1]
    firsthop_outgoing[key] = {}
    for index in tqdm(new_start_nodes):
        values = edge_index[1, :][edge_index[0, :] == prepro.hgnc2id[index]].tolist()
        if len(values) > 0:
            firsthop_outgoing[key][index] = [prepro.id2hgnc[value] for value in values]

        


In [None]:
import json

with open("/mnt/storage/speos/firsthop_bioplexhct_outgoing.json", "w") as outfile: 
    json.dump(firsthop_outgoing, outfile)

In [None]:
import json
from itertools import chain

with open("/mnt/storage/speos/firsthop_bioplexhct_outgoing.json", "r") as outfile: 
    firsthop_outgoing = json.load(outfile)

new_start_nodes = set()
key = list(outgoing.keys())[0]
new_start_nodes.update(set(chain(*list(outgoing[key].values()))))

In [None]:
len(set(chain(*list(outgoing[key].values()))))

In [None]:
adj_fisher = {}

for key in dataset.data.edge_index_dict.keys():
    total_firsthop = len(new_start_nodes)
    total_n_genes = coregenes.shape[0]
    tf_and_firsthop = len(set(firsthop_outgoing[key[1]].keys()))
    total_tf = outgoing_background[key[1]]

    #           Transcrtiption Factor
    #          Yes  No
    #   FH  Yes
    #       No

    array = np.asarray([[tf_and_firsthop           , total_firsthop - tf_and_firsthop],
                        [total_tf - tf_and_firsthop, total_n_genes - total_tf - total_firsthop + tf_and_firsthop]])

    assert array[0, :].sum() == total_firsthop
    assert array[1, :].sum() == total_n_genes - total_firsthop
    assert array[:, 0].sum() == total_tf
    assert array[:, 1].sum() == total_n_genes - total_tf
    
    result = fisher_exact(array)
    adj_fisher[key[1]] = array.flatten().tolist() + [result[0], result[1]]

firsthop_tf_df = pd.DataFrame.from_dict(adj_fisher, orient="index", columns=["1Hop_and_Out", "1Hop_not_Out", "not_1Hop_and_Out", "not_1Hop_not_Out", "OR", "pval"])
firsthop_tf_df["FDR"] = fdrcorrection(firsthop_tf_df["pval"])[1]
firsthop_tf_df


In [None]:
fig, ax = plt.subplots(figsize=(full_width*cm, 3))

firsthop_tf_df["log_OR"] = np.log10(firsthop_tf_df["OR"])
firsthop_tf_df.iloc[0, -1] = np.log10(30)
print(firsthop_tf_df["log_OR"])
for i, (log_or, fdr) in enumerate(zip(firsthop_tf_df["log_OR"], firsthop_tf_df["FDR"])):
    ax.bar(i, log_or, color="#01016f" if fdr <0.05 else "gray", zorder=3)
for i, (label, log_or) in enumerate(zip(firsthop_tf_df.index, firsthop_tf_df["OR"])):
    ax.text(i+0.05,np.log10(log_or) + 0.02 if label.startswith("GRNDB") and log_or > 1 else 0.02, label, va="bottom", ha="center", rotation=90, fontsize=6, color="black" if label.startswith("GRNDB") else "white" , zorder=5)
ax.set_ylim(top = np.log10(25))
ax.set_ylim(bottom = 1)
ax.set_yscale("symlog")
ax.set_yticks(np.log10([ 0.5, 1, 2, 4, 10, 20, ]))
ax.set_yticklabels([ 0.5, 1, 2, 4,  10, 20, ])

ax.grid(axis="y", linestyle="--", color="lightgray", zorder=-5)
ax.set_ylabel("Odds Ratio")
ax.set_xticks([])
ax.hlines(0, -0.75, len(tf_df)-0.25, color="black", linewidth=0.5)
ax.set_xlim((-0.75, len(tf_df)-0.25))
#ax.set_yticks(ax.get_yticks())
#ax.set_yticklabels(ticklabels)
plt.savefig("edge_frequency_onehop.svg")

In [None]:
adj_fisher = {}

for key in dataset.data.edge_index_dict.keys():
    total_firsthop = len(new_start_nodes)
    total_n_genes = coregenes.shape[0]
    tf_and_firsthop = len(set(firsthop_outgoing[key[1]].keys()))
    total_tf = outgoing_background[key[1]]

    #           Transcrtiption Factor
    #          Yes  No
    #   FH  Yes
    #       No

    array = np.asarray([[tf_and_firsthop           , total_firsthop - tf_and_firsthop],
                        [total_tf - tf_and_firsthop, total_n_genes - total_tf - total_firsthop + tf_and_firsthop]])

    assert array[0, :].sum() == total_firsthop
    assert array[1, :].sum() == total_n_genes - total_firsthop
    assert array[:, 0].sum() == total_tf
    assert array[:, 1].sum() == total_n_genes - total_tf
    
    result = fisher_exact(array)
    adj_fisher[key[1]] = array.flatten().tolist() + [result[0], result[1]]

firsthop_tf_df = pd.DataFrame.from_dict(adj_fisher, orient="index", columns=["1Hop_and_Out", "1Hop_not_Out", "not_1Hop_and_Out", "not_1Hop_not_Out", "OR", "pval"])
firsthop_tf_df["FDR"] = fdrcorrection(firsthop_tf_df["pval"])[1]
firsthop_tf_df


In [None]:
for newkey, edge_index in tqdm(dataset.data.edge_index_dict.items()):
    if newkey == key:
        print(len(edge_index[0,:].unique().tolist()))

In [None]:
total_n_genes

In [None]:
total_tf

In [None]:
total_firsthop

In [None]:
tf_and_firsthop

In [None]:
total_n_genes - total_tf - total_firsthop + tf_and_firsthop

In [None]:
(edge_index_flat[0, :] == 10612).sum()

In [None]:
prepro.hgnc2id["PARK7"]

In [None]:
dataset.data.y.nonzero().squeeze()

# Use Edge Attributes

In [None]:
import torch

genes = []
edge_attributions = []

for gene in [prepro.id2hgnc[idx.item()] for idx in coregenes.nonzero()]:
    try:
        edge_attributions.append(torch.load("/mnt/storage/speos/explanations/uc_film_nohetio_ig_attr_edge_total_{}.pt".format(gene)).detach().float().cpu().numpy())
        genes.append(gene)
    except (FileNotFoundError, RuntimeError):
        continue

In [None]:
len(edge_attributions)

In [None]:
import pandas as pd
edge_df = pd.DataFrame({"from": edge_index.storage.row().tolist(),
                        "to": edge_index.storage.col().tolist(),
                        "type":  encoder.inverse_transform(edge_index.storage.value().long().tolist()).tolist()})

In [None]:
edge_attributions = np.asarray(edge_attributions)
edge_attributions.shape

In [None]:
edge_attributions_tensor = torch.Tensor(edge_attributions.max(axis=0))
torch.save(edge_attributions_tensor, "edge_attributions_tensor_UC.pt")

In [None]:
(edge_attributions > 0.01).sum()

In [None]:
import numpy as np

edge_df["avg_attr"] = edge_attributions.mean(axis=0)

In [None]:
edge_df["avg_attr"].sort_values(ascending=False)[:100]

In [None]:
important_edges = {}

for i, gene in enumerate(genes):
    important_indices = (edge_attributions[i, :] > 0.9).nonzero()[0]
    important_edges[gene] = important_indices, edge_attributions[i, :][important_indices]

In [None]:
important_edges

In [None]:
with open("disease_network_09.txt", "w") as file:
    for gene, (indices, values) in important_edges.items():
        file.writelines("{}\t{}\t{}\t{}\n".format(prepro.id2hgnc[sender], prepro.id2hgnc[receiver], edgetype, value) for sender, receiver, edgetype, value in zip(edge_df["from"][indices], edge_df["to"][indices], edge_df["type"][indices], values))

In [None]:
from collections import Counter

count_dfs = []
total_counts = []
num_genes = []

for level in ["75", "5", "25", "1", "01"]:
    disease_edges = pd.read_csv("disease_network_0{}.txt".format(level), sep="\t", index_col=False, header=None, names=["from", "to", "type", "weight"])
    disease_edges = disease_edges.groupby(["from", "to", "type"]).agg("max").reset_index()
    counter = Counter(disease_edges["type"])
    count_df = pd.DataFrame.from_dict(counter, orient="index", columns=[level])
    count_df[level] /= count_df[level].sum()
    count_dfs.append(count_df)
    total_counts.append(len(disease_edges))

    num_genes.append(len(set(disease_edges["to"].tolist()).union(set(disease_edges["from"].tolist()))))

count_dfs = pd.concat(count_dfs, axis=1, join="outer").fillna(0).sort_values(by="75", ascending=True)
    

In [None]:
count_dfs["0"] = [dataset.data.edge_index_dict[("gene", adj, "gene")].shape[1] for adj in count_dfs.index]
total_counts.append(count_dfs["0"].sum())
count_dfs["0"] /= count_dfs["0"].sum()

num_genes.append(edge_index_flat.flatten().unique().shape[0])



In [None]:
count_dfs.index

In [None]:
pretty_names = {
    "BioPlex30293T": "BioPlex 3.0 HEK293T",
    "BioPlex30HCT116": "BioPlex 3.0 HCT116",
    "HuRI": "HuRI",
    'GRNDBadrenalgland': 'GRNDB Adrenal Gland',
    'GRNDBbloodvessel': "GRNDB Blood Vessel",
    'Recon3DDirected': "Recon 3D",
    'GRNDBsalivarygland': "GRNDB Salivary Gland", 
    'GRNDBsmallintestine': "GRNDB Small Intestine", 
    'GRNDButerus': "GRNDB Uterus",
    'GRNDBadiposetissue': 'GRNDB Adipose Tissue', 
    'GRNDBthyroid': 'GRNDB Thyroid', 
    'GRNDBstomach': "GRNDB Stomach", 
    'GRNDBcolon': "GRNDB Colon",
    'GRNDBovary': "GRNDB Ovary", 
    'GRNDBpituitary': "GRNDB Pituitary", 
    'GRNDBesophagus': "GRNDB Esophagus", 
    'GRNDBbrain': "GRNDB Brain",
    'GRNDBliver': "GRNDB Liver", 
    'GRNDBprostate': 'GRNDB Prostate', 
    'GRNDBheart': 'GRNDB Heart', 
    'GRNDBmuscle': 'GRNDB Muscle',
    'GRNDBkidney': "GRNDB Kidney",
    'GRNDBnerve': 'GRNDB Nerve', 
    'GRNDBbreast': "GRNDB Breast", 
    'GRNDBpancreas': "GRNDB Pancreas",
    'GRNDBtestis': "GRNDB Testis", 
    'GRNDBspleen': "GRNDB Spleen", 
    'GRNDBlung': "GRNDB Lung", 
    'GRNDBbloodx': "GRNDB Blood", 
    'GRNDBskin': "GRNDB Skin",
    'GRNDBvagina': "GRNDB Vagina"
}

class ColorCycler:
    def __init__(self, colors):
        self.state = 0
        self.colors = colors

    def next(self):
        color = self.colors[self.state]
        if self.state == len(self.colors) - 1:
            self.state = 0
        else:
            self.state += 1
        return color
    

In [None]:
from speos.visualization.settings import *
import matplotlib.pyplot as plt


#cycler = ColorCycler(["#01016f", "#89006b", "#d00053", "#f85732", "#ffa600"])

cycler = ColorCycler(["#000066", "#640069", "#9e0061", "#cc0052", "#eb3a3e", "#fc7225", "#ffa600"])

fig, (ax0, ax) = plt.subplots(nrows=2, figsize=(full_width*cm*0.5, 10*cm), sharex=True, gridspec_kw={'height_ratios': [1, 3]})

running = np.zeros((len(count_dfs.columns),))

ax0.plot(range(len(running)), num_genes, color="gray")
ax0.fill_between(range(len(running)), running, running+np.asarray(num_genes), color="gray", alpha=1)
ax0.set_ylabel("Genes")
ax0.set_ylim(bottom=0)
ax1 = ax0.twinx()
ax1.set_ylabel("% of Total\nNetwork")
ax1.set_ylim((0,1))
ax1.set_yticks((0, 0.2, 0.4, 0.6, 0.8, 1))
ax1.set_yticklabels((0, 20, 40, 60, 80, 100))
ax1.grid(axis="y", zorder=-5, linestyle=":")


for i, (idx, row) in enumerate(count_dfs.iterrows()):
    color = cycler.next()
    #line = ax.plot(range(len(running)), running+row.values, linewidth=1, color=color)
    ax.fill_between(range(len(running)), running, running+row.values, color=color, alpha=1)
    running += row.values
    ax.text(x=5.05, ha="left", y=i/len(count_dfs.index), s=pretty_names[idx], color=color, fontsize=5)

ax.set_xticks((0,1,2,3,4, 5))
ax.set_xticklabels(["{}\n(n={})".format(importance, num_edges) for importance, num_edges in  zip((".75", ".5", ".25", ".1", ".01", "0"), total_counts)])
ax.set_xlim((0,5))
ax.set_ylim((0,1))
ax.set_yticks((0, 0.2, 0.4, 0.6, 0.8, 1))
ax.set_yticklabels((0, 20, 40, 60, 80, 100))
ax.set_xlabel("Attributed Importance (>=)\n(Number of Edges)")
ax.set_ylabel("Percentage of Subnetwork")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.subplots_adjust(wspace=0, hspace=0.05)
plt.savefig("edge_importance_cycler2.svg", dpi=450, bbox_inches="tight")

In [None]:
edge_t

In [None]:
disease_edges_grouped = disease_edges.groupby(["from", "to", "type"]).agg("max").reset_index()

In [None]:
disease_edges_grouped

In [None]:
len(set(disease_edges_grouped["from"].tolist() + disease_edges_grouped["to"].tolist()))

In [None]:
import networkx as nx
disease_network = nx.MultiDiGraph()

for i, edge in disease_edges_grouped.iterrows():
    disease_network.add_edges_from(((edge[0], edge[1], edge[2]),), weight=[edge[3]])

In [None]:
disease_network["IRF3"]

In [None]:
len(list(nx.connected_components(nx.MultiGraph(disease_network))))

In [None]:
from collections import Counter

counter = Counter(disease_edges_grouped["type"])

In [None]:
disease_edges = pd.read_csv("disease_network_075.txt", sep="\t", index_col=False, header=None, names=["from", "to", "type", "weight"])
disease_edges = disease_edges.groupby(["from", "to", "type"]).agg("max").reset_index()

In [None]:
disease_edges.sort_values("weight")[-20:]

In [None]:
prepro.hgnc2id["LEMD3"] in coregenes.nonzero()