# Grapher

## Data description

### [FinnGen](https://finngen.gitbook.io/documentation/)

Any large Canadian GWAS-related clinical trials?

- The FinnGen research project is an expedition to the frontier of genomics and medicine, with significant discoveries potentially arising from any one of Finland’s 500,000 biomedical pioneers.
- The project brings together a nation-wide network of Finnish biobanks, with every Finn able to participate in the study by giving biobank consent.
- As of the last update, there were 589,000 samples available, with a goal to reach 520,000 by 2023. The latest data freeze included combined genotype and health registry data from 473,681 individuals.
- The study utilizes samples collected by a nationwide network of Finnish biobanks and combines genome information with digital health care data from national health registries【8†source】.
- There's a need for samples from all over Finland as solutions in the field of personalized healthcare can be found only by looking at large populations. Every Finn can be a part of the FinnGen study by giving a biobank consent.
- The genome data produced during the project is owned by the Finnish biobanks and remains available for research purposes. The medical breakthroughs that arise from the project are expected to benefit health care systems and patients globally.
- The FinnGen research project is collaborative, involving all the same actors as drug development, with the aim to speed up the emergence of new innovations.
- The project's data freeze 9 results and summary statistics are now available, consisting of over 377,200 individuals, almost 20.2 M variants, and 2,272 disease endpoints. Results can be browsed online using the FinnGen web browser, and the summary statistics downloaded.
- The University of Helsinki is the organization responsible for the study, and the nationwide network of Finnish biobanks is participating in the study, thus covering the whole of Finland. The Helsinki Biobank coordinates the sample collection.
- For more information, the project can be contacted at finngen-info@helsinki.fi.

### Dataset

Here's the summary documentation for the DataFrame in bullet format:

- `#chrom`: This column represents the chromosome number where the genetic variant is located.

- `pos`: This is the position of the genetic variant on the chromosome.

- `ref`: This column represents the reference allele (or variant) at the genomic position.

- `alt`: This is the alternate allele observed at this position.

- `rsids`: This stands for reference SNP cluster ID. It's a unique identifier for each variant used in the dbSNP database.

- `nearest_genes`: This column represents the gene which is nearest to the variant.

- `pval`: This represents the p-value, which is a statistical measure for the strength of evidence against the null hypothesis.

- `mlogp`: This represents the minus log of the p-value, commonly used in genomic studies.

- `beta`: The beta coefficient represents the effect size of the variant.

- `sebeta`: This is the standard error of the beta coefficient.

- `af_alt`: This is the allele frequency of the alternate variant in the general population.

- `af_alt_cases`: This is the allele frequency of the alternate variant in the cases group.

- `af_alt_controls`: This is the allele frequency of the alternate variant in the control group.

- `causal`: This binary column indicates whether the variant is determined to be causal (1) or not (0).

- `trait`: This column represents the trait associated with the variant. In this dataset, it is the response to the drug paracetamol and NSAIDs.

## Load libraries

In [1]:
import os
import random


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score
from sklearn.preprocessing import LabelEncoder, StandardScaler

import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import to_undirected, negative_sampling

import networkx as nx
from ogb.io import DatasetSaver
from ogb.linkproppred import LinkPropPredDataset

from scipy.spatial import cKDTree

## Perform checks

In [2]:
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")

PyTorch version: 2.0.0+cu118
PyTorch Geometric version: 2.3.1


In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")          # Current CUDA device
    print(f"Using {torch.cuda.get_device_name()} ({device})")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
else:
    print("CUDA is not available on this device.")

Using NVIDIA GeForce RTX 3060 Ti (cuda)
CUDA version: 11.8
Number of CUDA devices: 1


## Load data and create new rows for each gene

