Download all integrated gradients graphs as SVG

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import shutil
import wandb

import os

genes =  [
    "C6orf150",
    "CCL5",
    "CXCL10",
    "TMEM173",
    "CXCL9",
    "CXCL11",
    "NFKB1",
    "IKBKE",
    "IRF3",
    "TREX1",
    "ATM",
    "IL6",
    "IL8"
  ]


def draw_graph(feature_importance, adj_matrix):
    normalized_data = np.array(feature_importance)
    normalized_data = (normalized_data - np.min(normalized_data)) / (
            np.max(normalized_data) - np.min(normalized_data))

    cmap = plt.cm.RdPu
    colors = [cmap(val) for val in normalized_data]

    fig = plt.figure(figsize=(10, 8))
    G = nx.from_numpy_array(adj_matrix)
    labels = {i: gene for i, gene in enumerate(genes)}
    pos = nx.kamada_kawai_layout(G)
    nx.draw(G, pos, labels=labels, with_labels=True, node_color=colors, node_size=500, font_size=12)
    return fig


USERNAME = 'borna-personal'
PROJECT = 'GENIE-Nextflow-v2'

BASE_PATH = 'genie-graphs'

api = wandb.Api()
runs = api.runs(f"{USERNAME}/{PROJECT}")

if os.path.exists(f'{BASE_PATH}'):
    shutil.rmtree(f'{BASE_PATH}')
os.mkdir(f'{BASE_PATH}')

for run in runs:
    cancer = run.config.get('cancers')
    adj_matrix = run.summary.get('adjacency_matrix')
    target = run.config.get('variable')
    f1 = run.summary.get('final_f1')

    if not os.path.exists(f'{BASE_PATH}/{cancer}'):
        os.mkdir(f'{BASE_PATH}/{cancer}')
    
    os.mkdir(f'{BASE_PATH}/{cancer}/{target}')
    for attr, _class in runs[4].history()[['attributions', 'class']].iloc[-2:, :].values:
         fig = draw_graph(feature_importance=attr, adj_matrix=np.array(adj_matrix))
         fig.savefig(f'{BASE_PATH}/{cancer}/{target}/{target}_{_class}.svg', format='svg')

Create f1 final results table

In [None]:
import wandb

USERNAME = 'borna-personal'
PROJECT = 'GENIE-Nextflow-v2'

api = wandb.Api()
runs = api.runs(f"{USERNAME}/{PROJECT}")
f1_table = {}

for run in runs:
    cancer = run.config.get('cancers')
    target = run.config.get('variable')
    f1 = run.summary.get('final_f1')

    if cancer not in f1_table.keys():
        f1_table[cancer] = []

    f1_table[cancer].append((target, f1))


for cancer in f1_table.keys():
    print(cancer)
    for target, f1 in f1_table[cancer]:
        if f1 is None:
            continue
        print(f'\t{target}: {round(float(f1), 2)}')  