# Setup

In [16]:
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import average_precision_score
from tqdm import tqdm
import matplotlib.pyplot as plt
import plotly.express as px

import sys
sys.path.append("../../..")
from src.models import training_utils, sage_ones, base_model

data_folder = "../../../data/processed/graph_data_nohubs/merged_types/"
feature_folder = "../../../data/processed/feature_data/"
reports_folder = "../../../reports/explore_predictions/"

# Load a pretrained model

In [4]:
experiment_results = pd.read_csv("../../../reports/random_walks/walk_3/random_walk_results.csv",index_col=0)

In [81]:
step = 22
model_params = experiment_results.loc[step].to_dict()

In [82]:
model_params

{'hidden_channels': 64,
 'micro_aggregation': 'mean',
 'macro_aggregation': 'mean',
 'layer_connectivity': 'skipsum',
 'L2_norm': False,
 'pre_process_layers': 0,
 'msg_passing_layers': 3,
 'post_process_layers': 2,
 'normalize_output': False,
 'jumping_knowledge': False,
 'feature_dim': 32,
 'feature_type': 'lsa_scaled',
 'conv_type': 'SAGEConv',
 'batch_norm': True,
 'dropout': 0.1,
 'weight_decay': 0.001,
 'lr': 0.001,
 'epochs': 400,
 'patience': 10,
 'delta': 0.1,
 'sample_epochs': 10,
 'supervision_types': "[('gene_protein', 'gda', 'disease')]",
 'mean_auc': 0.8858,
 'std': 0.002315167380558}

In [83]:
seed = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

prediction_edge_type = ("gene_protein","gda","disease")
datasets, node_map = training_utils.load_data(data_folder+f"split_dataset/seed_{seed}/")
train_data, val_data = datasets

# feature_type = "ones"
# feature_dim = 10
feature_type = model_params["feature_type"]
feature_dim = model_params["feature_dim"]
train_data = training_utils.initialize_features(train_data, feature_type, feature_dim,feature_folder)
val_data = training_utils.initialize_features(val_data, feature_type, feature_dim,feature_folder)

# weights_path = "../../../data/experiments/merged_types_experiment/sage_ones_merged_experiment_13_06_23__15_59.pth"
# weights_path = "../../../data/experiments/merged_types_experiment/sage_ones_first_negatives_exp_04_07_23__12_07.pth"
weights_path = f"../../../reports/random_walks/walk_3/step_{step-1}_0.pth"
weights = torch.load(weights_path,map_location=device)
model = sage_ones.Model(train_data.metadata(),[("gene_protein","gda","disease")])
model = base_model.base_model(model_params,train_data.metadata(),model_params["supervision_types"])
model.load_state_dict(weights)

node_df = pd.read_csv(data_folder+"split_dataset/seed_4/tensor_df.csv",index_col=0).set_index("node_index",drop=True)

# Get encodings

In [3]:
%%timeit
encodings_dict = training_utils.get_encodings(model,val_data)

36.6 ms ± 3.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [85]:
encodings_dict = training_utils.get_encodings(model,val_data)

# Prioritization

