# Grapher

## Data description

### [FinnGen](https://finngen.gitbook.io/documentation/)

- The FinnGen research project is an expedition to the frontier of genomics and medicine, with significant discoveries potentially arising from any one of Finland’s 500,000 biomedical pioneers.
- The project brings together a nation-wide network of Finnish biobanks, with every Finn able to participate in the study by giving biobank consent.
- As of the last update, there were 589,000 samples available, with a goal to reach 520,000 by 2023. The latest data freeze included combined genotype and health registry data from 473,681 individuals.
- The study utilizes samples collected by a nationwide network of Finnish biobanks and combines genome information with digital health care data from national health registries【8†source】.
- There's a need for samples from all over Finland as solutions in the field of personalized healthcare can be found only by looking at large populations. Every Finn can be a part of the FinnGen study by giving a biobank consent.
- The genome data produced during the project is owned by the Finnish biobanks and remains available for research purposes. The medical breakthroughs that arise from the project are expected to benefit health care systems and patients globally.
- The FinnGen research project is collaborative, involving all the same actors as drug development, with the aim to speed up the emergence of new innovations.
- The project's data freeze 9 results and summary statistics are now available, consisting of over 377,200 individuals, almost 20.2 M variants, and 2,272 disease endpoints. Results can be browsed online using the FinnGen web browser, and the summary statistics downloaded.
- The University of Helsinki is the organization responsible for the study, and the nationwide network of Finnish biobanks is participating in the study, thus covering the whole of Finland. The Helsinki Biobank coordinates the sample collection.
- For more information, the project can be contacted at finngen-info@helsinki.fi.

### Dataset

Here's the summary documentation for the DataFrame in bullet format:

- `#chrom`: This column represents the chromosome number where the genetic variant is located.

- `pos`: This is the position of the genetic variant on the chromosome.

- `ref`: This column represents the reference allele (or variant) at the genomic position.

- `alt`: This is the alternate allele observed at this position.

- `rsids`: This stands for reference SNP cluster ID. It's a unique identifier for each variant used in the dbSNP database.

- `nearest_genes`: This column represents the gene which is nearest to the variant.

- `pval`: This represents the p-value, which is a statistical measure for the strength of evidence against the null hypothesis.

- `mlogp`: This represents the minus log of the p-value, commonly used in genomic studies.

- `beta`: The beta coefficient represents the effect size of the variant.

- `sebeta`: This is the standard error of the beta coefficient.

- `af_alt`: This is the allele frequency of the alternate variant in the general population.

- `af_alt_cases`: This is the allele frequency of the alternate variant in the cases group.

- `af_alt_controls`: This is the allele frequency of the alternate variant in the control group.

- `causal`: This binary column indicates whether the variant is determined to be causal (1) or not (0).

- `trait`: This column represents the trait associated with the variant. In this dataset, it is the response to the drug paracetamol and NSAIDs.

## Load libraries

In [1]:
import os
import random


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score
from sklearn.preprocessing import LabelEncoder, StandardScaler

import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import to_undirected, negative_sampling

import networkx as nx
from ogb.io import DatasetSaver
from ogb.linkproppred import LinkPropPredDataset

from scipy.spatial import cKDTree

## Perform checks

In [2]:
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")

PyTorch version: 2.0.0+cu118
PyTorch Geometric version: 2.3.1


In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")          # Current CUDA device
    print(f"Using {torch.cuda.get_device_name()} ({device})")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
else:
    print("CUDA is not available on this device.")

Using NVIDIA GeForce RTX 3060 Ti (cuda)
CUDA version: 11.8
Number of CUDA devices: 1


## Load data and create new rows for each gene

In [4]:
data = pd.read_csv('~/Desktop/geometric-omics/FinnGenn/data/gwas-causal.csv', low_memory=False)

In [5]:
data['nearest_genes'] = data['nearest_genes'].astype(str)

# Assuming your DataFrame is called data and the relevant column is 'nearest_genes'
# First, let's split the gene names in the 'nearest_genes' column
split_genes = data['nearest_genes'].str.split(',')

# Flatten the list of split gene names
flat_genes = [item for sublist in split_genes for item in sublist]

# Then, we create a new DataFrame by repeating rows and substituting the gene names
data_new = (data.loc[data.index.repeat(split_genes.str.len())]
            .assign(nearest_genes=flat_genes))

# Reset index to have a standard index
data = data_new.reset_index(drop=True)

In [6]:
data

