# Grapher

## Data description

### [FinnGen](https://finngen.gitbook.io/documentation/)

- The FinnGen research project is an expedition to the frontier of genomics and medicine, with significant discoveries potentially arising from any one of Finland’s 500,000 biomedical pioneers.
- The project brings together a nation-wide network of Finnish biobanks, with every Finn able to participate in the study by giving biobank consent.
- As of the last update, there were 589,000 samples available, with a goal to reach 520,000 by 2023. The latest data freeze included combined genotype and health registry data from 473,681 individuals.
- The study utilizes samples collected by a nationwide network of Finnish biobanks and combines genome information with digital health care data from national health registries【8†source】.
- There's a need for samples from all over Finland as solutions in the field of personalized healthcare can be found only by looking at large populations. Every Finn can be a part of the FinnGen study by giving a biobank consent.
- The genome data produced during the project is owned by the Finnish biobanks and remains available for research purposes. The medical breakthroughs that arise from the project are expected to benefit health care systems and patients globally.
- The FinnGen research project is collaborative, involving all the same actors as drug development, with the aim to speed up the emergence of new innovations.
- The project's data freeze 9 results and summary statistics are now available, consisting of over 377,200 individuals, almost 20.2 M variants, and 2,272 disease endpoints. Results can be browsed online using the FinnGen web browser, and the summary statistics downloaded.
- The University of Helsinki is the organization responsible for the study, and the nationwide network of Finnish biobanks is participating in the study, thus covering the whole of Finland. The Helsinki Biobank coordinates the sample collection.
- For more information, the project can be contacted at finngen-info@helsinki.fi.

### Dataset

Here's the summary documentation for the DataFrame in bullet format:

- `#chrom`: This column represents the chromosome number where the genetic variant is located.

- `pos`: This is the position of the genetic variant on the chromosome.

- `ref`: This column represents the reference allele (or variant) at the genomic position.

- `alt`: This is the alternate allele observed at this position.

- `rsids`: This stands for reference SNP cluster ID. It's a unique identifier for each variant used in the dbSNP database.

- `nearest_genes`: This column represents the gene which is nearest to the variant.

- `pval`: This represents the p-value, which is a statistical measure for the strength of evidence against the null hypothesis.

- `mlogp`: This represents the minus log of the p-value, commonly used in genomic studies.

- `beta`: The beta coefficient represents the effect size of the variant.

- `sebeta`: This is the standard error of the beta coefficient.

- `af_alt`: This is the allele frequency of the alternate variant in the general population.

- `af_alt_cases`: This is the allele frequency of the alternate variant in the cases group.

- `af_alt_controls`: This is the allele frequency of the alternate variant in the control group.

- `causal`: This binary column indicates whether the variant is determined to be causal (1) or not (0).

- `trait`: This column represents the trait associated with the variant. In this dataset, it is the response to the drug paracetamol and NSAIDs.

## Load libraries

In [1]:
import os
import random


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score
from sklearn.preprocessing import LabelEncoder, StandardScaler

import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import to_undirected, negative_sampling

import networkx as nx
from ogb.io import DatasetSaver
from ogb.linkproppred import LinkPropPredDataset

from scipy.spatial import cKDTree

## Perform checks

In [2]:
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")

PyTorch version: 2.0.0+cu118
PyTorch Geometric version: 2.3.1


In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")          # Current CUDA device
    print(f"Using {torch.cuda.get_device_name()} ({device})")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
else:
    print("CUDA is not available on this device.")

Using NVIDIA GeForce RTX 3060 Ti (cuda)
CUDA version: 11.8
Number of CUDA devices: 1


## Load data

In [4]:
data = pd.read_csv('~/Desktop/geometric-omics/FinnGenn/data/gwas-causal.csv', low_memory=False)

In [5]:
data.head()