In [86]:
class Predictor():
    def __init__(self,node_df, encodings_dict):
        assert node_df.index.name == "node_index", f"df index must be node_index, not {node_df.index.name}."

        self.df = node_df
        self.encodings = encodings_dict
    
    def inner_product_decoder(self,x_source,x_target,apply_sigmoid=True):
        pred = (x_source * x_target).sum(dim=1)

        if apply_sigmoid:
            pred = torch.sigmoid(pred)

        return pred
    
    def prioritize_one_vs_all(self,node_index):
        source_type = self.df.loc[node_index,"node_type"]
        tensor_index = self.df.loc[node_index,"tensor_index"]

        if source_type == "disease":
            target_type = "gene_protein"

        elif source_type == "gene_protein":
            target_type = "disease"

        source_vector = self.encodings[source_type][tensor_index]
        target_matrix = self.encodings[target_type]

        predicted_edges = self.inner_product_decoder(source_vector,target_matrix)
        ranked_scores, ranked_indices = torch.sort(predicted_edges,descending=True)
        results = pd.DataFrame({"score":ranked_scores.cpu().numpy(),"tensor_index":ranked_indices.cpu().numpy()})
        results.score = results.score.round(3)

        index_map = self.df.loc[self.df.node_type == target_type,["tensor_index","node_name"]].reset_index()
        ranked_predictions = pd.merge(results,index_map,on="tensor_index")
        ranked_predictions.index.name = "rank"

        return ranked_predictions
    
    def predict_supervision_edges(self,data, edge_type, return_dataframe=True):
        """If return_dataframe_==True, returns dataframe with edges, prediction scores and labels. Else, returns predicted scores tensor"""
        src_type, trg_type = edge_type[0],edge_type[2]
        x_source = self.encodings[src_type]
        x_target = self.encodings[trg_type]

        edge_label_index = data.edge_label_index_dict[edge_type]
        source_index, target_index = edge_label_index[0], edge_label_index[1]

        emb_nodes_source = x_source[source_index]
        emb_nodes_target = x_target[target_index]

        pred = self.inner_product_decoder(emb_nodes_source, emb_nodes_target)
        if return_dataframe:
            labels = data.edge_label_dict[edge_type].numpy()
            df = pd.DataFrame({"torch_gene_protein_index":source_index, "torch_disease_index":target_index, "score":pred, "label":labels})
            return df
        else: 
            return pred
    
    def hits_at_k(self,node_index,mapped_train,mapped_val):
      k_list = [5,10,50,100]
      predictions = self.prioritize_one_vs_all(node_index)

      node_type = self.df.loc[node_index,"node_type"]
      y_type = "disease" if node_type == "gene_protein" else "gene_protein"

      new_edges = set(mapped_val[(mapped_val.edge_type == "supervision") & (mapped_val.label == 1) & (mapped_val[node_type] == node_index)][y_type].values)
      seen_edges = set(mapped_train[(mapped_train.label != 0) & (mapped_train[node_type] == node_index)][y_type].values)

      results = {"seen_edges":len(seen_edges),"new_edges":len(new_edges)}

      for k in k_list:
            predicted_top = set(predictions[:k]["node_index"].values)

            seen_hits = len(seen_edges.intersection(predicted_top))
            new_hits = len(new_edges.intersection(predicted_top))

            results[f"{k}_seen"] = seen_hits
            results[f"{k}_new"] = new_hits

      return results

# Map datasets to index

In [87]:
class MappedDataset():
    def __init__(self,heterodata,node_map,prediction_edge_type):
        self.prediction_edge_type = prediction_edge_type
        self.node_map = node_map
        self.edge_dict = self._reverse_map_heterodata(heterodata)
        self.dataframe = self._edge_dict_to_dataframe()
        
    def _reverse_map_tensor(self,tensor,edge_type):
        """Maps edge dictionary from pyg Heterodata back into the original node indexes from the dataframe"""
        #Tensor to lists [sources], [targets]
        sources = tensor[0,:].tolist()
        targets = tensor[1,:].tolist()

        #Map edge list to node indexes
        src_type, dst_type = edge_type[0], edge_type[2]
        src_map,dst_map = self.node_map[src_type], self.node_map[dst_type]

        mapped_src = [src_map[n] for n in sources]
        mapped_trg = [dst_map[n] for n in targets]

        return {src_type:mapped_src, dst_type:mapped_trg, f"torch_{src_type}_index":sources, f"torch_{dst_type}_index":targets}

    def _reverse_map_heterodata(self,data):
        """Maps full edge data from pyg Heterodata back into the original node indexes from the dataframe"""
        edge_dict = {}
        for edge_type in data.edge_types:
            type_dict = {}
            edge_tensor = data[edge_type]["edge_index"]
            mapped_edge_list = self._reverse_map_tensor(edge_tensor,edge_type)

            type_dict["message_passing_edges"] = mapped_edge_list

            if "edge_label_index" in data[edge_type].keys():
                labeled_edges_tensor = data[edge_type]["edge_label_index"]
                # labeled_edges_list = tensor_to_edgelist(labeled_edges_tensor)
                mapped_labeled_edges_list = self._reverse_map_tensor(labeled_edges_tensor,edge_type)
                edge_labels = data[edge_type]["edge_label"].tolist()

                type_dict["supervision_edges"] = mapped_labeled_edges_list
                type_dict["supervision_labels"] = edge_labels
    
            edge_dict[edge_type] = type_dict
        
        return edge_dict
    
    def _edge_dict_to_dataframe(self):
        edges_df = []
        e_dict = self.edge_dict[self.prediction_edge_type]
        supervision_edges = pd.DataFrame(e_dict["supervision_edges"])

        labeled_edges = pd.concat([supervision_edges,pd.DataFrame(e_dict["supervision_labels"])],axis=1).rename(columns={0:"label"})
        msg_passing_edges = pd.DataFrame(e_dict["message_passing_edges"])

        msg_passing_edges["edge_type"] = "message_passing"
        labeled_edges["edge_type"] = "supervision"


        edges_df.append(labeled_edges)
        edges_df.append(msg_passing_edges)
        total_df = pd.concat(edges_df,axis=0)
        return total_df