In [4]:
dtypes={
    '#chrom': 'int64',
    'pos': 'int64',
    'ref': 'string',
    'alt': 'string',
    'rsids': 'string',
    'nearest_genes': 'string',
    'pval': 'float64',
    'mlogp': 'float64',
    'beta': 'float64',
    'sebeta': 'float64',
    'af_alt': 'float64',
    'af_alt_cases': 'float64',
    'af_alt_controls': 'float64',
    'causal':'int64',
    'LD': 'int64',
    'lead': 'string',
    'trait': 'string'
}
data = pd.read_csv('~/Desktop/gwas-graph/FinnGen/data/gwas-causal.csv', dtype=dtypes)

In [5]:
data['nearest_genes'] = data['nearest_genes'].astype(str)

# Assuming your DataFrame is called data and the relevant column is 'nearest_genes'
# First, let's split the gene names in the 'nearest_genes' column
split_genes = data['nearest_genes'].str.split(',')

# Flatten the list of split gene names
flat_genes = [item for sublist in split_genes for item in sublist]

# Then, we create a new DataFrame by repeating rows and substituting the gene names
data_new = (data.loc[data.index.repeat(split_genes.str.len())]
            .assign(nearest_genes=flat_genes))

# Reset index to have a standard index
data = data_new.reset_index(drop=True)

In [6]:
data

Unnamed: 0,#chrom,pos,ref,alt,rsids,nearest_genes,pval,mlogp,beta,sebeta,af_alt,af_alt_cases,af_alt_controls,causal,LD,lead,trait
0,1,13668,G,A,rs2691328,OR4F5,0.944365,0.024860,-0.005926,0.084918,0.005842,0.005729,0.005863,0,0,,T2D
1,1,14773,C,T,rs878915777,OR4F5,0.844305,0.073501,0.010088,0.051369,0.013495,0.013547,0.013485,0,0,,T2D
2,1,15585,G,A,rs533630043,OR4F5,0.841908,0.074735,0.031464,0.157751,0.001113,0.001125,0.001110,0,0,,T2D
3,1,16549,T,C,rs1262014613,OR4F5,0.343308,0.464316,0.241377,0.254711,0.000561,0.000620,0.000550,0,0,,T2D
4,1,16567,G,C,rs1194064194,OR4F5,0.129883,0.886447,0.130736,0.086319,0.004170,0.004250,0.004154,0,0,,T2D
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20565622,23,155697920,G,A,,,0.027115,1.566790,-0.013475,0.006097,0.290961,0.286054,0.291879,0,0,,T2D
20565623,23,155698443,C,A,,,0.178417,0.748564,-0.069907,0.051951,0.003259,0.003022,0.003304,0,0,,T2D
20565624,23,155698490,C,T,,,0.279640,0.553400,-0.020245,0.018725,0.024406,0.024312,0.024423,0,0,,T2D
20565625,23,155699751,C,T,,,0.078864,1.103120,-0.011284,0.006421,0.244829,0.241257,0.245498,0,0,,T2D


In [7]:
# Drop rows with NaN values in 'rsids' column
data.dropna(subset=['rsids'], inplace=True)

# Keep only rows with unique values in 'rsids' column
data = data[data.duplicated(subset='rsids', keep=False) == False]

# Adjust the index if necessary
data.reset_index(drop=True, inplace=True)

In [8]:
# Count the number of unique values in rsids col
num_unique_rsids = data['trait'].nunique()
print("Number of unique values in rsids column:", num_unique_rsids)

Number of unique values in rsids column: 1


In [9]:
# Get unique elements of the 'trait' column
unique_traits = data['trait'].unique()

# Print the unique elements
for trait in unique_traits:
    print(trait)

T2D


In [10]:
# Count the number of rows where data['lead'] is <NA>
na_count = data['trait'].isna().sum()

print("Number of rows where data['rsids'] is <NA>: ", na_count)

Number of rows where data['rsids'] is <NA>:  0


In [11]:
len(data)

18295392

## Spec

**Task Overview**
- The objective is to design and implement a multi-class link prediction model for analyzing relationships between SNP nodes and Phenotype nodes.