Unnamed: 0,#chrom,pos,ref,alt,rsids,nearest_genes,pval,mlogp,beta,sebeta,af_alt,af_alt_cases,af_alt_controls,causal,trait
0,1,13668,G,A,rs2691328,OR4F5,0.618421,0.208716,-0.039244,0.078788,0.005845,0.005842,0.005858,0,RX_PARACETAMOL_NSAID
1,1,14773,C,T,rs878915777,OR4F5,0.852267,0.069424,-0.009138,0.049071,0.0135,0.013499,0.013502,0,RX_PARACETAMOL_NSAID
2,1,15585,G,A,rs533630043,OR4F5,0.687459,0.162753,0.060661,0.150784,0.001112,0.001115,0.001095,0,RX_PARACETAMOL_NSAID
3,1,16549,T,C,rs1262014613,OR4F5,0.798057,0.097966,0.062197,0.243085,0.000563,0.000561,0.000572,0,RX_PARACETAMOL_NSAID
4,1,16567,G,C,rs1194064194,OR4F5,0.363171,0.439889,-0.073028,0.080309,0.004192,0.004172,0.004295,0,RX_PARACETAMOL_NSAID


## Spec

Task
- causal link prediction 
- can a GNN learn how to predict the causal SNP?

Phenotype nodes features:
- `trait` column
- `#chrom` column


Gene nodes features:
- `nearest_genes` column
- `#chrom` column

SNP node features:
- `rsids` column
- `#chrom` column
- `pos` column
- `ref` column
- `alt` column
- `beta` column
- `sebeta` column
- `af_alt` column
- `af_alt_cases` column
- `af_alt_controls` column

Edge features:
- undirected 
- positive if `data['causal'] = 1`
- negative if `data['causal'] = 0`

## Create graph

In [6]:
%%time

# Create mappings for phenotypes, genes, and SNPs to integer indices
phenotypes = data['trait'].unique()
genes = data['nearest_genes'].unique()
snps = data['rsids'].unique()
phenotype_to_idx = {phenotype: idx for idx, phenotype in enumerate(phenotypes)}
gene_to_idx = {gene: idx + len(phenotypes) for idx, gene in enumerate(genes)}
snp_to_idx = {snp: idx + len(phenotypes) + len(genes) for idx, snp in enumerate(snps)}

# Create node feature vectors for phenotypes, genes, and SNPs
phenotype_features = data.loc[data['trait'].isin(phenotypes)][['trait', '#chrom']].drop_duplicates().sort_values(by='trait').reset_index(drop=True)
gene_features = data.loc[data['nearest_genes'].isin(genes)][['nearest_genes', '#chrom']].drop_duplicates().sort_values(by='nearest_genes').reset_index(drop=True)
snp_features = data.loc[data['rsids'].isin(snps)][['rsids', '#chrom', 'pos', 'ref', 'alt', 'beta', 'sebeta', 'af_alt', 'af_alt_cases', 'af_alt_controls']].drop_duplicates().sort_values(by='rsids').reset_index(drop=True)

# Create node type labels
node_types = torch.tensor([0] * len(phenotypes) + [1] * len(genes) + [2] * len(snps), dtype=torch.long)

# Edge creation based on 'causal' column
edges = data[['rsids', 'nearest_genes', 'trait', 'causal']].drop_duplicates()

edges['snp_idx'] = edges['rsids'].map(snp_to_idx)
edges['gene_idx'] = edges['nearest_genes'].map(gene_to_idx)
edges['phenotype_idx'] = edges['trait'].map(phenotype_to_idx)

# Create positive and negative edges for SNP-Gene and Gene-Phenotype
positive_edges_snp_gene = edges.loc[edges['causal'] == 1, ['snp_idx', 'gene_idx']].values
positive_edges_gene_phenotype = edges.loc[edges['causal'] == 1, ['gene_idx', 'phenotype_idx']].values

negative_edges_snp_gene = edges.loc[edges['causal'] == 0, ['snp_idx', 'gene_idx']].values
negative_edges_gene_phenotype = edges.loc[edges['causal'] == 0, ['gene_idx', 'phenotype_idx']].values