Unnamed: 0,#chrom,pos,ref,alt,rsids,nearest_genes,pval,mlogp,beta,sebeta,af_alt,af_alt_cases,af_alt_controls,causal,LD,lead,trait
0,1,13668,G,A,rs2691328,OR4F5,0.944365,0.024860,-0.005926,0.084918,0.005842,0.005729,0.005863,0,0,,T2D
1,1,14773,C,T,rs878915777,OR4F5,0.844305,0.073501,0.010088,0.051369,0.013495,0.013547,0.013485,0,0,,T2D
2,1,15585,G,A,rs533630043,OR4F5,0.841908,0.074735,0.031464,0.157751,0.001113,0.001125,0.001110,0,0,,T2D
3,1,16549,T,C,rs1262014613,OR4F5,0.343308,0.464316,0.241377,0.254711,0.000561,0.000620,0.000550,0,0,,T2D
4,1,16567,G,C,rs1194064194,OR4F5,0.129883,0.886447,0.130736,0.086319,0.004170,0.004250,0.004154,0,0,,T2D
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20565622,23,155697920,G,A,,,0.027115,1.566790,-0.013475,0.006097,0.290961,0.286054,0.291879,0,0,,T2D
20565623,23,155698443,C,A,,,0.178417,0.748564,-0.069907,0.051951,0.003259,0.003022,0.003304,0,0,,T2D
20565624,23,155698490,C,T,,,0.279640,0.553400,-0.020245,0.018725,0.024406,0.024312,0.024423,0,0,,T2D
20565625,23,155699751,C,T,,,0.078864,1.103120,-0.011284,0.006421,0.244829,0.241257,0.245498,0,0,,T2D


In [7]:
len(data)

20565627

## Spec

**Task Overview**
- The objective is to design and implement a multi-class link prediction model for analyzing relationships between SNP nodes and Phenotype nodes.

**Nodes and Their Features**
- There are two types of nodes: SNP Nodes and Phenotype Nodes.
- *Phenotype Nodes*: Each Phenotype Node represents a particular trait. This information comes from the `trait` column in the data.
- *SNP Nodes*: Each SNP Node is characterized by various features, including `rsids`, `nearest_genes`, `#chrom`, `pos`, `ref`, `alt`, `beta`, `sebeta`, `af_alt`, and `af_alt_cases` columns.

**Edges, Their Features, and Labels**
- Edges represent relationships between nodes. There are two types of edges: SNP-Phenotype and SNP-SNP.
- *SNP-Phenotype Edges*:
  - These edges are undirected, linking SNP Nodes and Phenotype Nodes.
  - The label for each edge is determined by the `causal` column in the data:
    - A label of +1 is assigned when `data['causal']` is 1, indicating a causal relationship.
    - A label of -1 is assigned when `data['causal']` is 0, indicating the absence of a causal relationship.
- *SNP-SNP Edges*:
  - These edges are undirected, linking an SNP Node (as identified by the `rsids` column) to another SNP Node (as identified by the `lead` column) in the same data row.
  - The label for each edge is determined by the `LD` column in the data:
    - A label of +2 is assigned when `data['LD']` is 1, signifying that the two SNPs are in linkage disequilibrium.
    - A label of -2 is assigned when `data['LD']` is 0, indicating that the two SNPs are not in linkage disequilibrium.

## Create graph

In [13]:
%%time

# Create mappings for phenotypes and SNPs to integer indices
phenotypes = data['trait'].unique()
snps = data['rsids'].unique()
phenotype_to_idx = {phenotype: idx for idx, phenotype in enumerate(phenotypes)}
snp_to_idx = {snp: idx + len(phenotypes) for idx, snp in enumerate(snps)}

# Create node feature vectors for phenotypes and SNPs
phenotype_features = data.loc[data['trait'].isin(phenotypes)][['trait']].drop_duplicates().sort_values(by='trait').reset_index(drop=True)
snp_features = data.loc[data['rsids'].isin(snps)][['rsids', 'nearest_genes', '#chrom', 'pos', 'ref', 'alt', 'beta', 'sebeta', 'af_alt', 'af_alt_cases']].drop_duplicates().sort_values(by='rsids').reset_index(drop=True)

# Create node type labels
node_types = torch.tensor([0] * len(phenotypes) + [1] * len(snps), dtype=torch.long)

# SNP-Phenotype edge creation based on 'causal' column
edges = data[['rsids', 'trait', 'causal']].drop_duplicates()

edges['snp_idx'] = edges['rsids'].map(snp_to_idx)
edges['phenotype_idx'] = edges['trait'].map(phenotype_to_idx)

# Create positive and negative edges for SNP-Phenotype
positive_edges_snp_phenotype = edges.loc[edges['causal'] == 1, ['snp_idx', 'phenotype_idx']].values
negative_edges_snp_phenotype = edges.loc[edges['causal'] == 0, ['snp_idx', 'phenotype_idx']].values

positive_edges_snp_phenotype = torch.tensor(positive_edges_snp_phenotype, dtype=torch.long).t().contiguous()
negative_edges_snp_phenotype = torch.tensor(negative_edges_snp_phenotype, dtype=torch.long).t().contiguous()