mapped_val = MappedDataset(val_data,node_map,prediction_edge_type)
mapped_train = MappedDataset(train_data,node_map,prediction_edge_type)

In [91]:
predictor = Predictor(node_df,encodings_dict)
pred = predictor.predict_supervision_edges(val_data, prediction_edge_type,return_dataframe=True)
val_supervision_edges = mapped_val.dataframe[mapped_val.dataframe.edge_type == "supervision"]
pred = pd.concat([pred,val_supervision_edges[["gene_protein","disease"]]], axis=1).round(2)
pred


Unnamed: 0,torch_gene_protein_index,torch_disease_index,score,label,gene_protein,disease
0,498,1140,0.98,1.0,3810,24466
1,2054,3088,0.02,1.0,2884,31828
2,10133,6323,0.99,1.0,2973,26568
3,3530,10617,0.85,1.0,8162,26932
4,8048,6729,0.95,1.0,12683,30601
...,...,...,...,...,...,...
16797,16657,1302,0.00,0.0,13107,26053
16798,12001,1004,0.00,0.0,9685,20075
16799,4783,63,0.65,0.0,12695,22363
16800,908,1453,0.00,0.0,171,29805


In [269]:
idx = node_df[node_df.node_type == "disease"].sample().index.values[0]
print(idx)
pred[pred.disease == idx].sort_values(by="score", ascending=False)[["label","score"]]

26633


Unnamed: 0,label,score
10818,0.0,0.84
2878,1.0,0.72
11443,0.0,0.42
12859,0.0,0.2
12515,0.0,0.07
1771,1.0,0.0
11666,0.0,0.0


In [92]:
disease_index = []
ap_at_10 = []
ap_at_5 = []
k = []
for disease in pred.disease.unique():
    labels = pred[pred.disease == disease].sort_values(by="score", ascending=False)["label"].values
    scores = pred[pred.disease == disease].sort_values(by="score", ascending=False)["score"].values
    k.append(len(labels))

    if len(labels) >= 10:
        labels = labels[:10]
        scores = scores[:10]
        ap_at_10.append(average_precision_score(labels,scores))  
    else:
        ap_at_10.append(average_precision_score(labels,scores))
    
    if len(labels) >= 5:
        labels = labels[:5]
        scores = scores[:5]
        ap_at_5.append(average_precision_score(labels,scores))
    else:
        ap_at_5.append(average_precision_score(labels,scores))

    disease_index.append(disease)

ap_df = pd.DataFrame({"disease_index":disease_index,"ap_at_10":ap_at_10,"ap_at_5":ap_at_5, "k":k})
ap_df



Unnamed: 0,disease_index,ap_at_10,ap_at_5,k
0,24466,1.0,1.0,19
1,31828,1.0,1.0,1
2,26568,1.0,1.0,2
3,26932,1.0,1.0,1
4,30601,1.0,1.0,1
...,...,...,...,...
5149,27120,-0.0,-0.0,1
5150,20166,-0.0,-0.0,1
5151,28984,-0.0,-0.0,1
5152,33298,-0.0,-0.0,1