**Nodes and Their Features**
- There are two types of nodes: SNP Nodes and Phenotype Nodes.
- *Phenotype Nodes*: Each Phenotype Node represents a particular trait. This information comes from the `trait` column in the data.
- *SNP Nodes*: Each SNP Node is characterized by various features, including `rsids`, `nearest_genes`, `#chrom`, `pos`, `ref`, `alt`, `beta`, `sebeta`, `af_alt`, and `af_alt_cases` columns.

**Edges, Their Features, and Labels**
- Edges represent relationships between nodes. There are two types of edges: SNP-Phenotype and SNP-SNP.
- *SNP-Phenotype Edges*:
  - These edges are undirected, linking SNP Nodes and Phenotype Nodes.
  - The label for each edge is determined by the `causal` column in the data:
    - A label of 0 is assigned when `data['causal']` is 1, indicating a causal relationship.
    - A label of 1 is assigned when `data['causal']` is 0, indicating the absence of a causal relationship.
- *SNP-SNP Edges*:
  - These edges are undirected, linking an SNP Node (as identified by the `rsids` column) to another SNP Node (as identified by the `lead` column) in the same data row.
  - The label for each edge is determined by the `LD` column in the data:
    - A label of 2 is assigned when `data['LD']` is 1, signifying that the two SNPs are in linkage disequilibrium.
    - A label of 3 is assigned when `data['LD']` is 0, indicating that the two SNPs are not in linkage disequilibrium.

## Create graph

### Data Preprocessing Functions

In [12]:
def fill_na(df, fill_values):
    return df.fillna(fill_values)

def encode_columns(df, columns):
    le = LabelEncoder()
    return df[columns].apply(lambda x: le.fit_transform(x))

def scale_columns(df, columns):
    scaler = StandardScaler()
    return df[columns].apply(lambda x: pd.Series(scaler.fit_transform(x.values.reshape(-1, 1)).flatten(), index=x.index))

### Node Data Preprocessing 

In [13]:
# Deep copy the dataframe to avoid modifying the original data
data_copy = data.copy()

# Convert columns into categorical data type initially to reduce redundancy
categorical_columns = ['trait', 'rsids', 'lead']
for col in categorical_columns:
    data_copy[col] = data_copy[col].astype('category')
    if 'N/A' not in data_copy[col].cat.categories:
        data_copy[col] = data_copy[col].cat.add_categories('N/A')

# Fill NaNs for 'lead' column
data_copy['lead'] = data_copy['lead'].fillna('N/A')

# Mapping traits and rsids to integers
data_copy['trait_codes'] = data_copy['trait'].cat.codes
data_copy['rsids_codes'] = data_copy['rsids'].cat.codes 

# Get unique traits and rsids
unique_traits = data_copy['trait'].unique()
unique_rsids = data_copy['rsids'].unique()

assert len(unique_traits) == data_copy['trait'].nunique(), "Mismatch in the number of unique traits"
assert len(unique_rsids) == data_copy['rsids'].nunique(), "Mismatch in the number of unique rsids"

# Define node features and types
node_features_columns = ['trait', 'rsids', 'nearest_genes', '#chrom', 'pos', 'ref', 'alt', 'beta', 'sebeta', 'af_alt', 'af_alt_cases']
node_features = data_copy[node_features_columns]

# Keep the rows where either trait or rsid is unique
node_features = node_features[node_features['trait'].isin(unique_traits) | node_features['rsids'].isin(unique_rsids)]

# Create a new row with only 'trait' value and NaNs for all other cols
new_row = pd.DataFrame([['trait'] + [np.nan] * (node_features.shape[1] - 1)], columns=node_features.columns)
node_features = pd.concat([node_features, new_row], ignore_index=True)

print("Node feature shape", node_features.shape[0])
print("Number of unique traits:", len(unique_traits))
print("Number of unique rsids:", len(unique_rsids))

# The node_features should not double the sum of unique traits and rsids
assert node_features.shape[0] == len(unique_traits) + len(unique_rsids), "Mismatch in the number of nodes and unique traits/rsids"