# SNP-SNP edge creation based on 'LD' column
snp_snp_edges = data[['rsids', 'lead', 'LD']].dropna().drop_duplicates()

snp_snp_edges['snp_idx'] = snp_snp_edges['rsids'].map(snp_to_idx)
snp_snp_edges['lead_snp_idx'] = snp_snp_edges['lead'].map(snp_to_idx)

# Create positive and negative edges for SNP-SNP
positive_edges_snp_snp = torch.tensor(positive_edges_snp_snp, dtype=torch.long).t().contiguous()

# Create a set of all SNP-SNP edges (both positive and negative)
all_snp_snp_edges = set(map(tuple, np.concatenate((positive_edges_snp_snp, negative_edges_snp_snp), axis=1)))

# Create a set of all SNPs
all_snps = set(range(len(phenotypes), len(phenotypes) + len(snps)))

# Initialize a list for negative SNP-SNP edges
negative_edges_snp_snp = []

# Create 1000 negative SNP-SNP edges for every positive edge
for _ in range(1000 * positive_edges_snp_snp.shape[1]):
    while True:
        # Randomly select two SNPs
        snp1, snp2 = random.sample(all_snps, 2)

        # If the pair does not exist in the set of all SNP-SNP edges, add it to the list of negative edges
        if (snp1, snp2) not in all_snp_snp_edges and (snp2, snp1) not in all_snp_snp_edges:
            negative_edges_snp_snp.append((snp1, snp2))
            break

negative_edges_snp_snp = torch.tensor(negative_edges_snp_snp, dtype=torch.long).t().contiguous()

# Combine SNP-Phenotype and SNP-SNP edges
edges = torch.cat([positive_edges_snp_phenotype, negative_edges_snp_phenotype, positive_edges_snp_snp, negative_edges_snp_snp], dim=1)

# Create edge attributes
edge_attr = torch.cat([torch.ones(positive_edges_snp_phenotype.size(1), dtype=torch.float),
                       -1 * torch.ones(negative_edges_snp_phenotype.size(1), dtype=torch.float),
                       2 * torch.ones(positive_edges_snp_snp.size(1), dtype=torch.float),
                       -2 * torch.ones(negative_edges_snp_snp.size(1), dtype=torch.float)])

# Combine the feature vectors
combined_features = pd.concat([phenotype_features, snp_features], ignore_index=True).drop(['trait', 'rsids'], axis=1)

# Now you can fill NaNs with 'N/A' or 0 for numerical columns
nan_replacements = {'nearest_genes': 'N/A', '#chrom': 'N/A', 'pos': 0, 'ref': 'N/A', 'alt': 'N/A', 'beta': 0, 'sebeta': 0, 'af_alt': 0, 'af_alt_cases': 0}
for col, replacement in nan_replacements.items():
    if col in combined_features:
        combined_features[col].fillna(replacement, inplace=True)

# Label encoding for categorical columns
le = LabelEncoder()
categorical_columns = ['nearest_genes', '#chrom', 'ref', 'alt']
for col in categorical_columns:
    combined_features[col] = le.fit_transform(combined_features[col].astype(str))

# Standardize numerical features
scaler = StandardScaler()
numerical_columns = ['pos', 'beta', 'sebeta', 'af_alt', 'af_alt_cases']
for col in numerical_columns:
    combined_features[col] = scaler.fit_transform(combined_features[[col]])

features = torch.tensor(combined_features.values, dtype=torch.float)

# Create the PyTorch Geometric graph
graph = Data(x=features, edge_index=edges, edge_attr=edge_attr)
graph.node_types = node_types



ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 1 dimension(s)

## Graph stats

In [14]:
from torch_geometric.utils import degree

def print_graph_stats(graph, positive_edges_snp_phenotype, negative_edges_snp_phenotype, positive_edges_snp_snp, negative_edges_snp_snp):
    node_types = np.unique(graph.node_types.numpy(), return_counts=True)

    print(f"Number of nodes: {graph.num_nodes}")
    for node_type, count in zip(*node_types):
        print(f"Number of {node_type} nodes: {count}")
    print(f"Number of positive edges between SNPs and phenotypes: {positive_edges_snp_phenotype.size(1)}")
    print(f"Number of negative edges between SNPs and phenotypes: {negative_edges_snp_phenotype.size(1)}")
    print(f"Number of positive edges between SNPs: {positive_edges_snp_snp.size(1)}")
    print(f"Number of negative edges between SNPs: {negative_edges_snp_snp.size(1)}")
    print(f"Number of edges: {graph.num_edges}")
    print(f"Node feature dimension: {graph.num_node_features}")

    # Compute and print degree-related stats for each node type
    for node_type in node_types[0]:
        node_indices = np.where(graph.node_types.numpy() == node_type)[0]
        degrees = degree(graph.edge_index[0], num_nodes=graph.num_nodes)[node_indices]
        average_degree = degrees.float().mean().item()
        median_degree = np.median(degrees.numpy())
        std_degree = degrees.float().std().item()

        print(f"\n{node_type} node stats:")
        print(f"Average degree: {average_degree:.2f}")
        print(f"Median degree: {median_degree:.2f}")
        print(f"Standard deviation of degree: {std_degree:.2f}")

    # Density is the ratio of actual edges to the maximum number of possible edges
    num_possible_edges = graph.num_nodes * (graph.num_nodes - 1) / 2
    density = graph.num_edges / num_possible_edges
    print(f"Density: {density:.10f}")

    # Check for NaN values in features
    nan_in_features = torch.isnan(graph.x).any().item()
    print(f"Are there any NaN values in features? {nan_in_features}")