positive_edges_snp_gene = torch.tensor(positive_edges_snp_gene, dtype=torch.long).t().contiguous()
positive_edges_gene_phenotype = torch.tensor(positive_edges_gene_phenotype, dtype=torch.long).t().contiguous()
negative_edges_snp_gene = torch.tensor(negative_edges_snp_gene, dtype=torch.long).t().contiguous()
negative_edges_gene_phenotype = torch.tensor(negative_edges_gene_phenotype, dtype=torch.long).t().contiguous()

# Combine edges
edges = torch.cat([positive_edges_snp_gene, positive_edges_gene_phenotype, negative_edges_snp_gene, negative_edges_gene_phenotype], dim=1)

# Create edge attributes
edge_attr = torch.ones(edges.size(1), dtype=torch.float)
edge_attr[len(positive_edges_snp_gene[0])+len(positive_edges_gene_phenotype[0]):] *= -1  # Make non-causal edges negative

# Combine the feature vectors
combined_features = pd.concat([phenotype_features, gene_features, snp_features], ignore_index=True).drop(['trait', 'nearest_genes', 'rsids'], axis=1)

# Now you can fill NaNs with 'N/A'
nan_replacements = {'#chrom': 'N/A', 'pos': 0, 'ref': 'N/A', 'alt': 'N/A', 'beta': 0, 'sebeta': 0, 'af_alt': 0, 'af_alt_cases': 0, 'af_alt_controls': 0}
for col, replacement in nan_replacements.items():
    if col in combined_features:
        if combined_features[col].dtype.name == 'category' and replacement not in combined_features[col].cat.categories:
            combined_features[col] = combined_features[col].cat.add_categories([replacement])
        combined_features[col].fillna(replacement, inplace=True)

# Label encoding for categorical columns
le = LabelEncoder()
combined_features = combined_features.apply(lambda col: le.fit_transform(col.astype(str)) if col.dtype == 'object' else col)

# Standardize numerical features
scaler = StandardScaler()
numerical_columns = ['pos', 'beta', 'sebeta', 'af_alt', 'af_alt_cases', 'af_alt_controls']
categorical_columns = ['#chrom', 'ref', 'alt']
for col in numerical_columns:
    combined_features[col] = scaler.fit_transform(combined_features[[col]])

for col in categorical_columns:
    combined_features[col] = combined_features[col].astype('category').cat.codes

features = torch.tensor(combined_features.values, dtype=torch.float)

# Create the PyTorch Geometric graph
graph = Data(x=features, edge_index=edges, edge_attr=edge_attr)
graph.node_types = node_types

CPU times: total: 1min 16s
Wall time: 2min 18s


## Graph stats

In [15]:
from torch_geometric.utils import degree

def print_graph_stats(graph, positive_edges_snp_gene, positive_edges_gene_phenotype, negative_edges_snp_gene, negative_edges_gene_phenotype):
    print(f"Number of nodes: {graph.num_nodes}")
    print(f"Number of positive edges between SNPs and genes: {len(positive_edges_snp_gene)}")
    print(f"Number of positive edges between genes and phenotypes: {len(positive_edges_gene_phenotype)}")
    print(f"Number of negative edges for SNPs and genes: {len(negative_edges_snp_gene)}")
    print(f"Number of negative edges for genes and phenotypes: {len(negative_edges_gene_phenotype)}")
    print(f"Number of edges: {graph.num_edges}")
    print(f"Node feature dimension: {graph.num_node_features}")
    print(f"Node types: {graph.node_types}")

    # Computing degree-related stats
    degrees = degree(graph.edge_index[0]).numpy()  # for undirected graphs, use edge_index[0]
    average_degree = np.mean(degrees)
    median_degree = np.median(degrees)
    std_degree = np.std(degrees)

    print(f"Average degree: {average_degree:.2f}")
    print(f"Median degree: {median_degree}")
    print(f"Standard deviation of degree: {std_degree:.2f}")
    
    # Density is the ratio of actual edges to the maximum number of possible edges
    density = graph.num_edges / (graph.num_nodes * (graph.num_nodes - 1))
    print(f"Density: {density:.10f}")

    # Check for NaN values in features
    nan_in_features = torch.isnan(graph.x).any().item()
    print(f"Are there any NaN values in features? {nan_in_features}")