# Define node types separately for traits and unique rsids values
trait_nodes = len(unique_traits)
rsids_nodes = len(unique_rsids)
total_nodes = trait_nodes + rsids_nodes
node_types = torch.tensor([0] * trait_nodes + [1] * rsids_nodes, dtype=torch.long)

print("Total number of nodes:", total_nodes)
print("Number of node types:", len(node_types))

assert len(node_types) == total_nodes, "Mismatch in the number of nodes and node types"

Node feature shape 18295393
Number of unique traits: 1
Number of unique rsids: 18295392
Total number of nodes: 18295393
Number of node types: 18295393


### Edge Data Processing

In [14]:
# Processing edge data
data_copy['lead'] = data_copy['lead'].fillna('N/A')  # Fill NaNs for 'lead' column
data_copy['lead_codes'] = data_copy['lead'].cat.codes  # Now you can do the mapping

edge_columns = ['rsids_codes', 'trait_codes', 'lead_codes', 'causal', 'LD']
edges = data_copy[edge_columns]

# Create edges based on 'causal' and 'LD' conditions
positive_edges_snp_phenotype = edges[edges['causal'] == 1][['rsids_codes', 'trait_codes']].values.T
negative_edges_snp_phenotype = edges[edges['causal'] == 0][['rsids_codes', 'trait_codes']].values.T
positive_edges_snp_snp = edges[edges['LD'] == 1][['rsids_codes', 'lead_codes']].values.T
negative_edges_snp_snp = edges[edges['LD'] == 0][['rsids_codes', 'lead_codes']].values.T

# Combine all edges and their labels
edges = np.concatenate([positive_edges_snp_phenotype, negative_edges_snp_phenotype, positive_edges_snp_snp, negative_edges_snp_snp], axis=1)
edge_labels = np.concatenate([
    np.full(positive_edges_snp_phenotype.shape[1], 0), # Label 0 for causal relationships in SNP-Phenotype edges
    np.full(negative_edges_snp_phenotype.shape[1], 1), # Label 1 for non-causal relationships in SNP-Phenotype edges
    np.full(positive_edges_snp_snp.shape[1], 2),       # Label 2 for linkage disequilibrium in SNP-SNP edges
    np.full(negative_edges_snp_snp.shape[1], 3)        # Label 3 for non-linkage disequilibrium in SNP-SNP edges
])

edges = torch.from_numpy(edges)
edge_labels = torch.from_numpy(edge_labels)

# Print relevant information
print("Positive SNP-Phenotype edges:", positive_edges_snp_phenotype.shape[1])
print("Negative SNP-Phenotype edges:", negative_edges_snp_phenotype.shape[1])
print("Positive SNP-SNP edges:", positive_edges_snp_snp.shape[1])
print("Negative SNP-SNP edges:", negative_edges_snp_snp.shape[1])
print("Final edges tensor:", edges.size())
print("Edge labels tensor:", edge_labels.size())

# Assertions
assert edges.shape[0] == 2, "Edges tensor should have shape (2, num_edges)"
assert edge_labels.ndim == 1, "Edge labels tensor should be 1-dimensional"
assert edges.shape[1] == edge_labels.shape[0], "Edges and edge_labels should have the same number of columns"

Positive SNP-Phenotype edges: 37
Negative SNP-Phenotype edges: 18295355
Positive SNP-SNP edges: 3787
Negative SNP-SNP edges: 18291605
Final edges tensor: torch.Size([2, 36590784])
Edge labels tensor: torch.Size([36590784])


### Edge Attributes and Node Feature Preprocessing

In [15]:
from sklearn.preprocessing import StandardScaler

# Function to encode and scale columns
def process_columns(df, encode_cols, scale_cols):
    # Encode columns
    for col in encode_cols:
        df[col] = df[col].astype('category').cat.codes
    # Scale columns
    scaler = StandardScaler()
    for col in scale_cols:
        df[col] = scaler.fit_transform(df[col].values.reshape(-1, 1))
    return df

# Preprocessing node features
fill_values = {'nearest_genes': 'N/A', '#chrom': 'N/A', 'pos': 0, 'ref': 'N/A', 'alt': 'N/A', 'beta': 0, 'sebeta': 0, 'af_alt': 0, 'af_alt_cases': 0}
node_features = fill_na(node_features, fill_values)

