In [1]:
######## IMPORT EXTERNAL FILES ###########
import torch
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
import torch.nn as nn

import torch_geometric
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import train_test_split_edges, negative_sampling

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger

######### IMPORT INTERNAL FILES ###########
import sys
sys.path.append("../../src")
from GRAFF import *
from config import *

  from .autonotebook import tqdm as notebook_tqdm


Link prediction features initialized.....


In [2]:
import random
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt

def convert_to_networkx(graph, n_sample=None):

    g = to_networkx(graph, node_attrs=["x"])
    y = graph.y.numpy()

    if n_sample is not None:
        sampled_nodes = random.sample(g.nodes, n_sample)
        g = g.subgraph(sampled_nodes)
        y = y[sampled_nodes]

    return g, y


def plot_graph(g, y):

    plt.figure(figsize=(9, 7))
    nx.draw_spring(g, node_size=30, arrows=False, node_color=y)
    plt.show() 

In [None]:
g, y = convert_to_networkx(dataset[0])
plot_graph(g, y)

In [None]:

# final_dataset = train_test_split_edges(dataset[0], val_ratio = 0.1, test_ratio= 0.1)

In [None]:
# def indices(dataset, split_idx):
#     ''' According to the dataset, and the specified splitting (e.g. in Geom-GCN there are 10 splits) 
#         We identify the indices. 

#         args:
#           - dataset: torch-geometric data type,
#           - split_idx: in the Geom-GCN implementations the available splittings are from 0-9    
        
#         output:
#           - (train_indices, val_indices, test_indices):
#                  indices that corrensponds to the whole graph. 
    
#     '''

#     train_idx = dataset.train_mask[:, split_idx]
#     val_idx = dataset.val_mask[:, split_idx]
#     test_idx = dataset.test_mask[:, split_idx]

#     train_indices = torch.nonzero(train_idx)
#     val_indices = torch.nonzero(val_idx)
#     test_indices = torch.nonzero(test_idx)

#     return train_indices.squeeze(1), val_indices.squeeze(1), test_indices.squeeze(1)


 

# final = train_test_split_edges(dataset[0])


In [None]:
class DataModuleLP(pl.LightningDataModule):

    def __init__(self,  train_set, val_set, test_set, neg_edges, mode, batch_size):

        self.mode = mode  # "hp" or "test"
        self.batch_size = batch_size
        self.train_set, self.val_set, self.test_set = train_set, val_set, test_set
        self.neg_edges = neg_edges

    def setup(self, stage=None):
        if stage == 'fit':
            if self.mode == 'test':
                # For the test phase, after the hp tuning we unify train and val.
                self.train_set.edge_index = torch.concat((self.train_set.edge_index, self.val_set.edge_label_index), dim = -1)

            train_mask_edg = 0.7 * self.train_set.edge_index.shape[1] 

            self.train_set.pos_forward_pass = self.train_set.edge_index[:, :int(train_mask_edg)]

            # The remaining (30%) is used for the prediction
            self.train_set.pos_masked_edges = self.train_set.edge_index[:, int(train_mask_edg):]
            # The same amount used as positive in the prediction is taken from the negatives
            self.train_set.neg_edges = negative_edges[:, :self.train_set.pos_masked_edges.shape[1]]


        elif stage == 'test':
            # During the inference we attempt to predict the whole set as true.
            if self.mode == 'hp':
                self.val_set.neg_edges = negative_edges[:, self.train_set.pos_masked_edges.shape[1]: self.train_set.pos_masked_edges.shape[1] + 
                                                                    self.val_set.edge_label_index.shape[1]]
            elif self.mode == 'test':
                self.test_set.neg_edges = negative_edges[:, self.train_set.pos_masked_edges.shape[1]:self.train_set.pos_masked_edges.shape[1]+
                                                            self.test_set.edge_label_index.shape[1]]

    def train_dataloader(self, *args, **kwargs):
        return DataLoader([self.train_set], batch_size = batch_size, shuffle = False)
    def val_dataloader(self, *args, **kwargs):
        if self.mode == 'hp':
            return DataLoader([self.val_set], batch_size = batch_size, shuffle = False)
        elif self.mode == 'test':
            return DataLoader([self.test_set], batch_size = batch_size, shuffle = False)


In [6]:
mode = 'hp'

transform = RandomLinkSplit(is_undirected=True if dataset_name != 'Texas' else False)

# Edges are divided into three sets
train_data, val_data, test_data = transform(dataset[0])

# Negative edges are extracted
negative_edges = negative_sampling(dataset.edge_index, num_nodes=dataset.x.shape[0])

In [None]:
save = True

if save:
    torch.save("Texas/")

In [None]:
DM = DataModuleLP(train_data.clone(), val_data.clone(), test_data.clone(), negative_edges, mode = 'hp', batch_size = batch_size)
DM.setup('fit')
DM.setup('test')


In [None]:
for i in DM.train_dataloader():
    print(i)
for i in DM.val_dataloader():
    print(i)

In [None]:
# How to repeat the experiments? 
# What are the splittings? If i do my own splittings should i repeat the experiments? 
# Message passing questions..........