# Print
print("Graph stats:")
print_graph_stats(graph, positive_edges_snp_gene, positive_edges_gene_phenotype, negative_edges_snp_gene, negative_edges_gene_phenotype)

Graph stats:
Number of nodes: 20192197
Number of positive edges between SNPs and genes: 18
Number of positive edges between genes and phenotypes: 18
Number of negative edges for SNPs and genes: 18728888
Number of negative edges for genes and phenotypes: 18728888
Number of edges: 37457812
Node feature dimension: 9
Node types: tensor([0, 1, 1,  ..., 2, 2, 2])
Average degree: 2.00
Median degree: 1.0
Standard deviation of degree: 72.59
Density: 0.0000000919
Are there any NaN values in features? False


## Data splitting

- **Importing Modules**

  The script begins by importing the necessary Python libraries. It uses `random` for shuffling data, `torch` for handling tensors, and `torch_geometric.data` for its `Data` class, which is used to represent graph data.

- **Constants**

  The ratios for splitting the data into training, validation, and testing sets are defined. Both positive and negative edges are split equally into three parts.

- **Calculating Sample Size for Each Split**

  The number of samples for each set (training, validation, testing) are calculated separately for SNP-Gene and Gene-Trait edges. This is done for both positive and negative edges.

- **Shuffling the Edges**

  The edges for both positive and negative SNP-Gene and Gene-Trait data are shuffled. This ensures that the training, validation, and testing sets get a fair representation of the entire dataset.

- **Splitting the Edges**

  Both SNP-Gene and Gene-Trait edges are split according to the previously calculated sample sizes. The split is performed separately for both positive and negative edges.

- **Combining SNP-Gene and Gene-Trait Edges**

  After the split, the SNP-Gene and Gene-Trait edges are combined back together to form the final sets of edges for the training, validation, and testing sets.

- **Converting Edges back to Tensors**

  The lists of edges are then converted back into torch tensors. This conversion prepares the data for future operations with PyTorch's machine learning functionalities.

- **Creating Graphs**

  Graphs for training, validation, and testing sets are created. These graphs are instances of the `Data` class from the `torch_geometric.data` module. The graphs contain node features, edge indices, and edge attributes.

- **Setting Node Types**

  The node types for each graph are set. The node type information can be used for tasks such as node classification.

- **Printing the Graphs**

  Finally, the script prints out the graphs for the training, validation, and testing sets. This helps to ensure that the data has been correctly processed and is ready for the machine learning task.

In [8]:
import random
import torch
from torch_geometric.data import Data

# Constants
pos_train_ratio = 1/3
pos_val_ratio = 1/3
pos_test_ratio = 1/3

neg_train_ratio = 1/3
neg_val_ratio = 1/3
neg_test_ratio = 1/3

# Calculate the number of samples for each split
num_positive_snp_gene_train = int(positive_edges_snp_gene.size(1) * pos_train_ratio)
num_positive_snp_gene_val = int(positive_edges_snp_gene.size(1) * pos_val_ratio)
num_positive_snp_gene_test = positive_edges_snp_gene.size(1) - num_positive_snp_gene_train - num_positive_snp_gene_val

num_positive_gene_trait_train = int(positive_edges_gene_phenotype.size(1) * pos_train_ratio)
num_positive_gene_trait_val = int(positive_edges_gene_phenotype.size(1) * pos_val_ratio)
num_positive_gene_trait_test = positive_edges_gene_phenotype.size(1) - num_positive_gene_trait_train - num_positive_gene_trait_val

num_negative_snp_gene_train = int(negative_edges_snp_gene.size(1) * neg_train_ratio)
num_negative_snp_gene_val = int(negative_edges_snp_gene.size(1) * neg_val_ratio)
num_negative_snp_gene_test = negative_edges_snp_gene.size(1) - num_negative_snp_gene_train - num_negative_snp_gene_val

