# Setup

In [2]:
import pandas as pd
import numpy as np
import torch
from sklearn.metrics import average_precision_score

import sys
sys.path.append("../../..")
from src.models import training_utils, sage_ones

data_folder = "../../../data/processed/graph_data_nohubs/merged_types/"
reports_folder = "../../../reports/explore_predictions/"

# Load a pretrained model

In [3]:
prediction_edge_type = ("gene_protein","gda","disease")
datasets, node_map = training_utils.load_data(data_folder+"split_dataset/seed_4/")
train_data, val_data = datasets

feature_type = "ones"
feature_dim = 10
train_data = training_utils.initialize_features(train_data, feature_type, feature_dim)
val_data = training_utils.initialize_features(val_data, feature_type, feature_dim)

weights_path = "../../../data/experiments/merged_types_experiment/sage_ones_first_negatives_exp_04_07_23__12_07.pth"
weights = torch.load(weights_path)
model = sage_ones.Model(train_data.metadata(),[("gene_protein","gda","disease")])
model.load_state_dict(weights)

node_df = pd.read_csv(data_folder+"split_dataset/seed_4/tensor_df.csv",index_col=0).set_index("node_index",drop=True)

# Get encodings

In [6]:
%%timeit
encodings_dict = training_utils.get_encodings(model,val_data)

32.7 ms ± 691 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
encodings_dict = training_utils.get_encodings(model,val_data)

# Prioritization

In [62]:
class Predictor():
    def __init__(self,node_df, encodings_dict):
        assert node_df.index.name == "node_index", f"df index must be node_index, not {node_df.index.name}."

        self.df = node_df
        self.encodings = encodings_dict
    
    def inner_product_decoder(self,x_source,x_target,apply_sigmoid=True):
        pred = (x_source * x_target).sum(dim=1)

        if apply_sigmoid:
            pred = torch.sigmoid(pred)

        return pred
    
    def prioritize_one_vs_all(self,node_index):
        source_type = self.df.loc[node_index,"node_type"]
        tensor_index = self.df.loc[node_index,"tensor_index"]

        if source_type == "disease":
            target_type = "gene_protein"

        elif source_type == "gene_protein":
            target_type = "disease"

        source_vector = self.encodings[source_type][tensor_index]
        target_matrix = self.encodings[target_type]

        predicted_edges = self.inner_product_decoder(source_vector,target_matrix)
        ranked_scores, ranked_indices = torch.sort(predicted_edges,descending=True)
        results = pd.DataFrame({"score":ranked_scores.cpu().numpy(),"tensor_index":ranked_indices.cpu().numpy()})
        results.score = results.score.round(3)

        index_map = self.df.loc[self.df.node_type == target_type,["tensor_index","node_name"]].reset_index()
        ranked_predictions = pd.merge(results,index_map,on="tensor_index")
        ranked_predictions.index.name = "rank"

        return ranked_predictions
    
    def predict_supervision_edges(self,data, edge_type, return_dataframe=True):
        """If return_dataframe_==True, returns dataframe with edges, prediction scores and labels. Else, returns predicted scores tensor"""
        src_type, trg_type = edge_type[0],edge_type[2]
        x_source = self.encodings[src_type]
        x_target = self.encodings[trg_type]

        edge_label_index = data.edge_label_index_dict[edge_type]
        source_index, target_index = edge_label_index[0], edge_label_index[1]

        emb_nodes_source = x_source[source_index]
        emb_nodes_target = x_target[target_index]

        pred = self.inner_product_decoder(emb_nodes_source, emb_nodes_target)
        if return_dataframe:
            labels = data.edge_label_dict[edge_type].numpy()
            df = pd.DataFrame({"torch_gene_protein_index":source_index, "torch_disease_index":target_index, "score":pred, "label":labels})
            return df
        else: 
            return pred

# Map datasets to index

In [27]:
class MappedDataset():
    def __init__(self,heterodata,node_map,prediction_edge_type):
        self.prediction_edge_type = prediction_edge_type
        self.node_map = node_map
        self.edge_dict = self._reverse_map_heterodata(heterodata)
        self.dataframe = self._edge_dict_to_dataframe()
        
    def _reverse_map_tensor(self,tensor,edge_type):
        """Maps edge dictionary from pyg Heterodata back into the original node indexes from the dataframe"""
        #Tensor to lists [sources], [targets]
        sources = tensor[0,:].tolist()
        targets = tensor[1,:].tolist()

        #Map edge list to node indexes
        src_type, dst_type = edge_type[0], edge_type[2]
        src_map,dst_map = self.node_map[src_type], self.node_map[dst_type]

        mapped_src = [src_map[n] for n in sources]
        mapped_trg = [dst_map[n] for n in targets]

        return {src_type:mapped_src, dst_type:mapped_trg, f"torch_{src_type}_index":sources, f"torch_{dst_type}_index":targets}

    def _reverse_map_heterodata(self,data):
        """Maps full edge data from pyg Heterodata back into the original node indexes from the dataframe"""
        edge_dict = {}
        for edge_type in data.edge_types:
            type_dict = {}
            edge_tensor = data[edge_type]["edge_index"]
            mapped_edge_list = self._reverse_map_tensor(edge_tensor,edge_type)

            type_dict["message_passing_edges"] = mapped_edge_list

            if "edge_label_index" in data[edge_type].keys():
                labeled_edges_tensor = data[edge_type]["edge_label_index"]
                # labeled_edges_list = tensor_to_edgelist(labeled_edges_tensor)
                mapped_labeled_edges_list = self._reverse_map_tensor(labeled_edges_tensor,edge_type)
                edge_labels = data[edge_type]["edge_label"].tolist()

                type_dict["supervision_edges"] = mapped_labeled_edges_list
                type_dict["supervision_labels"] = edge_labels
    
            edge_dict[edge_type] = type_dict
        
        return edge_dict
    
    def _edge_dict_to_dataframe(self):
        edges_df = []
        e_dict = self.edge_dict[self.prediction_edge_type]
        supervision_edges = pd.DataFrame(e_dict["supervision_edges"])

        labeled_edges = pd.concat([supervision_edges,pd.DataFrame(e_dict["supervision_labels"])],axis=1).rename(columns={0:"label"})
        msg_passing_edges = pd.DataFrame(e_dict["message_passing_edges"])

        msg_passing_edges["edge_type"] = "message_passing"
        labeled_edges["edge_type"] = "supervision"


        edges_df.append(labeled_edges)
        edges_df.append(msg_passing_edges)
        total_df = pd.concat(edges_df,axis=0)
        return total_df

