In [None]:
import os
os.chdir(os.path.expanduser("~/speos/"))
from speos.preprocessing.datasets import DatasetBootstrapper
from speos.utils.config import Config

config = Config()

path_to_config = "config_uc_only_nohetio_film_newstorage.yaml"
# path_to_config = "config_cad_only_film_newstorage.yaml"  # change this to your config path
config.parse_yaml(path_to_config)

dataset = DatasetBootstrapper(name=config.name, config=config).get_dataset()
dataset.preprocessor.build_graph(adjacency=False)

In [None]:
dataset.preprocessor.hgnc2id["MAML2"]

# Load pre-computed explanations

You will need to change the gene you want to inspect and the path to the explanation .pt files generated by the explanation script

In [None]:
import torch

gene = "HNF4A"
num_outer = 11
num_inner = 10
ig_attr_self_all = None
ig_attr_self_abs_all = None
num_processed = 0
for outer_fold in range(0, num_outer):
    for inner_fold in range(0, num_inner):
        if inner_fold == outer_fold:
            continue
        try:
            #print("looking for /mnt/storage/speos/results/{}_ig_attr_self_outer{}_inner{}_{}.pt".format(config.name, outer_fold, inner_fold, gene))
            ig_attr_self = torch.load("/mnt/storage/speos/explanations/{}_ig_attr_self_outer{}_inner{}_{}.pt".format(config.name, outer_fold, inner_fold, gene)).detach().cpu()
            #ig_attr_self_abs = torch.load("/mnt/storage/speos/results/{}_ig_attr_self_abs_outer{}_inner{}_{}.pt".format(config.name,outer_fold, inner_fold, gene))
            print("Loaded attributes for outer {} inner {}".format(outer_fold, inner_fold))
            ig_attr_self.requires_grad = False
            #ig_attr_self_abs.requires_grad = False
        except FileNotFoundError:
            print("Attributes for outer {} inner {} not yet calculated.".format(outer_fold, inner_fold))
            continue 

        num_processed += 1

        #if ig_attr_self_abs_all is None:
        #    ig_attr_self_abs_all = ig_attr_self_abs
        #else:
        #    ig_attr_self_abs_all += ig_attr_self_abs

        if ig_attr_self_all is None:
            ig_attr_self_all = ig_attr_self
        else:
            ig_attr_self_all += ig_attr_self

ig_attr_self_all /= num_processed
#ig_attr_self_abs_all /= num_processed

input_attributions = ig_attr_self_all.cpu().numpy().tolist()

In [None]:
# Or easier in one step
import torch

gene = "HNF4A"

input_attributions = torch.load("/mnt/storage/speos/explanations/{}_ig_attr_self_total_{}.pt".format(config.name, gene)).detach().cpu().numpy().tolist()

# Now prepare the input data for visualization

In [None]:
features = dataset.preprocessor.get_feature_names()

normalized_data = torch.empty(dataset.data.x.shape, dtype=torch.double)
normalized_data_positives = dataset.data.x / dataset.data.x.topk(dim=0, k =int(dataset.data.x.shape[0]/100), sorted=True).values[-1,:]
normalized_data_negatives = (dataset.data.x / dataset.data.x.topk(dim=0, k=int(dataset.data.x.shape[0]/100), sorted=True, largest=False).values[-1,:]) * -1

normalized_data[dataset.data.x >= 0] = normalized_data_positives[dataset.data.x >= 0]
normalized_data[dataset.data.x < 0] = normalized_data_negatives[dataset.data.x < 0]

q001 = normalized_data.quantile(0.01, dim=0)
q01 = normalized_data.quantile(0.1, dim=0)
q25 = normalized_data.quantile(0.25, dim=0)
q50 = normalized_data.quantile(0.5, dim=0)
q75 = normalized_data.quantile(0.75, dim=0)
q90 = normalized_data.quantile(0.9, dim=0)
q99 = normalized_data.quantile(0.99, dim=0)

input_values = normalized_data[dataset.preprocessor.hgnc2id[gene]].tolist()
#input_values = normalized_data[int(gene)].tolist()

# Sort everything by the highest explainer attribution

In [None]:
list1, list2, list3,  q001_ordered, q01_ordered, q25_ordered, q50_ordered, q75_ordered, q90_ordered, q99_ordered= zip(*sorted(zip(input_attributions, features, input_values, q001, q01, q25, q50, q75, q90, q99)))

In [None]:
import matplotlib as plt
"""
import matplotlib.font_manager as font_manager

# uncomment this part if you also want to set the font to Helvetica. However, you must first download it and place it in your font directory, which can be tricky

from matplotlib import get_cachedir
from glob import glob

font_dirs = ['/mnt/storage/anaconda3/envs/speos/lib/python3.10/site-packages/matplotlib/mpl-data/matplotlibrc']
font_files = font_manager.findSystemFonts(fontpaths=font_dirs)

for font_file in font_files:
    font_manager.fontManager.addfont(font_file)

dir_cache = get_cachedir()
for file in glob(f'{dir_cache}/*.cache') + glob(f'{dir_cache}/font*'):
    if not os.path.isdir(file): # don't dump the tex.cache folder... because dunno why
        os.remove(file)
        print(f'Deleted font cache {file}.')

# set font
plt.rcParams['font.family'] = 'Helvetica'

"""