num_negative_gene_trait_train = int(negative_edges_gene_phenotype.size(1) * neg_train_ratio)
num_negative_gene_trait_val = int(negative_edges_gene_phenotype.size(1) * neg_val_ratio)
num_negative_gene_trait_test = negative_edges_gene_phenotype.size(1) - num_negative_gene_trait_train - num_negative_gene_trait_val

# Shuffle the positive and negative edges
positive_edges_snp_gene = positive_edges_snp_gene.t().tolist()
positive_edges_gene_phenotype = positive_edges_gene_phenotype.t().tolist()
random.shuffle(positive_edges_snp_gene)
random.shuffle(positive_edges_gene_phenotype)

negative_edges_snp_gene = negative_edges_snp_gene.t().tolist()
negative_edges_gene_phenotype = negative_edges_gene_phenotype.t().tolist()
random.shuffle(negative_edges_snp_gene)
random.shuffle(negative_edges_gene_phenotype)

# Split SNP-Gene positive edges
positive_snp_gene_train_edges = positive_edges_snp_gene[:num_positive_snp_gene_train]
positive_snp_gene_val_edges = positive_edges_snp_gene[num_positive_snp_gene_train:num_positive_snp_gene_train + num_positive_snp_gene_val]
positive_snp_gene_test_edges = positive_edges_snp_gene[num_positive_snp_gene_train + num_positive_snp_gene_val:]

# Split Gene-Trait positive edges
positive_gene_trait_train_edges = positive_edges_gene_phenotype[:num_positive_gene_trait_train]
positive_gene_trait_val_edges = positive_edges_gene_phenotype[num_positive_gene_trait_train:num_positive_gene_trait_train + num_positive_gene_trait_val]
positive_gene_trait_test_edges = positive_edges_gene_phenotype[num_positive_gene_trait_train + num_positive_gene_trait_val:]

# Split SNP-Gene negative edges
negative_snp_gene_train_edges = negative_edges_snp_gene[:num_negative_snp_gene_train]
negative_snp_gene_val_edges = negative_edges_snp_gene[num_negative_snp_gene_train:num_negative_snp_gene_train + num_negative_snp_gene_val]
negative_snp_gene_test_edges = negative_edges_snp_gene[num_negative_snp_gene_train + num_negative_snp_gene_val:]

# Split Gene-Trait negative edges
negative_gene_trait_train_edges = negative_edges_gene_phenotype[:num_negative_gene_trait_train]
negative_gene_trait_val_edges = negative_edges_gene_phenotype[num_negative_gene_trait_train:num_negative_gene_trait_train + num_negative_gene_trait_val]
negative_gene_trait_test_edges = negative_edges_gene_phenotype[num_negative_gene_trait_train + num_negative_gene_trait_val:]

# Combine SNP-Gene and Gene-Trait edges
positive_train_edges = positive_snp_gene_train_edges + positive_gene_trait_train_edges
positive_val_edges = positive_snp_gene_val_edges + positive_gene_trait_val_edges
positive_test_edges = positive_snp_gene_test_edges + positive_gene_trait_test_edges

negative_train_edges = negative_snp_gene_train_edges + negative_gene_trait_train_edges
negative_val_edges = negative_snp_gene_val_edges + negative_gene_trait_val_edges
negative_test_edges = negative_snp_gene_test_edges + negative_gene_trait_test_edges

# Convert edges back to tensors
positive_train_edges = torch.tensor(positive_train_edges, dtype=torch.long).t().contiguous()
positive_val_edges = torch.tensor(positive_val_edges, dtype=torch.long).t().contiguous()
positive_test_edges = torch.tensor(positive_test_edges, dtype=torch.long).t().contiguous()

negative_train_edges = torch.tensor(negative_train_edges, dtype=torch.long).t().contiguous()
negative_val_edges = torch.tensor(negative_val_edges, dtype=torch.long).t().contiguous()
negative_test_edges = torch.tensor(negative_test_edges, dtype=torch.long).t().contiguous()