# Print
print("Graph stats:")
print_graph_stats(graph, positive_edges_snp_phenotype, negative_edges_snp_phenotype, positive_edges_snp_snp, negative_edges_snp_snp)

Graph stats:
Number of nodes: 20565628
Number of 0 nodes: 1
Number of 1 nodes: 18709437
Number of positive edges between SNPs and phenotypes: 37
Number of negative edges between SNPs and phenotypes: 18709400
Number of positive edges between SNPs: 3934


AttributeError: 'list' object has no attribute 'size'

## Data splitting

- **Importing Modules**

  The script begins by importing the necessary Python libraries. It uses `random` for shuffling data, `torch` for handling tensors, and `torch_geometric.data` for its `Data` class, which is used to represent graph data.

- **Constants**

  The ratios for splitting the data into training, validation, and testing sets are defined. Both positive and negative edges are split equally into three parts.

- **Calculating Sample Size for Each Split**

  The number of samples for each set (training, validation, testing) are calculated separately for SNP-Gene and Gene-Trait edges. This is done for both positive and negative edges.

- **Shuffling the Edges**

  The edges for both positive and negative SNP-Gene and Gene-Trait data are shuffled. This ensures that the training, validation, and testing sets get a fair representation of the entire dataset.

- **Splitting the Edges**

  Both SNP-Gene and Gene-Trait edges are split according to the previously calculated sample sizes. The split is performed separately for both positive and negative edges.

- **Combining SNP-Gene and Gene-Trait Edges**

  After the split, the SNP-Gene and Gene-Trait edges are combined back together to form the final sets of edges for the training, validation, and testing sets.

- **Converting Edges back to Tensors**

  The lists of edges are then converted back into torch tensors. This conversion prepares the data for future operations with PyTorch's machine learning functionalities.

- **Creating Graphs**

  Graphs for training, validation, and testing sets are created. These graphs are instances of the `Data` class from the `torch_geometric.data` module. The graphs contain node features, edge indices, and edge attributes.

- **Setting Node Types**

  The node types for each graph are set. The node type information can be used for tasks such as node classification.

- **Printing the Graphs**

  Finally, the script prints out the graphs for the training, validation, and testing sets. This helps to ensure that the data has been correctly processed and is ready for the machine learning task.

# TO DO

In [None]:
# Constants
pos_train_ratio = 1/6
pos_val_ratio = 2/6
pos_test_ratio = 3/6

neg_train_ratio = 1/6
neg_val_ratio = 2/6
neg_test_ratio = 3/6

# Calculate the number of samples for each split
num_positive_snp_phenotype_train = int(positive_edges_snp_phenotype.size(1) * pos_train_ratio)
num_positive_snp_phenotype_val = int(positive_edges_snp_phenotype.size(1) * pos_val_ratio)
num_positive_snp_phenotype_test = positive_edges_snp_phenotype.size(1) - num_positive_snp_phenotype_train - num_positive_snp_phenotype_val

num_negative_snp_phenotype_train = int(negative_edges_snp_phenotype.size(1) * neg_train_ratio)
num_negative_snp_phenotype_val = int(negative_edges_snp_phenotype.size(1) * neg_val_ratio)
num_negative_snp_phenotype_test = negative_edges_snp_phenotype.size(1) - num_negative_snp_phenotype_train - num_negative_snp_phenotype_val

# Similar calculations for SNP-SNP edges
num_positive_snp_snp_train = int(positive_edges_snp_snp.size(1) * pos_train_ratio)
num_positive_snp_snp_val = int(positive_edges_snp_snp.size(1) * pos_val_ratio)
num_positive_snp_snp_test = positive_edges_snp_snp.size(1) - num_positive_snp_snp_train - num_positive_snp_snp_val

num_negative_snp_snp_train = int(negative_edges_snp_snp.size(1) * neg_train_ratio)
num_negative_snp_snp_val = int(negative_edges_snp_snp.size(1) * neg_val_ratio)
num_negative_snp_snp_test = negative_edges_snp_snp.size(1) - num_negative_snp_snp_train - num_negative_snp_snp_val