In [11]:
ap_df[ap_df.k < 5]

Unnamed: 0,disease_index,ap_at_10,ap_at_5,k
0,25553,1.0,1.0,3
2,31007,1.0,1.0,2
4,22836,1.0,1.0,1
10,21097,1.0,1.0,1
11,30091,1.0,1.0,2
...,...,...,...,...
8276,34168,-0.0,-0.0,1
8277,27990,-0.0,-0.0,1
8278,25972,-0.0,-0.0,1
8279,18688,-0.0,-0.0,1


In [105]:
ap_df[ap_df.k >= 5].ap_at_5.mean().round(2)

0.92

In [270]:
ap_df[ap_df.k >= 10].ap_at_10.mean().round(2)

0.93

In [275]:
ap_df[(ap_df.k <= 10)&(ap_df.k >= 5)].ap_at_5.mean().round(2)

0.9

In [276]:
ap_df[(ap_df.k <= 10)&(ap_df.k >= 5)]

Unnamed: 0,disease_index,ap_at_10,ap_at_5,k
7,27030,1.000000,1.000000,6
11,20519,1.000000,1.000000,6
12,21894,1.000000,1.000000,5
40,32065,1.000000,1.000000,5
50,25072,0.833333,0.866667,6
...,...,...,...,...
2953,24064,-0.000000,-0.000000,5
3096,33375,-0.000000,-0.000000,5
3210,20839,-0.000000,-0.000000,5
3390,27076,-0.000000,-0.000000,5


# Hits

In [284]:
disease_evals = {}
for disease in tqdm(node_df[node_df.node_type == "disease"].index.values):
    predictions = predictor.prioritize_one_vs_all(disease)
    disease_evals[disease] = predictor.hits_at_k(disease,mapped_train.dataframe, mapped_val.dataframe)

100%|██████████| 16079/16079 [07:48<00:00, 34.30it/s]


In [285]:
gene_evals = {}
for gene_protein in tqdm(node_df[node_df.node_type == "gene_protein"].index.values):
    predictions = predictor.prioritize_one_vs_all(gene_protein)
    gene_evals[gene_protein] = predictor.hits_at_k(gene_protein,mapped_train.dataframe, mapped_val.dataframe)

100%|██████████| 17743/17743 [08:46<00:00, 33.72it/s]


In [286]:
df = node_df

total_disease_evals = pd.DataFrame(disease_evals).T
cols = total_disease_evals.columns
total_disease_evals = total_disease_evals.merge(df["node_name"],left_index=True,right_index=True)[["node_name",*cols]]

total_gene_evals = pd.DataFrame(gene_evals).T
cols = total_gene_evals.columns
total_gene_evals = total_gene_evals.merge(df["node_name"],left_index=True,right_index=True)[["node_name",*cols]]

k_list = [5,10,50,100]
summary_disease_evals = pd.DataFrame()
for k in k_list:
    summary_disease_evals[f"hits_{k}"] = total_disease_evals[f"{k}_new"] + total_disease_evals[f"{k}_seen"]

cols = summary_disease_evals.columns
add_cols=["node_name","degree_gda","degree_dd","comunidades_infomap","comunidades_louvain"]
summary_disease_evals = pd.merge(summary_disease_evals,df[add_cols],left_index=True,right_index=True)[[*add_cols,*cols]]

summary_gene_evals = pd.DataFrame()
for k in k_list:
    summary_gene_evals[f"hits_{k}"] = total_gene_evals[f"{k}_new"] + total_gene_evals[f"{k}_seen"]

cols = summary_gene_evals.columns
add_cols = ["node_name","degree_gda","degree_pp"]
summary_gene_evals = pd.merge(summary_gene_evals,df[add_cols],left_index=True,right_index=True)[[*add_cols,*cols]]

total_disease_evals = total_disease_evals.merge(summary_disease_evals["degree_gda"],left_index=True,right_index=True)
total_gene_evals = total_gene_evals.merge(summary_gene_evals["degree_gda"],left_index=True,right_index=True)