# Create train, validation, and test graphs
graph_train = Data(x=features, edge_index=torch.cat([positive_train_edges, negative_train_edges], dim=1), edge_attr=edge_attr)
graph_val = Data(x=features, edge_index=torch.cat([positive_val_edges, negative_val_edges], dim=1), edge_attr=edge_attr)
graph_test = Data(x=features, edge_index=torch.cat([positive_test_edges, negative_test_edges], dim=1), edge_attr=edge_attr)

# Set node types for train, validation, and test graphs
graph_train.node_types = node_types
graph_val.node_types = node_types
graph_test.node_types = node_types

# Print the graphs
print("Graph Train:")
print(graph_train)
print("\nGraph Validation:")
print(graph_val)
print("\nGraph Test:")
print(graph_test)

Graph Train:
Data(x=[20192197, 9], edge_index=[2, 12485936], edge_attr=[37457812], node_types=[18731560])

Graph Validation:
Data(x=[20192197, 9], edge_index=[2, 12485936], edge_attr=[37457812], node_types=[18731560])

Graph Test:
Data(x=[20192197, 9], edge_index=[2, 12485940], edge_attr=[37457812], node_types=[18731560])


In [9]:
import torch.nn as nn

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

## Create model

- The GCN (Graph Convolutional Network) model used in this script is a simple 2-layer GCN. It transforms the original 9-dimensional node feature vectors into 2-dimensional hidden representations, using the adjacency matrix (encoded by `edge_index`) to propagate information across the graph.
- The model is trained using Focal Loss, which is designed to address class imbalance problems. The training function computes the Focal Loss between the model's predictions on positive and negative edge examples, and the true edge labels. Negative edges are generated using a negative sampling method.
- During evaluation, the model's embeddings are used to predict whether an edge exists between each pair of nodes, and these predictions are compared to the actual edges in the validation or test graph. Several evaluation metrics are computed, including ROC AUC, Mean Reciprocal Rank (MRR), Hits@5, Recall, and Precision.
- The training process iterates for 100 epochs. In each epoch, the model parameters are updated to minimize the Focal Loss on the training data, and the model's performance is evaluated on the validation data. The best validation scores on the ROC AUC, MRR, Hits@5, Recall, and Precision metrics are tracked throughout training.
- After training, the model can be used to predict whether causal edges exist between nodes in a graph. This makes it suitable for tasks like link prediction in biological networks, where the nodes represent entities like genes or phenotypes and the edges represent relationships between them.

In [10]:
# Task: Link prediction: does a causal edge exist between two nodes?
# Node Types: 0 = phenotypes, 1 = gene, 2 = snps
# Node Feature Vector: 9-dimensional

torch.cuda.empty_cache()

# Define the GCN model
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(9, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.75, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# Train and evaluate the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN(hidden_channels=2).to(device)

graph_train = graph_train.to(device)
graph_val = graph_val.to(device)
graph_test = graph_test.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Train function
from torch_geometric.utils import negative_sampling

# Instantiate the loss function
focal_loss = FocalLoss(alpha=0.75, gamma=2.0).to(device)

def train():
    model.train()
    optimizer.zero_grad()
    z = model(graph_train.x.float(), graph_train.edge_index)

    # Only consider positive edges for the positive score calculation
    pos_edge_index = graph_train.edge_index
    pos = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1)

    # Use negative_sampling to generate negative edges
    neg_edge_index = negative_sampling(edge_index=pos_edge_index, num_nodes=z.size(0), num_neg_samples=pos_edge_index.size(1))
    neg = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)

    logits = torch.cat([pos, neg], dim=0)
    targets = torch.tensor([1] * pos.size(0) + [0] * neg.size(0), dtype=torch.float32).to(device)

    loss = focal_loss(logits, targets) # replace BCE with focal loss
    loss.backward()
    optimizer.step()
    return loss.item()



