## Purpose 
I am just going to sanity check that the forward and reverse edges produced by [RandomLinkSplit](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.transforms.RandomLinkSplit.html) are inline with what is discussed in [this post](https://github.com/pyg-team/pytorch_geometric/issues/3668)

In [18]:

import pandas as pd
import torch 
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F
from torch_geometric.typing import Tensor
import tqdm
from sklearn.metrics import roc_auc_score, roc_curve, RocCurveDisplay, average_precision_score, precision_recall_curve, PrecisionRecallDisplay
import os
import argparse
from src.gnn_model import *
from src.utils import *
import numpy as np

## load data / make data splits

In [35]:
## first want to load the graph 
# data = torch.load('kg/msk_impact_unknown-free_gene-oncogenic-no-reverse_edges.pt')
data = torch.load('kg/msk_impact_unknown-free_gene-oncogenic.pt')
## want to drop the reverse edges
# Check if it's indeed a HeteroData object
assert isinstance(data, HeteroData), "Loaded data is not a HeteroData object."
## augment data for link prediction
data = augment_graph(data = data, use_dummy_features = False, num_dummy_features = 100, random_dummy_features = False)
## set up train test split transform
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0,
    ## this 

    ## before total = 9711, train = 7769, val = 7769, test = 8740
    add_negative_train_samples=False,
    # edge_types=forward_edges_list,
    # rev_edge_types=reverse_edges_list,
    is_undirected = True,
    edge_types=('patient', 'treated_with', 'drug'), 
    # rev_edge_types = ('drug', 'treating', 'patient')
    rev_edge_types = ('drug', 'treating', 'patient')
)
train_data, val_data, test_data = transform(data)

## general check 
lets make sure the overall number of forward and reverse edges are equal

In [36]:
num_total_rev_edges = data[('drug', 'treating', 'patient')]['edge_index'].shape[1]
num_total_forward_edges = data[('patient', 'treated_with', 'drug')]['edge_index'].shape[1]
print(num_total_rev_edges == num_total_forward_edges)

True


lets make sure that swapping the rows for the rev total edges makes them equal 

In [37]:
total_rev_edges = data[('drug', 'treating', 'patient')]['edge_index']
swapped_total_rev_edges =total_rev_edges[[1,0]]
total_forward_edges = data[('patient', 'treated_with', 'drug')]['edge_index']
print(torch.all(swapped_total_rev_edges.eq(total_forward_edges)))

tensor(True)


lets make sure that the indecies of all edges are unique

In [38]:
total_forward_edges = data[('patient', 'treated_with', 'drug')]['edge_index']
unique_forward_edges = np.unique(total_forward_edges,axis=1)
print(total_forward_edges.shape[1] == unique_forward_edges.shape[1])

9711 3689


## Verify the training set 
First as a sanity check lets confirm that edge_index and edge_label_index are the same for the forward edges

In [39]:
train_data_edge_index = train_data[('patient', 'treated_with', 'drug')]['edge_index']
train_data_edge_label_index = train_data[('patient', 'treated_with', 'drug')]['edge_label_index']
print(torch.all(train_data_edge_index.eq(train_data_edge_label_index)))

tensor(True)


now lets make sure that the size of tensor is equal to 80% of the total number of edges (since there are no negative samples added at this point)

In [40]:
num_train_edges = train_data[('patient', 'treated_with', 'drug')]['edge_label_index'].shape[1]
num_total_forward_edges = data[('patient', 'treated_with', 'drug')]['edge_index'].shape[1]
print(num_train_edges)
print(num_total_forward_edges)
print(num_train_edges == round(num_total_forward_edges*0.8))

7769
9711
True


Now confirm that switching the first and second row of the training data reverse edge edge_index is equal to the train data edge_index

In [41]:
train_data_edge_index = train_data[('patient', 'treated_with', 'drug')]['edge_index']
train_data_rev_edge_index = train_data[('drug', 'treating', 'patient')]['edge_index']
row_swapped_train_data_rev_edge_index = train_data_rev_edge_index[[1, 0]]
torch.all(row_swapped_train_data_rev_edge_index.eq(train_data_edge_index))

tensor(True)

So this shows that the indices of the reverse edges in the training set are equal to the indices set of edges used as labels in the training data, which is what we would expect.

## Verify the validation data 

Ok so the validation set should be the original 80% of the data from the training data. We can confirm this by just making sure the training and validation edge_index are equal in the forward and backwards direction

In [42]:
train_data_edge_index = train_data[('patient', 'treated_with', 'drug')]['edge_index']
val_data_edge_index = val_data[('patient', 'treated_with', 'drug')]['edge_index']
rev_train_data_edge_index = train_data[('drug', 'treating', 'patient')]['edge_index']
rev_val_data_edge_index = val_data[('drug', 'treating', 'patient')]['edge_index']
print(torch.all(val_data_edge_index.eq(train_data_edge_index)))
torch.all(rev_val_data_edge_index.eq(rev_train_data_edge_index))

tensor(True)


tensor(True)