# Define columns to be encoded and scaled
categorical_columns = ['trait', 'rsids', 'nearest_genes', '#chrom', 'ref', 'alt']
numerical_columns = ['pos', 'beta', 'sebeta', 'af_alt', 'af_alt_cases']

# Process node features
node_features = process_columns(node_features, categorical_columns, numerical_columns)

# Convert to tensor
node_features = torch.tensor(node_features.values, dtype=torch.float)

# Compute edge attributes
edge_attr = torch.tensor([0] * positive_edges_snp_phenotype.shape[1] + [1] * negative_edges_snp_phenotype.shape[1] + [2] * positive_edges_snp_snp.shape[1] + [3] * negative_edges_snp_snp.shape[1], dtype=torch.float)

# Print relevant information
print("Edge attribute tensor:")
print(edge_attr)
print()
print("Node features after processing:")
print(node_features)

# Assertions
assert node_features.ndim == 2, "Node features tensor should have 2 dimensions"
assert edge_attr.ndim == 1, "Edge attribute tensor should be 1-dimensional"

Edge attribute tensor:
tensor([0., 0., 0.,  ..., 3., 3., 3.])

Node features after processing:
tensor([[ 0.0000e+00,  9.3201e+06,  1.0744e+04,  ..., -6.4647e-02,
         -5.2159e-01, -5.2213e-01],
        [ 0.0000e+00,  1.7153e+07,  1.0744e+04,  ..., -3.0032e-01,
         -4.8745e-01, -4.8724e-01],
        [ 0.0000e+00,  1.1052e+07,  1.0744e+04,  ...,  4.4701e-01,
         -5.4269e-01, -5.4267e-01],
        ...,
        [ 0.0000e+00,  2.4669e+06,  1.2527e+04,  ..., -1.8717e-01,
         -5.3397e-01, -5.3461e-01],
        [ 0.0000e+00,  3.5383e+06,  1.2527e+04,  ...,  4.1164e-01,
         -5.4174e-01, -5.4195e-01],
        [ 1.0000e+00, -1.0000e+00,  9.7710e+03,  ..., -6.6119e-01,
         -5.4766e-01, -5.4769e-01]])


### Construct Graph Data Object

In [16]:
# Convert node features to tensor
if isinstance(node_features, pd.DataFrame):
    node_features_tensor = torch.tensor(node_features.values(), dtype=torch.float)
else:
    node_features_tensor = node_features

# Create the PyTorch Geometric data object
graph = Data(x=node_features_tensor, edge_index=edges, edge_attr=edge_attr)

graph.y = torch.tensor(data['causal'].map({1: 0, 0: 1}).tolist() + data['LD'].map({1: 2, 0: 3}).tolist(), dtype=torch.long)

# Truncate the node_types tensor to match the adjusted number of nodes
graph.node_types = node_types

# Assertions
assert graph.x.ndim == 2, "Node features tensor should have 2 dimensions"
assert graph.edge_index.ndim == 2, "Edge index tensor should have 2 dimensions"
assert graph.edge_attr.ndim == 1, "Edge attribute tensor should be 1-dimensional"
assert graph.num_nodes == total_nodes, "Number of nodes in the graph should match the total number of nodes"
assert graph.num_node_features == node_features_tensor.shape[1], "Number of node features in the graph should match the number of columns in node_features"
assert graph.edge_index.size(1) == edges.shape[1], "Number of edges in the graph should match the number of columns in edges"
assert graph.edge_attr.size(0) == graph.edge_index.size(1), "Number of edge attributes should match the number of edges"

print(graph)

Data(x=[18295393, 11], edge_index=[2, 36590784], edge_attr=[36590784], y=[36590784], node_types=[18295393])


## Graph stats

In [17]:
from torch_geometric.utils import degree