# Evaluation function
def evaluate(edge_index, graph):
    model.eval()
    with torch.no_grad():
        z = model(graph.x.float(), graph.edge_index)
        pos = torch.sigmoid((z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)).view(-1)
        neg_edge_index = negative_sampling(edge_index, num_nodes=graph.num_nodes, num_neg_samples=edge_index.size(1))
        neg = torch.sigmoid((z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)).view(-1)

        preds = np.concatenate([pos.cpu().numpy(), neg.cpu().numpy()])
        true_labels = np.concatenate([np.ones_like(pos.cpu().numpy()), np.zeros_like(neg.cpu().numpy())])

        roc_auc = roc_auc_score(true_labels, preds)
        mrr = compute_mrr(preds, true_labels)
        hits_at_5 = compute_hits_at_k(preds, true_labels, k=5)
        recall = compute_recall(preds, true_labels)
        precision = compute_precision(preds, true_labels)

        return roc_auc, mrr, hits_at_5, recall, precision

def compute_recall(preds, true_labels):
    # Count the number of positive labels
    num_pos = np.sum(true_labels == 1)
    # Rank the predictions
    sorted_preds_idx = np.argsort(preds)[::-1]
    # Consider the top-k predictions to be positive
    pos_preds_binary = np.zeros_like(preds)
    pos_preds_binary[sorted_preds_idx[:num_pos]] = 1
    # Calculate the number of true positives and false negatives
    true_positives = np.sum((pos_preds_binary == 1) & (true_labels == 1))
    false_negatives = np.sum((pos_preds_binary == 0) & (true_labels == 1))
    # Calculate recall
    recall = true_positives / (true_positives + false_negatives)
    return recall

def compute_precision(preds, true_labels):
    # Count the number of positive labels
    num_pos = np.sum(true_labels == 1)
    # Rank the predictions
    sorted_preds_idx = np.argsort(preds)[::-1]
    # Consider the top-k predictions to be positive
    pos_preds_binary = np.zeros_like(preds)
    pos_preds_binary[sorted_preds_idx[:num_pos]] = 1
    # Calculate the number of true positives and false positives
    true_positives = np.sum((pos_preds_binary == 1) & (true_labels == 1))
    false_positives = np.sum((pos_preds_binary == 1) & (true_labels == 0))
    # Calculate precision
    precision = true_positives / (true_positives + false_positives)
    return precision

def compute_mrr(preds, true_labels):
    # Find the predicted scores for positive examples
    pos_preds = preds[:len(true_labels)]
    # Rank the positive examples by predicted score in descending order
    sorted_idx = np.argsort(pos_preds)[::-1]
    # Find the rank of the first true positive
    for i, idx in enumerate(sorted_idx):
        if true_labels[idx] == 1:
            return 1.0 / (i + 1)
    return 0.0

def compute_hits_at_k(preds, true_labels, k=5):
    # Find the predicted scores for positive examples
    pos_preds = preds[:len(true_labels)]
    # Rank the positive examples by predicted score in descending order
    sorted_idx = np.argsort(pos_preds)[::-1]
    # Check if the first k predictions contain at least one true positive
    hits = 0
    for idx in sorted_idx[:k]:
        if true_labels[idx] == 1:
            hits = 1
            break
    return hits

max_val_roc_auc = -np.inf
max_val_mrr = -np.inf
max_val_hits5 = -np.inf
max_val_recall = -np.inf
max_val_precision = -np.inf

max_test_roc_auc = -np.inf
max_test_mrr = -np.inf
max_test_hits5 = -np.inf
max_test_recall = -np.inf
max_test_precision = -np.inf

for epoch in range(100):
    loss = train()
    val_roc_auc, val_mrr, val_hits_at_5, val_recall, val_precision = evaluate(graph_val.edge_index, graph_val)
    print(f"Epoch: {epoch + 1}, Loss: {loss:.4f}, Val ROC-AUC: {val_roc_auc:.10f}, Val MRR: {val_mrr:.10f}, Val Hits@5: {val_hits_at_5}, Val Recall: {val_recall:.10f}, Val Precision: {val_precision:.10f}")
    max_val_roc_auc = max(max_val_roc_auc, val_roc_auc)
    max_val_mrr = max(max_val_mrr, val_mrr)
    max_val_hits5 = max(max_val_hits5, val_hits_at_5)
    max_val_recall = max(max_val_recall, val_recall)
    max_val_precision = max(max_val_precision, val_precision)