# Shuffle the positive and negative edges
positive_edges_snp_phenotype = positive_edges_snp_phenotype.t().tolist()
random.shuffle(positive_edges_snp_phenotype)

negative_edges_snp_phenotype = negative_edges_snp_phenotype.t().tolist()
random.shuffle(negative_edges_snp_phenotype)

positive_edges_snp_snp = positive_edges_snp_snp.t().tolist()
random.shuffle(positive_edges_snp_snp)

negative_edges_snp_snp = negative_edges_snp_snp.t().tolist()
random.shuffle(negative_edges_snp_snp)

# Splitting code for SNP-Phenotype positive and negative edges remains the same.

# Split SNP-SNP positive edges
positive_snp_snp_train_edges = positive_edges_snp_snp[:num_positive_snp_snp_train]
positive_snp_snp_val_edges = positive_edges_snp_snp[num_positive_snp_snp_train:num_positive_snp_snp_train + num_positive_snp_snp_val]
positive_snp_snp_test_edges = positive_edges_snp_snp[num_positive_snp_snp_train + num_positive_snp_snp_val:]

# Split SNP-SNP negative edges
negative_snp_snp_train_edges = negative_edges_snp_snp[:num_negative_snp_snp_train]
negative_snp_snp_val_edges = negative_edges_snp_snp[num_negative_snp_snp_train:num_negative_snp_snp_train + num_negative_snp_snp_val]
negative_snp_snp_test_edges = negative_edges_snp_snp[num_negative_snp_snp_train + num_negative_snp_snp_val:]

# Convert edges back to tensors
positive_train_edges = torch.tensor(positive_snp_phenotype_train_edges + positive_snp_snp_train_edges, dtype=torch.long).t().contiguous()
positive_val_edges = torch.tensor(positive_snp_phenotype_val_edges + positive_snp_snp_val_edges, dtype=torch.long).t().contiguous()
positive_test_edges = torch.tensor(positive_snp_phenotype_test_edges + positive_snp_snp_test_edges, dtype=torch.long).t().contiguous()

negative_train_edges = torch.tensor(negative_snp_phenotype_train_edges + negative_snp_snp_train_edges, dtype=torch.long).t().contiguous()
negative_val_edges = torch.tensor(negative_snp_phenotype_val_edges + negative_snp_snp_val_edges, dtype=torch.long).t().contiguous()
negative_test_edges = torch.tensor(negative_snp_phenotype_test_edges + negative_snp_snp_test_edges, dtype=torch.long).t().contiguous()

# Create train, validation, and test graphs
graph_train = Data(x=features, edge_index=torch.cat([positive_train_edges, negative_train_edges], dim=1), edge_attr=edge_attr)
graph_val = Data(x=features, edge_index=torch.cat([positive_val_edges, negative_val_edges], dim=1), edge_attr=edge_attr)
graph_test = Data(x=features, edge_index=torch.cat([positive_test_edges, negative_test_edges], dim=1), edge_attr=edge_attr)

# Set node types for train, validation, and test graphs
graph_train.node_types = node_types
graph_val.node_types = node_types
graph_test.node_types = node_types

# Print the graphs
print("Graph Train:")
print(graph_train)
print("\nGraph Validation:")
print(graph_val)
print("\nGraph Test:")
print(graph_test)

## Models

### Define helpers

In [None]:
from torch_geometric.utils import negative_sampling

def compute_recall(preds, true_labels):
    # Count the number of positive labels
    num_pos = np.sum(true_labels == 1)
    # Rank the predictions
    sorted_preds_idx = np.argsort(preds)[::-1]
    # Consider the top-k predictions to be positive
    pos_preds_binary = np.zeros_like(preds)
    pos_preds_binary[sorted_preds_idx[:num_pos]] = 1
    # Calculate the number of true positives and false negatives
    true_positives = np.sum((pos_preds_binary == 1) & (true_labels == 1))
    false_negatives = np.sum((pos_preds_binary == 0) & (true_labels == 1))
    # Calculate recall
    recall = true_positives / (true_positives + false_negatives)
    return recall

def compute_precision(preds, true_labels):
    # Count the number of positive labels
    num_pos = np.sum(true_labels == 1)
    # Rank the predictions
    sorted_preds_idx = np.argsort(preds)[::-1]
    # Consider the top-k predictions to be positive
    pos_preds_binary = np.zeros_like(preds)
    pos_preds_binary[sorted_preds_idx[:num_pos]] = 1
    # Calculate the number of true positives and false positives
    true_positives = np.sum((pos_preds_binary == 1) & (true_labels == 1))
    false_positives = np.sum((pos_preds_binary == 1) & (true_labels == 0))
    # Calculate precision
    precision = true_positives / (true_positives + false_positives)
    return precision