def print_graph_stats(graph, positive_edges_snp_phenotype, negative_edges_snp_phenotype, positive_edges_snp_snp, negative_edges_snp_snp):
    node_types = np.unique(graph.node_types.numpy(), return_counts=True)

    print(f"Number of nodes: {graph.num_nodes}")
    for node_type, count in zip(*node_types):
        print(f"Number of {node_type} nodes: {count}")
    print(f"Number of positive edges between SNPs and phenotypes: {positive_edges_snp_phenotype.shape[1]}")
    print(f"Number of negative edges between SNPs and phenotypes: {negative_edges_snp_phenotype.shape[1]}")
    print(f"Number of positive edges between SNPs: {positive_edges_snp_snp.shape[1]}")
    print(f"Number of negative edges between SNPs: {negative_edges_snp_snp.shape[1]}")
    print(f"Number of edges: {graph.num_edges}")
    print(f"Node feature dimension: {graph.num_node_features}")

    # Compute and print degree-related stats for each node type
    for node_type in node_types[0]:
        node_indices = np.where(graph.node_types.numpy() == node_type)[0]
        degrees = degree(graph.edge_index[0].long(), num_nodes=graph.num_nodes)[node_indices]
        average_degree = degrees.float().mean().item()
        median_degree = np.median(degrees.numpy())
        std_degree = degrees.float().std().item()

        print(f"\n{node_type} node stats:")
        print(f"Average degree: {average_degree:.2f}")
        print(f"Median degree: {median_degree:.2f}")
        print(f"Standard deviation of degree: {std_degree:.2f}")

    # Density is the ratio of actual edges to the maximum number of possible edges
    num_possible_edges = graph.num_nodes * (graph.num_nodes - 1) / 2
    density = graph.num_edges / num_possible_edges
    print(f"Density: {density:.10f}")

    # Check for NaN values in features
    nan_in_features = torch.isnan(graph.x).any().item()
    print(f"Are there any NaN values in features? {nan_in_features}")

# Print
print("Graph stats:")
print_graph_stats(graph, positive_edges_snp_phenotype, negative_edges_snp_phenotype, positive_edges_snp_snp, negative_edges_snp_snp)

Graph stats:
Number of nodes: 18295393
Number of 0 nodes: 1
Number of 1 nodes: 18295392
Number of positive edges between SNPs and phenotypes: 37
Number of negative edges between SNPs and phenotypes: 18295355
Number of positive edges between SNPs: 3787
Number of negative edges between SNPs: 18291605
Number of edges: 36590784
Node feature dimension: 11

0 node stats:
Average degree: 2.00
Median degree: 2.00
Standard deviation of degree: nan

1 node stats:
Average degree: 2.00
Median degree: 2.00
Standard deviation of degree: 0.00
Density: 0.0000002186
Are there any NaN values in features? False


## Data splitting

In [22]:
def shuffle_graph_data(data, seed=None):
    """Shuffle the edges of a graph along with their corresponding attributes."""
    if seed is not None:
        torch.manual_seed(seed)
    perm = torch.randperm(data.num_edges)
    print(f"Permutation vector: {perm}")

    # Shuffle edge_index
    data.edge_index = data.edge_index[:, perm]

    # Shuffle edge attributes
    if data.edge_attr is not None:
        data.edge_attr = data.edge_attr[perm]

    # Shuffle edge labels
    data.y = data.y[perm]

    return data

# Shuffle the graph data
graph = shuffle_graph_data(graph, seed=12345)

# Compute the sizes of training, validation, and test sets
train_size = int(graph.num_edges * 0.6)
val_size = (graph.num_edges - train_size) // 2
test_size = graph.num_edges - train_size - val_size

print(f"Train size: {train_size}, Validation size: {val_size}, Test size: {test_size}")
assert train_size + val_size + test_size == graph.num_edges, "Sizes do not add up to total number of edges."

# Create index for each set
train_index = torch.arange(train_size)
val_index = torch.arange(train_size, train_size + val_size)
test_index = torch.arange(train_size + val_size, graph.num_edges)