Now we want to confirm that the sizes of the postive samples from the validation edge_index + validation edge_labels are equal to 90% of the overall data

In [44]:
num_val_data_edge_index = val_data[('patient', 'treated_with', 'drug')]['edge_index'].shape[1]
num_val_data_edge_label_index = val_data[('patient', 'treated_with', 'drug')]['edge_label'][val_data[('patient', 'treated_with', 'drug')]['edge_label']==1].shape[0]
num_total_forward_edges = data[('patient', 'treated_with', 'drug')]['edge_index'].shape[1]
print(num_total_forward_edges)
print(num_total_forward_edges*0.1)
print(num_val_data_edge_label_index)
print(num_val_data_edge_label_index == round(num_total_forward_edges*0.1))
print(num_val_data_edge_index + num_val_data_edge_label_index) == round(num_total_forward_edges*0.9)

9711
971.1
971
True
8740


False

ok now we want to confirm that adding the edges_index from the test set and the edge_index_label from the validation set, is equal to the test set's edge_index (which should be the edges used in training and validation) 

In [45]:
val_data_edge_index = val_data[('patient', 'treated_with', 'drug')]['edge_index']
val_data_postive_edge_label_index = val_data[('patient', 'treated_with', 'drug')]['edge_label_index'][:,val_data[('patient', 'treated_with', 'drug')]['edge_label']==1]
test_data_positive_edge_index = test_data[('patient', 'treated_with', 'drug')]['edge_index']
expected_test_data_positive_edge_index = torch.cat([val_data_edge_index, val_data_postive_edge_label_index], dim = 1)
torch.all(expected_test_data_positive_edge_index.eq(test_data_positive_edge_index))

tensor(True)

Now we want to confirm the same thing for the reverse edges 

In [46]:
val_data_rev_edge_index = val_data[('drug', 'treating', 'patient')]['edge_index']
val_data_positive_edge_label_index = val_data[('patient', 'treated_with', 'drug')]['edge_label_index'][:,val_data[('patient', 'treated_with', 'drug')]['edge_label']==1]
row_swapped_val_data_positive_edge_label_index = val_data_positive_edge_label_index[[1, 0]]
test_data_rev_positive_edge_index = test_data[('drug', 'treating', 'patient')]['edge_index']
expected_test_rev_data_positive_edge_index = torch.cat([val_data_rev_edge_index, row_swapped_val_data_positive_edge_label_index], dim = 1)
torch.all(expected_test_rev_data_positive_edge_index.eq(test_data_rev_positive_edge_index))

tensor(True)

## test set validation

So we already showed that the edge index in the test set is working as we want. So we just need to show that the edges labels are. we can firt do this by checking the length 

In [47]:
num_postive_edge_labels_test = test_data[('patient', 'treated_with', 'drug')]['edge_label_index'][:,test_data[('patient', 'treated_with', 'drug')]['edge_label']==1].shape[1]
num_total_postive_edges = data[('patient', 'treated_with', 'drug')]['edge_index'].shape[1]
num_edge_index_test = test_data[('patient', 'treated_with', 'drug')]['edge_index'].shape[1]

(num_edge_index_test + num_postive_edge_labels_test) == num_total_postive_edges





True

Now we want to verify that adding the new postive samples to the samples used for message passing is equal to the set of all origginal edges. We can confirm this by checking that edge_label_index + edge_label(for the postive samples) for the test data is equal the edge_index of the unsplit dataset (after sorting) 

In [48]:
postive_edge_labels_test = test_data[('patient', 'treated_with', 'drug')]['edge_label_index'][:,test_data[('patient', 'treated_with', 'drug')]['edge_label']==1]
edge_index_test = test_data[('patient', 'treated_with', 'drug')]['edge_index']
expected_total_edge_index = torch.cat((postive_edge_labels_test,edge_index_test),axis=1)
total_edge_index = data[('patient', 'treated_with', 'drug')]['edge_index']

sorted_expected_total_edge_index = torch.sort(expected_total_edge_index,1).values
sorted_total_edge_index = torch.sort(total_edge_index,1).values
torch.all(sorted_expected_total_edge_index.eq(sorted_total_edge_index))


tensor(True)

Now we want to varify the same thing for the reverse edges 

In [49]:
postive_edge_labels_test = test_data[('patient', 'treated_with', 'drug')]['edge_label_index'][:,test_data[('patient', 'treated_with', 'drug')]['edge_label']==1]
row_swapped_postive_edge_labes_test = postive_edge_labels_test[[1,0]]
rev_edge_index_test = test_data[('drug', 'treating', 'patient')]['edge_index']
expected_total_rev_edge_index = torch.cat((rev_edge_index_test,row_swapped_postive_edge_labes_test),axis=1)
rev_total_edge_index = data[('drug', 'treating', 'patient')]['edge_index']

sorted_expected_total_rev_edge_index = torch.sort(expected_total_rev_edge_index,1).values
sorted_total_rev_edge_index = torch.sort(rev_total_edge_index,1).values
torch.all(sorted_expected_total_rev_edge_index.eq(sorted_total_rev_edge_index))


tensor(True)