Epoch: 1, Loss: 126468728.0000, Val ROC-AUC: 0.5000002002, Val MRR: 0.5000000000, Val Hits@5: 1, Val Recall: 0.9999996796, Val Precision: 0.9999996796
Epoch: 2, Loss: 116658720.0000, Val ROC-AUC: 0.5000003204, Val MRR: 0.5000000000, Val Hits@5: 1, Val Recall: 0.9999997597, Val Precision: 0.9999997597
Epoch: 3, Loss: 106547888.0000, Val ROC-AUC: 0.5000002403, Val MRR: 0.5000000000, Val Hits@5: 1, Val Recall: 0.9999997597, Val Precision: 0.9999997597
Epoch: 4, Loss: 97265416.0000, Val ROC-AUC: 0.5000004005, Val MRR: 0.5000000000, Val Hits@5: 1, Val Recall: 0.9999998398, Val Precision: 0.9999998398
Epoch: 5, Loss: 88689480.0000, Val ROC-AUC: 0.5000000000, Val MRR: 0.5000000000, Val Hits@5: 1, Val Recall: 0.9999998398, Val Precision: 0.9999998398
Epoch: 6, Loss: 80992152.0000, Val ROC-AUC: 0.5000017219, Val MRR: 0.5000000000, Val Hits@5: 1, Val Recall: 0.9999987986, Val Precision: 0.9999987986
Epoch: 7, Loss: 73650456.0000, Val ROC-AUC: 0.5000012414, Val MRR: 0.5000000000, Val Hits@5: 1, V

## Evaluate model

In [11]:
val_roc_auc, val_mrr, val_hits5, val_recall, val_precision = evaluate(graph_val.edge_index, graph_val)
test_roc_auc, test_mrr, test_hits5, test_recall, test_precision = evaluate(graph_test.edge_index, graph_test)

max_val_roc_auc = max(max_val_roc_auc, val_roc_auc)
max_val_mrr = max(max_val_mrr, val_mrr)
max_val_hits5 = max(max_val_hits5, val_hits5)
max_val_recall = max(max_val_recall, val_recall)
max_val_precision = max(max_val_precision, val_precision)

max_test_roc_auc = max(max_test_roc_auc, test_roc_auc)
max_test_mrr = max(max_test_mrr, test_mrr)
max_test_hits5 = max(max_test_hits5, test_hits5)
max_test_recall = max(max_test_recall, test_recall)
max_test_precision = max(max_test_precision, test_precision)

print(f"Maximum Validation ROC-AUC: {max_val_roc_auc:.10f}")
print(f"Maximum Validation MRR: {max_val_mrr:.10f}")
print(f"Maximum Validation Hits@5: {max_val_hits5:.10f}")
print(f"Maximum Validation Recall: {max_val_recall:.10f}")
print(f"Maximum Validation Precision: {max_val_precision:.10f}")

print(f"Maximum Test ROC-AUC: {max_test_roc_auc:.10f}")
print(f"Maximum Test MRR: {max_test_mrr:.10f}")
print(f"Maximum Test Hits@5: {max_test_hits5:.10f}")
print(f"Maximum Test Recall: {max_test_recall:.10f}")
print(f"Maximum Test Precision: {max_test_precision:.10f}")

Maximum Validation ROC-AUC: 0.8804117969
Maximum Validation MRR: 1.0000000000
Maximum Validation Hits@5: 1.0000000000
Maximum Validation Recall: 0.9999998398
Maximum Validation Precision: 0.9999998398
Maximum Test ROC-AUC: 0.8804110212
Maximum Test MRR: 1.0000000000
Maximum Test Hits@5: 1.0000000000
Maximum Test Recall: 0.8147133496
Maximum Test Precision: 0.8147133496
