In [1]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib_venn import venn3

import sys
sys.path.append("..")
from models import training_utils, prediction_utils

In [2]:
data_folder = "/biodata/nyanovsky/datasets/dti/processed/"
pred_edge_type = ("gene","chg","chem")

In [3]:
datasets, node_map = training_utils.load_data(data_folder ,load_test=True)

In [4]:
def get_edge_sets(heterodata,node_map=node_map):
    """
    returns:
    1. df with edges and their nodes corresponding original df indexes
    and heterodata node indexes, their type (msg passing/supervision), 
    and label if supervision.   
    2. msg passing edges in their original node idxs
    3. supervision edges in their original node idxs
    """
    mapped = prediction_utils.MappedDataset(heterodata,node_map,("gene","chg","chem"))
    mapped_df = mapped.dataframe

    mapped_df["edges"] = mapped_df[["gene_source","chem_target"]].values.tolist()
    mapped_df["edges"] = mapped_df.edges.apply(lambda x: tuple(x))
    supervision_edges = set(mapped_df[mapped_df.edge_type == "supervision"].edges.values)
    propagation_edges = set(mapped_df[mapped_df.edge_type == "message_passing"].edges.values)

    return mapped_df,propagation_edges,supervision_edges

In [33]:
mapped_dfs[0]

Unnamed: 0,gene_source,chem_target,torch_gene_index_source,torch_chem_index_target,label,edge_type,edges
0,7843,3796,3115,1700,1.0,supervision,"(7843, 3796)"
1,217,5546,138,2996,1.0,supervision,"(217, 5546)"
2,798,2622,419,1115,1.0,supervision,"(798, 2622)"
3,4393,4392,2250,2142,1.0,supervision,"(4393, 4392)"
4,11484,7868,5673,4750,1.0,supervision,"(11484, 7868)"
...,...,...,...,...,...,...,...
24566,11356,11658,5545,5853,,message_passing,"(11356, 11658)"
24567,11359,11658,5548,5853,,message_passing,"(11359, 11658)"
24568,11360,11658,5549,5853,,message_passing,"(11360, 11658)"
24569,11361,11658,5550,5853,,message_passing,"(11361, 11658)"


In [5]:
supervision_sets = []
propagation_sets = []
mapped_dfs = []

for split in datasets:
    mapped_df,propagation, supervision = get_edge_sets(split)
    supervision_sets.append(supervision)
    propagation_sets.append(propagation)
    mapped_dfs.append(mapped_df)

In [6]:
[df.shape for df in mapped_dfs] # aristas de train, val y test: en val y test hay negativas tambien. 

[(30713, 7), (38391, 7), (42230, 7)]

In [13]:
train_sup_set, val_sup_set, test_sup_set = supervision_sets

In [18]:
print(train_sup_set & val_sup_set)
print(train_sup_set & test_sup_set)
# no hay leaks entre train y val/test

set()
set()


In [71]:
datasets[1][pred_edge_type].edge_label.shape[0]-len(supervision_sets[1])
# por que hay aristas de supervision repetidas?

4

In [72]:
datasets[2][pred_edge_type].edge_label.shape[0]-len(supervision_sets[2])
# pasa lo mismo en test

6

In [75]:
# me fijo las aristas de supervision duplicadas en val
val_sup_df = mapped_dfs[1][mapped_dfs[1]["edge_type"]=="supervision"]
val_sup_df[val_sup_df["edges"].duplicated()].sort_values(by="edges")

Unnamed: 0,gene_source,chem_target,torch_gene_index_source,torch_chem_index_target,label,edge_type,edges
6369,30,2980,23,1257,0.0,supervision,"(30, 2980)"
6829,132,6598,85,3782,0.0,supervision,"(132, 6598)"
7434,246,9309,156,5809,0.0,supervision,"(246, 9309)"
6602,3403,9309,1966,5809,0.0,supervision,"(3403, 9309)"


In [74]:
# me fijo las aristas de supervision duplicadas en test
test_sup_df = mapped_dfs[2][mapped_dfs[2]["edge_type"]=="supervision"]
test_sup_df[test_sup_df["edges"].duplicated()].sort_values(by="edges")

Unnamed: 0,gene_source,chem_target,torch_gene_index_source,torch_chem_index_target,label,edge_type,edges
7086,140,2416,89,983,0.0,supervision,"(140, 2416)"
4979,464,9309,276,5809,0.0,supervision,"(464, 9309)"
6570,732,9309,393,5809,0.0,supervision,"(732, 9309)"
5991,1677,9309,985,5809,0.0,supervision,"(1677, 9309)"
6684,3628,3767,2047,1679,0.0,supervision,"(3628, 3767)"
6907,3773,1930,2088,835,0.0,supervision,"(3773, 1930)"


In [50]:
val_test_leak = val_sup_set & test_sup_set # 5 aristas totales DE ~7670

In [53]:
mapped_dfs[1].set_index("edges").loc[list(val_test_leak)] # todas negativas

Unnamed: 0_level_0,gene_source,chem_target,torch_gene_index_source,torch_chem_index_target,label,edge_type
edges,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
"(3, 4608)",3,4608,1,2309,0.0,supervision
"(6180, 1930)",6180,1930,2719,835,0.0,supervision
"(84, 9309)",84,9309,60,5809,0.0,supervision
"(8982, 9309)",8982,9309,3356,5809,0.0,supervision
"(1786, 3999)",1786,3999,1013,1849,0.0,supervision