full_width = 18
cen = 1/2.54
small_font = 6
medium_font = 8
large_font = 10
plt.rc('xtick', labelsize=small_font)
plt.rc('ytick', labelsize=small_font)
plt.rcParams['axes.linewidth'] = 0.4
plt.rcParams['ytick.major.size'] = 3
plt.rcParams['ytick.major.width'] = 0.5
plt.rcParams['ytick.minor.size'] = 2
plt.rcParams['ytick.minor.width'] = 0.3
plt.rcParams['xtick.major.size'] = 2
plt.rcParams['xtick.major.width'] = 0.3
plt.rcParams['xtick.minor.size'] = 1
plt.rcParams['xtick.minor.width'] = 0.1




# And visualize it

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from matplotlib import cm
from numpy import linspace
import os

orange = '#fd5800'
gray = "#71797E"

get_top_n = 10
get_last_n = 5

labels = list2[:get_last_n] + list2[-get_top_n:]
attributions = list1[:get_last_n] + list1[-get_top_n:]
values = list3[:get_last_n] + list3[-get_top_n:]

start = 0.0
stop = 1.0

number_of_lines = 7
alpha = 0.5
cm_subsection = linspace(stop, start, number_of_lines)
colors = [cm.viridis(x) for x in cm_subsection]
colors = [color[:3] + (alpha,) for color in colors]

q001_, q01_, q25_, q50_, q75_, q90_, q99_ =  [liste[:get_last_n] + liste[-get_top_n:] for liste in [q001_ordered, q01_ordered, q25_ordered, q50_ordered, q75_ordered, q90_ordered, q99_ordered]]

labels = [" ".join([word[0].upper() + word[1:] for word in label.split(" ")]) for label in labels]
labels = ["-".join([word[0].upper() + word[1:] for word in label.split("-")]) for label in labels]
labels = ["Z-Score " + label[6:]if label.startswith("ZSTAT") else label for label in labels]
labels = ["P-Value " + label[2:] if label.startswith("P ") else label for label in labels]
labels = ["n SNPs " + label[6:] if label.startswith("NSNPS") else label for label in labels]
labels = [label + "\n(GWAS)" if label.startswith(("Z-Score", "P-Value", "n SNPs")) else label + "\n(Gene Expr.)" for label in labels]
labels = [label[8:] if label.startswith("Cells") else label for label in labels]

attributions /= np.max(np.abs(attributions))

x = np.arange(len(labels))  # the label locations

width = 0.35  # the width of the bars

gs_kw = dict(height_ratios=[1, 20])

fig, (ax2, ax) = plt.subplots(2,1, figsize=((full_width/2.5)*cen, 11*cen), gridspec_kw=gs_kw)
ax.grid(axis="y", color="lightgray", linewidth=0.5)
rects2 = ax.barh(x + width/2, attributions, width, color=orange, label='Importance')
rects1 = ax.barh(x - width/2, values, width, color=gray, label='Input Value')


# Add some text for labels, title and custom x-axis tick labels, etc.
#ax.set_ylabel('Scores')
#ax.set_title('Scores by group and gender')
ax.set_yticks(x, labels)
#ax.legend(fontsize=small_font)
ax.axvline(x=1, ymin=0, ymax=len(labels)-2, linestyle=":", color ="lightgray", zorder = -1)
ax.axvline(x=0, ymin=0, ymax=len(labels)-2, color ="gray", zorder = 0)
#ax.bar_label(rects1, padding=3)
#ax.bar_label(rects2, padding=3)
steps = []
steps.append(ax.step(x=q001_, y =x + 0.5, where = "post", c="lightgray",zorder=-1, linewidth=0.5)[0])
steps.append(ax.step(x=q01_, y =x + 0.5, where = "post", c="lightgray", zorder=-1, linewidth=0.5)[0])
steps.append(ax.step(x=q25_, y =x +0.5, where = "post", c="lightgray", zorder=-1, linewidth=0.5)[0])
steps.append(ax.step(x=q50_, y =x +0.5, where = "post", c="lightgray", zorder=-1, linewidth=0.5)[0])
steps.append(ax.step(x=q75_, y =x +0.5, where = "post", c="lightgray", zorder=-1, linewidth=0.5)[0])
steps.append(ax.step(x=q90_, y =x +0.5, where = "post", c="lightgray", zorder=-1, linewidth=0.5)[0])
steps.append(ax.step(x=q99_, y =x +0.5, where = "post", c="lightgray", zorder=-1, linewidth=0.5)[0])

for i in range(1, len(steps)):
    ax.fill_betweenx(y=[value-0.5 if index %2 == 0 else value+0.5 for index, value in enumerate(np.repeat(x,2))], 
                     x1=np.repeat(steps[i-1].get_xdata(), 2), 
                     x2=np.repeat(steps[i].get_xdata(), 2), zorder=-1,
                     color =colors[i-1],
                     linewidth=0.5)

ax.set_ylim((-0.5, len(labels) -0.5))         
ax.set_xlabel("Normalized Input Value and Importance", size=medium_font)
cmap = (mpl.colors.ListedColormap(colors[:len(steps)-1]).with_extremes(over='white', under='white'))


bounds = [0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99]
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
cbar = fig.colorbar(
    mpl.cm.ScalarMappable(cmap=cmap, norm=norm),
    cax=ax2,
    extend='both',
    ticks=bounds,
    spacing='proportional',
    orientation='horizontal',
    label="Normalized Quantiles across all Genes' Input Values"
)

cbar.ax.set_xticklabels([".01", ".1", ".25", ".5", ".75", ".9",".99"], size=small_font) 
cbar.ax.set_xlabel("Normalized Quantiles across all Genes' Input Values", size=small_font) 

fig.tight_layout()
plt.savefig("Input_explanation_{}.svg".format(gene), bbox_inches="tight", dpi=300)
plt.show()