# Setup

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import copy
import plotly.express as px
import torch
from torch_geometric import seed_everything

import sys
sys.path.append("../../..")
from src.models.training_utils import NegativeSampler, load_data

data_folder = "../../../data/processed/graph_data_nohubs/merged_types/"
experiments_folder = "../../../data/experiments/design_space_merged_experiment/seed_0/"

seed = 4
seed_everything(seed)

# Load data

In [3]:
# node_csv = pd.read_csv(data_folder+"merged_nodes.csv", index_col="node_index")
# node_info = pd.read_csv(data_folder+"merged_node_info.csv",index_col=0)
# edge_data = pd.read_csv(data_folder+"merged_edges.csv")

datasets, node_map = load_data(data_folder+f"split_dataset/seed_{seed}/")
train_data,val_data = datasets

full_dataset = torch.load(data_folder+f"split_dataset/seed_{seed}/full_dataset.pt")
tensor_df = pd.read_csv(data_folder+f"split_dataset/seed_{seed}/tensor_df.csv",index_col=0)

In [4]:
src_degrees = tensor_df[tensor_df.node_type == "gene_protein"]["degree_gda"].values
dst_degrees = tensor_df[tensor_df.node_type == "disease"]["degree_gda"].values
pred_edge_type = ("gene_protein","gda","disease")

negative_sampler = NegativeSampler(full_dataset,pred_edge_type,src_degrees,dst_degrees)

# Neighborhood sampling

Tengo que ver si me conviene pasarle primero el neighborhood loader o el negative sampler.
Si primero sampleo negativos y después hago NL, me va a romper la distribución deg 0.75 que armé
En cambio si primero sampleo y después genero negativos a partir de esa muestra positiva, voy a conservar la distribución.

Después esta el HGsampler (o algo asi) que no entendí del todo bien que hace, pero sonaba bien?

Tome como num neighbors el tercer cuartil en la distribución de grado, para cada edge type consideré los enlaces correspondientes.

Dado (src,edge_type,dst):
Para cada tipo src y dst, tomo el tercer cuartil en la distribución de grado edge_type.

In [5]:
num_neighbors_per_type = {}
deg_type = {"pathway_protein":"degree_pp","disease_disease":"degree_dd","gda":"degree_gda","ppi":"degree_pp","form_complex":"degree_pp"}
for edge_type in train_data.edge_types:
    dst_type = edge_type[2]
    src_type = edge_type[0]
    deg_column = deg_type[edge_type[1]]
    num_src = tensor_df[(tensor_df.node_type == src_type) & (tensor_df[deg_column] != 0)][deg_column].describe()["75%"].astype(int)
    num_dst = tensor_df[(tensor_df.node_type == dst_type) & (tensor_df[deg_column] != 0)][deg_column].describe()["75%"].astype(int)
    num_neighbors_per_type[edge_type] = [num_src,num_dst]

In [6]:
num_neighbors_per_type

{('disease', 'gda', 'gene_protein'): [4, 9],
 ('pathway', 'pathway_protein', 'gene_protein'): [24, 19],
 ('gene_protein', 'ppi', 'gene_protein'): [19, 19],
 ('gene_protein', 'gda', 'disease'): [9, 4],
 ('gene_protein', 'pathway_protein', 'pathway'): [19, 24],
 ('disease', 'disease_disease', 'disease'): [3, 3],
 ('gene_protein', 'form_complex', 'gene_protein'): [19, 19]}

In [7]:
from torch_geometric.loader import LinkNeighborLoader

train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=num_neighbors_per_type,
    edge_label_index=(pred_edge_type,train_data[pred_edge_type]["edge_label_index"]),
    edge_label = train_data[pred_edge_type]["edge_label"],
    replace=True,
    batch_size=2048,
    shuffle=True,
    num_workers=8,
    drop_last=True,
)

In [8]:
sampled_data = next(iter(train_loader))

Habría que ver que es e_id y todo eso ahora. Y como se lo paso al modelo.

In [9]:
sampled_data

HeteroData(
  [1mgene_protein[0m={
    num_nodes=13288,
    n_id=[13288]
  },
  [1mdisease[0m={
    num_nodes=6306,
    n_id=[6306]
  },
  [1mpathway[0m={
    num_nodes=1734,
    n_id=[1734]
  },
  [1m(disease, gda, gene_protein)[0m={
    edge_index=[2, 38072],
    adj_t=[13288, 6306, nnz=38072],
    e_id=[38072]
  },
  [1m(pathway, pathway_protein, gene_protein)[0m={
    edge_index=[2, 86524],
    adj_t=[13288, 1734, nnz=86524],
    e_id=[86524]
  },
  [1m(gene_protein, ppi, gene_protein)[0m={
    edge_index=[2, 120004],
    edge_label=[17608],
    edge_label_index=[2, 17608],
    adj_t=[13288, 13288, nnz=120004],
    e_id=[120004]
  },
  [1m(gene_protein, gda, disease)[0m={
    edge_index=[2, 14374],
    edge_label=[2048],
    edge_label_index=[2, 2048],
    adj_t=[6306, 13288, nnz=14374],
    e_id=[14374],
    input_id=[2048]
  },
  [1m(gene_protein, pathway_protein, pathway)[0m={
    edge_index=[2, 29784],
    edge_label=[6823],
    edge_label_index=[2, 6823],
    a