mapped_val = MappedDataset(val_data,node_map,prediction_edge_type)
mapped_train = MappedDataset(train_data,node_map,prediction_edge_type)

In [63]:
predictor = Predictor(node_df,encodings_dict)
pred = predictor.predict_supervision_edges(val_data, prediction_edge_type,return_dataframe=True)
val_supervision_edges = mapped_val.dataframe[mapped_val.dataframe.edge_type == "supervision"]
pred = pd.concat([pred,val_supervision_edges[["gene_protein","disease"]]], axis=1)
pred


Unnamed: 0,torch_gene_protein_index,torch_disease_index,score,label,gene_protein,disease
0,498,1140,0.875595,1.0,3810,24466
1,2054,3088,0.160896,1.0,2884,31828
2,10133,6323,0.737463,1.0,2973,26568
3,3530,10617,0.258367,1.0,8162,26932
4,8048,6729,0.441622,1.0,12683,30601
...,...,...,...,...,...,...
16797,16657,1302,0.632007,0.0,13107,26053
16798,12001,1004,0.499077,0.0,9685,20075
16799,4783,63,0.839635,0.0,12695,22363
16800,908,1453,0.032092,0.0,171,29805


In [189]:
ap_df[ap_df.ap_at_5 < 0.5]

Unnamed: 0,disease_index,ap_at_10,ap_at_5,k
37,24180,0.477778,0.416667,29
127,32089,0.309524,0.333333,7
151,21916,0.566156,0.450000,24
152,18649,0.594048,0.477778,27
192,20160,0.416667,0.416667,4
...,...,...,...,...
5149,27120,-0.000000,-0.000000,1
5150,20166,-0.000000,-0.000000,1
5151,28984,-0.000000,-0.000000,1
5152,33298,-0.000000,-0.000000,1


In [188]:
idx = 24180
pred[pred.disease == idx].sort_values(by="score", ascending=False)[["label","score"]]

Unnamed: 0,label,score
12204,0.0,0.862819
11590,0.0,0.840733
5900,1.0,0.834235
2905,1.0,0.822737
13573,0.0,0.809931
4318,1.0,0.784502
8483,0.0,0.780439
3288,1.0,0.769507
3027,1.0,0.728701
16321,0.0,0.710944


In [181]:
disease_index = []
ap_at_10 = []
ap_at_5 = []
k = []
for disease in pred.disease.unique():
    labels = pred[pred.disease == disease].sort_values(by="score", ascending=False)["label"].values
    scores = pred[pred.disease == disease].sort_values(by="score", ascending=False)["score"].values
    k.append(len(labels))

    if len(labels) >= 10:
        labels = labels[:10]
        scores = scores[:10]
        ap_at_10.append(average_precision_score(labels,scores))  
    else:
        ap_at_10.append(average_precision_score(labels,scores))
    
    if len(labels) >= 5:
        labels = labels[:5]
        scores = scores[:5]
        ap_at_5.append(average_precision_score(labels,scores))
    else:
        ap_at_5.append(average_precision_score(labels,scores))

    disease_index.append(disease)

ap_df = pd.DataFrame({"disease_index":disease_index,"ap_at_10":ap_at_10,"ap_at_5":ap_at_5, "k":k})
ap_df


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all thresholds.


No positive class found in y_true, recall is set to one for all

Unnamed: 0,disease_index,ap_at_10,ap_at_5,k
0,24466,1.0,1.0,19
1,31828,1.0,1.0,1
2,26568,1.0,1.0,2
3,26932,1.0,1.0,1
4,30601,1.0,1.0,1
...,...,...,...,...
5149,27120,-0.0,-0.0,1
5150,20166,-0.0,-0.0,1
5151,28984,-0.0,-0.0,1
5152,33298,-0.0,-0.0,1


In [190]:
ap_df.ap_at_5.mean().round(2)

0.51

In [183]:
ap_df[ap_df.k < 5]

Unnamed: 0,disease_index,ap_at_10,ap_at_5,k
1,31828,1.0,1.0,1
2,26568,1.0,1.0,2
3,26932,1.0,1.0,1
4,30601,1.0,1.0,1
5,22973,1.0,1.0,2
...,...,...,...,...
5149,27120,-0.0,-0.0,1
5150,20166,-0.0,-0.0,1
5151,28984,-0.0,-0.0,1
5152,33298,-0.0,-0.0,1


In [191]:
ap_df[ap_df.k >= 5].ap_at_5.mean().round(2)

0.83

In [186]:
ap_df[ap_df.k >= 10].ap_at_10.mean().round(2)

0.81

In [193]:
ap_df[ap_df.k < 5].ap_at_10.mean().round(2)

0.45

In [184]:
import plotly.express as px

aver = pd.merge(ap_df,node_df[node_df.node_type == "disease"], left_on="disease_index",right_index=True, how="left")
fig = px.scatter(aver,x="degree_gda",y="ap_at_10")
fig.show()