def compute_mrr(preds, true_labels):
    # Find the predicted scores for positive examples
    pos_preds = preds[:len(true_labels)]
    # Rank the positive examples by predicted score in descending order
    sorted_idx = np.argsort(pos_preds)[::-1]
    # Find the rank of the first true positive
    for i, idx in enumerate(sorted_idx):
        if true_labels[idx] == 1:
            return 1.0 / (i + 1)
    return 0.0

def compute_hits_at_k(preds, true_labels, k=1):
    # Find the predicted scores for positive examples
    pos_preds = preds[:len(true_labels)]
    # Rank the positive examples by predicted score in descending order
    sorted_idx = np.argsort(pos_preds)[::-1]
    # Check if the first k predictions contain at least one true positive
    hits = 0
    for idx in sorted_idx[:k]:
        if true_labels[idx] == 1:
            hits = 1
            break
    return hits

import torch.nn as nn

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.9, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

### Logistic Regression

In [None]:
# Define the Logistic Regression model
class LogReg(torch.nn.Module):
    def __init__(self, input_dim):
        super(LogReg, self).__init__()
        self.linear = torch.nn.Linear(input_dim, 1)

    def forward(self, x):
        out = self.linear(x)
        return torch.sigmoid(out)

# Train and evaluate the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LogReg(input_dim=9).to(device)  # 9-dimensional edge feature vectors

# I'm assuming you have these datasets ready
graph_train = graph_train.to(device)
graph_val = graph_val.to(device)
graph_test = graph_test.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    z = model(graph_train.x.float())

    pos_edge_index = graph_train.edge_index
    neg_edge_index = negative_sampling(edge_index=pos_edge_index, num_nodes=z.size(0), num_neg_samples=pos_edge_index.size(1))

    pos_logits = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1)
    neg_logits = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)

    logits = torch.cat([pos_logits, neg_logits], dim=0)
    targets = torch.tensor([1] * pos_edge_index.size(1) + [0] * neg_edge_index.size(1), dtype=torch.float32).to(device)

    loss = F.binary_cross_entropy_with_logits(logits, targets)
    loss.backward()
    optimizer.step()
    return loss.item()

# Evaluation function without model
def evaluate(edge_index, graph):
    model.eval()
    with torch.no_grad():
        z = model(graph.x.float())
        pos = torch.sigmoid((z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)).view(-1)
        neg_edge_index = negative_sampling(edge_index, num_nodes=graph.num_nodes, num_neg_samples=edge_index.size(1))
        neg = torch.sigmoid((z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)).view(-1)

        preds = np.concatenate([pos.cpu().numpy(), neg.cpu().numpy()])
        true_labels = np.concatenate([np.ones_like(pos.cpu().numpy()), np.zeros_like(neg.cpu().numpy())])

        # Compute Accuracy for positive edges
        pos_preds = preds[:len(pos)]  # positive predictions
        pos_labels = true_labels[:len(pos)]  # actual positive labels
        pos_accuracy = np.mean((pos_preds > 0.9) == pos_labels)

        # Other metrics
        roc_auc = roc_auc_score(true_labels, preds)
        mrr = compute_mrr(preds, true_labels)
        hits_at_5 = compute_hits_at_k(preds, true_labels, k=5)
        recall = compute_recall(preds, true_labels)
        precision = compute_precision(preds, true_labels)

        return roc_auc, mrr, hits_at_5, recall, precision, pos_accuracy

max_val_roc_auc = -np.inf
max_val_mrr = -np.inf
max_val_hits1 = -np.inf
max_val_recall = -np.inf
max_val_precision = -np.inf
max_val_pos_accuracy = -np.inf

max_test_roc_auc = -np.inf
max_test_mrr = -np.inf
max_test_hits1 = -np.inf
max_test_recall = -np.inf
max_test_precision = -np.inf
max_test_pos_accuracy = -np.inf

# Assuming the evaluate function is properly defined somewhere else
for epoch in range(25):
    loss = train()
    val_roc_auc, val_mrr, val_hits_at_5, val_recall, val_precision, val_pos_accuracy = evaluate(graph_val.edge_index, graph_val)
    print(f"Epoch: {epoch + 1}, Loss: {loss:.4f}, Val ROC-AUC: {val_roc_auc:.10f}, Val MRR: {val_mrr:.10f}, Val hits@1: {val_hits_at_5}, Val Recall: {val_recall:.10f}, Val Precision: {val_precision:.10f}, Val Pos Accuracy: {val_pos_accuracy:.10f}")
    max_val_roc_auc = max(max_val_roc_auc, val_roc_auc)
    max_val_mrr = max(max_val_mrr, val_mrr)
    max_val_hits1 = max(max_val_hits1, val_hits_at_5)
    max_val_recall = max(max_val_recall, val_recall)
    max_val_precision = max(max_val_precision, val_precision)
    max_val_pos_accuracy = max(max_val_pos_accuracy, val_pos_accuracy)