In [292]:
def save_hits_df(model_name,desc,disease_total,disease_summary,gene_total,gene_summary,reports_folder=reports_folder):
    disease_total.to_csv(reports_folder+model_name+"_total_disease.csv")
    disease_summary.to_csv(reports_folder+model_name+"_summary_disease.csv")
    gene_summary.to_csv(reports_folder+model_name+"_summary_gene.csv")
    gene_total.to_csv(reports_folder+model_name+"_total_gene.csv")

    with open(reports_folder+model_name+"_desc.txt", "w") as f:
        f.write(desc)

def load_hits_df(model_name,reports_folder=reports_folder):
    fnames = ["_total_disease.csv","_summary_disease.csv","_summary_gene.csv","_total_gene.csv"]
    dfs = []
    for fname in fnames:
        dfs.append(pd.read_csv(reports_folder+model_name+fname,index_col=0))
    return dfs

def group_by_range(data_df,group_column,ranges,inplace=True):
    if not inplace:
        df = data_df.copy()
        bins = np.digitize(df[group_column].values, ranges)
        df["bins"] = bins
        return df
    else:
        df = data_df
        bins = np.digitize(df[group_column].values, ranges)
        df["bins"] = bins

def plot_box(data_df,value_cols,title,range_text,y_top):
    melted_df = data_df[["bins",*value_cols]].melt("bins").rename(columns={"value":"hits"})
    melted_df["Nivel de Evidencia"] = melted_df.bins.apply(lambda x: range_text[x])
    fig = px.box(melted_df.sort_values(by="bins"),y="hits",x="Nivel de Evidencia",color="variable",title=title,width=900,height=450,labels={"hits":"Hits"})
    fig.update_yaxes(range=[-0.5, y_top])
    fig.show()

In [289]:
model_name = f"step_{step-1}_0.pth"
desc = "El mejor modelo de la caminata con 0.88 auc"

save_hits_df(model_name,desc,total_disease_evals,summary_disease_evals,total_gene_evals,summary_gene_evals)

In [294]:
# model_name = "sage_ones_no_sampling"
# model_name = "sage_ones_first_negatives_exp_04_07_23__12_07"
model_name = f"step_{step-1}_0.pth"
hits_df = load_hits_df(model_name)

disease_ranges = np.array([10,50,100,hits_df[0].degree_gda.max()+1]).astype(int)
gene_ranges = np.array([5,20,50,100,hits_df[3].degree_gda.max()+1]).astype(int)
range_text = ["< 10","10-50","50-100","100 +"]

for i,data in enumerate(hits_df):
    if i<2:
        data = group_by_range(data[data.degree_gda != 0],"degree_gda",disease_ranges,inplace=False)
        hits_df[i] = data
    else:
        data = group_by_range(data[data.degree_gda != 0],"degree_gda",gene_ranges,inplace=False)
        hits_df[i] = data       

value_pairs = [["5_seen","5_new"],["10_seen","10_new"],["50_seen","50_new"]]
for pair in value_pairs:
    plot_box(hits_df[0],pair,"Evaluación Enfermedades",range_text,10)

In [63]:
# model_name = "sage_ones_no_sampling"
model_name = "sage_ones_first_negatives_exp_04_07_23__12_07"
hits_df = load_hits_df(model_name)

disease_ranges = np.array([10,50,100,hits_df[0].degree_gda.max()+1]).astype(int)
gene_ranges = np.array([5,20,50,100,hits_df[3].degree_gda.max()+1]).astype(int)
range_text = ["< 10","10-50","50-100","100 +"]

for i,data in enumerate(hits_df):
    if i<2:
        data = group_by_range(data[data.degree_gda != 0],"degree_gda",disease_ranges,inplace=False)
        hits_df[i] = data
    else:
        data = group_by_range(data[data.degree_gda != 0],"degree_gda",gene_ranges,inplace=False)
        hits_df[i] = data       

value_pairs = [["5_seen","5_new"],["10_seen","10_new"],["50_seen","50_new"]]
for pair in value_pairs:
    plot_box(hits_df[0],pair,"Evaluación Enfermedades",range_text)