# Split edge_index
train_edge_index = graph.edge_index[:, train_index]
val_edge_index = graph.edge_index[:, val_index]
test_edge_index = graph.edge_index[:, test_index]

# Check edge indices
print(f"Train edge index shape: {train_edge_index.shape}")
print(f"Validation edge index shape: {val_edge_index.shape}")
print(f"Test edge index shape: {test_edge_index.shape}")

# Split edge attributes
if graph.edge_attr is not None:
    train_edge_attr = graph.edge_attr[train_index]
    val_edge_attr = graph.edge_attr[val_index]
    test_edge_attr = graph.edge_attr[test_index]
else:
    train_edge_attr, val_edge_attr, test_edge_attr = None, None, None

# Split edge labels
train_y = graph.y[train_index]
val_y = graph.y[val_index]
test_y = graph.y[test_index]

# Check labels
print(f"Train labels shape: {train_y.shape}")
print(f"Validation labels shape: {val_y.shape}")
print(f"Test labels shape: {test_y.shape}")

# Create the data for each set
train_data = Data(x=graph.x, edge_index=train_edge_index, edge_attr=train_edge_attr, y=train_y)
val_data = Data(x=graph.x, edge_index=val_edge_index, edge_attr=val_edge_attr, y=val_y)
test_data = Data(x=graph.x, edge_index=test_edge_index, edge_attr=test_edge_attr, y=test_y)

print(f"Number of training edges: {train_data.num_edges}")
print(f"Number of validation edges: {val_data.num_edges}")
print(f"Number of test edges: {test_data.num_edges}")

# Check that the sum of the edges in the split datasets equals the total number of edges
assert train_data.num_edges + val_data.num_edges + test_data.num_edges == graph.num_edges, "Split datasets do not add up to total number of edges."

# Check the number of unique classes in the original graph data
num_classes = torch.unique(graph.y).size(0)
assert num_classes == 4, "The original graph data does not have 4 classes."

# Check the number of unique classes in the split datasets
train_num_classes = torch.unique(train_data.y).size(0)
val_num_classes = torch.unique(val_data.y).size(0)
test_num_classes = torch.unique(test_data.y).size(0)
assert train_num_classes == 4, "The training data does not have 4 classes."
assert val_num_classes == 4, "The validation data does not have 4 classes."
assert test_num_classes == 4, "The test data does not have 4 classes."


Permutation vector: tensor([ 4275234, 17743950,  4963753,  ..., 35549552,  9354397,  5111534])
Train size: 21954470, Validation size: 7318157, Test size: 7318157
Train edge index shape: torch.Size([2, 21954470])
Validation edge index shape: torch.Size([2, 7318157])
Test edge index shape: torch.Size([2, 7318157])
Train labels shape: torch.Size([21954470])
Validation labels shape: torch.Size([7318157])
Test labels shape: torch.Size([7318157])
Number of training edges: 21954470
Number of validation edges: 7318157
Number of test edges: 7318157


## Models

### GCN

In [27]:
import torch
from torch.nn import functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.out = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.out(x)
        return x

def train_model(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out, data.y)
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluate_model(model, data):
    model.eval()
    with torch.no_grad():
        preds = model(data.x, data.edge_index)
        _, preds = preds.max(dim=1)
        correct = float(preds.eq(data.y).sum().item())
        acc = correct / data.num_nodes
    return acc


# Assuming train_data, val_data, test_data are defined elsewhere
# Initialize the model and optimizer
num_classes = torch.unique(train_data.y).size(0)
model = GCN(num_features=train_data.num_node_features, hidden_channels=4, num_classes=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Train the model for a number of epochs
for epoch in range(50):
    loss = train_model(model, train_data, optimizer, criterion)
    train_acc = evaluate_model(model, train_data)
    val_acc = evaluate_model(model, val_data)
    print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')

# After training, evaluate the model on the test set
test_acc = evaluate_model(model, test_data)
print(f'Test Acc: {test_acc:.4f}')


ValueError: Expected input batch_size (18295393) to match target batch_size (21954470).