In [None]:
# For each epoch
val_roc_auc, val_mrr, val_hits1, val_recall, val_precision, val_pos_accuracy = evaluate(graph_val.edge_index, graph_val)
test_roc_auc, test_mrr, test_hits1, test_recall, test_precision, test_pos_accuracy = evaluate(graph_test.edge_index, graph_test)

max_val_roc_auc = max(max_val_roc_auc, val_roc_auc)
max_val_mrr = max(max_val_mrr, val_mrr)
max_val_hits1 = max(max_val_hits1, val_hits1)
max_val_recall = max(max_val_recall, val_recall)
max_val_precision = max(max_val_precision, val_precision)
max_val_pos_accuracy = max(max_val_pos_accuracy, val_pos_accuracy) # Add this line

max_test_roc_auc = max(max_test_roc_auc, test_roc_auc)
max_test_mrr = max(max_test_mrr, test_mrr)
max_test_hits1 = max(max_test_hits1, test_hits1)
max_test_recall = max(max_test_recall, test_recall)
max_test_precision = max(max_test_precision, test_precision)
max_test_pos_accuracy = max(max_test_pos_accuracy, test_pos_accuracy) # Add this line

# Print the maximum scores for each metric
print(f"Maximum Validation ROC-AUC: {max_val_roc_auc:.10f}")
print(f"Maximum Validation MRR: {max_val_mrr:.10f}")
print(f"Maximum Validation hits@1: {max_val_hits1:.10f}")
print(f"Maximum Validation Recall: {max_val_recall:.10f}")
print(f"Maximum Validation Precision: {max_val_precision:.10f}")
print(f"Maximum Validation Positive Edge Accuracy: {max_val_pos_accuracy:.10f}") # Add this line

print(f"Maximum Test ROC-AUC: {max_test_roc_auc:.10f}")
print(f"Maximum Test MRR: {max_test_mrr:.10f}")
print(f"Maximum Test hits@1: {max_test_hits1:.10f}")
print(f"Maximum Test Recall: {max_test_recall:.10f}")
print(f"Maximum Test Precision: {max_test_precision:.10f}")
print(f"Maximum Test Positive Edge Accuracy: {max_test_pos_accuracy:.10f}") # Add this line

### GCN

- The GCN (Graph Convolutional Network) model used in this script is a simple 2-layer GCN. It transforms the original 9-dimensional node feature vectors into 2-dimensional hidden representations, using the adjacency matrix (encoded by `edge_index`) to propagate information across the graph.
- The model is trained using Focal Loss, which is designed to address class imbalance problems. The training function computes the Focal Loss between the model's predictions on positive and negative edge examples, and the true edge labels. Negative edges are generated using a negative sampling method.
- During evaluation, the model's embeddings are used to predict whether an edge exists between each pair of nodes, and these predictions are compared to the actual edges in the validation or test graph. Several evaluation metrics are computed, including ROC AUC, Mean Reciprocal Rank (MRR), hits@1, Recall, and Precision.
- The training process iterates for 100 epochs. In each epoch, the model parameters are updated to minimize the Focal Loss on the training data, and the model's performance is evaluated on the validation data. The best validation scores on the ROC AUC, MRR, hits@1, Recall, and Precision metrics are tracked throughout training.
- After training, the model can be used to predict whether causal edges exist between nodes in a graph. This makes it suitable for tasks like link prediction in biological networks, where the nodes represent entities like genes or phenotypes and the edges represent relationships between them.

In [None]:
# Task: Link prediction: does a causal edge exist between two nodes?
# Node Types: 0 = phenotypes, 1 = snps
# Node Feature Vector: 10-dimensional

torch.cuda.empty_cache()

# Define the GCN model
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(9, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# Train and evaluate the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN(hidden_channels=2).to(device)

graph_train = graph_train.to(device)
graph_val = graph_val.to(device)
graph_test = graph_test.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Instantiate the loss function
focal_loss = FocalLoss(alpha=0.9, gamma=2.0).to(device)

def train():
    model.train()
    optimizer.zero_grad()
    z = model(graph_train.x.float(), graph_train.edge_index)

    # Only consider positive edges for the positive score calculation
    pos_edge_index = graph_train.edge_index
    pos = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1)

    # Use negative_sampling to generate negative edges
    neg_edge_index = negative_sampling(edge_index=pos_edge_index, num_nodes=z.size(0), num_neg_samples=pos_edge_index.size(1))
    neg = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)

    logits = torch.cat([pos, neg], dim=0)
    targets = torch.tensor([1] * pos.size(0) + [0] * neg.size(0), dtype=torch.float32).to(device)

    loss = focal_loss(logits, targets) # replace BCE with focal loss
    loss.backward()
    optimizer.step()
    return loss.item()

# Evaluation function
def evaluate(edge_index, graph):
    model.eval()
    with torch.no_grad():
        z = model(graph.x.float(), graph.edge_index)
        pos = torch.sigmoid((z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)).view(-1)
        neg_edge_index = negative_sampling(edge_index, num_nodes=graph.num_nodes, num_neg_samples=edge_index.size(1))
        neg = torch.sigmoid((z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)).view(-1)

        preds = np.concatenate([pos.cpu().numpy(), neg.cpu().numpy()])
        true_labels = np.concatenate([np.ones_like(pos.cpu().numpy()), np.zeros_like(neg.cpu().numpy())])

        # Compute Accuracy for positive edges
        pos_preds = preds[:len(pos)]  # positive predictions
        pos_labels = true_labels[:len(pos)]  # actual positive labels
        pos_accuracy = np.mean((pos_preds > 0.9) == pos_labels)

        # Other metrics
        roc_auc = roc_auc_score(true_labels, preds)
        mrr = compute_mrr(preds, true_labels)
        hits_at_5 = compute_hits_at_k(preds, true_labels, k=5)
        recall = compute_recall(preds, true_labels)
        precision = compute_precision(preds, true_labels)

        return roc_auc, mrr, hits_at_5, recall, precision, pos_accuracy


max_val_roc_auc = -np.inf
max_val_mrr = -np.inf
max_val_hits1 = -np.inf
max_val_recall = -np.inf
max_val_precision = -np.inf
max_val_pos_accuracy = -np.inf

max_test_roc_auc = -np.inf
max_test_mrr = -np.inf
max_test_hits1 = -np.inf
max_test_recall = -np.inf
max_test_precision = -np.inf
max_test_pos_accuracy = -np.inf

for epoch in range(25):
    loss = train()
    val_roc_auc, val_mrr, val_hits_at_5, val_recall, val_precision, val_pos_accuracy = evaluate(graph_val.edge_index, graph_val)
    print(f"Epoch: {epoch + 1}, Loss: {loss:.4f}, Val ROC-AUC: {val_roc_auc:.10f}, Val MRR: {val_mrr:.10f}, Val hits@1: {val_hits_at_5}, Val Recall: {val_recall:.10f}, Val Precision: {val_precision:.10f}, Val Pos Accuracy: {val_pos_accuracy:.10f}")
    max_val_roc_auc = max(max_val_roc_auc, val_roc_auc)
    max_val_mrr = max(max_val_mrr, val_mrr)
    max_val_hits1 = max(max_val_hits1, val_hits_at_5)
    max_val_recall = max(max_val_recall, val_recall)
    max_val_precision = max(max_val_precision, val_precision)
    max_val_pos_accuracy = max(max_val_pos_accuracy, val_pos_accuracy)

In [None]:
# For each epoch
val_roc_auc, val_mrr, val_hits1, val_recall, val_precision, val_pos_accuracy = evaluate(graph_val.edge_index, graph_val)
test_roc_auc, test_mrr, test_hits1, test_recall, test_precision, test_pos_accuracy = evaluate(graph_test.edge_index, graph_test)

max_val_roc_auc = max(max_val_roc_auc, val_roc_auc)
max_val_mrr = max(max_val_mrr, val_mrr)
max_val_hits1 = max(max_val_hits1, val_hits1)
max_val_recall = max(max_val_recall, val_recall)
max_val_precision = max(max_val_precision, val_precision)
max_val_pos_accuracy = max(max_val_pos_accuracy, val_pos_accuracy) # Add this line

max_test_roc_auc = max(max_test_roc_auc, test_roc_auc)
max_test_mrr = max(max_test_mrr, test_mrr)
max_test_hits1 = max(max_test_hits1, test_hits1)
max_test_recall = max(max_test_recall, test_recall)
max_test_precision = max(max_test_precision, test_precision)
max_test_pos_accuracy = max(max_test_pos_accuracy, test_pos_accuracy) # Add this line

# Print the maximum scores for each metric
print(f"Maximum Validation ROC-AUC: {max_val_roc_auc:.10f}")
print(f"Maximum Validation MRR: {max_val_mrr:.10f}")
print(f"Maximum Validation hits@1: {max_val_hits1:.10f}")
print(f"Maximum Validation Recall: {max_val_recall:.10f}")
print(f"Maximum Validation Precision: {max_val_precision:.10f}")
print(f"Maximum Validation Positive Edge Accuracy: {max_val_pos_accuracy:.10f}") # Add this line

print(f"Maximum Test ROC-AUC: {max_test_roc_auc:.10f}")
print(f"Maximum Test MRR: {max_test_mrr:.10f}")
print(f"Maximum Test hits@1: {max_test_hits1:.10f}")
print(f"Maximum Test Recall: {max_test_recall:.10f}")
print(f"Maximum Test Precision: {max_test_precision:.10f}")
print(f"Maximum Test Positive Edge Accuracy: {max_test_pos_accuracy:.10f}") # Add this line

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model)} parameters')