In [1]:
import sys
import os
import random
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import sklearn
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score
from sklearn.preprocessing import RobustScaler, LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.utils import to_undirected, negative_sampling
import networkx as nx
from scipy.spatial import cKDTree
from scipy.special import expit
from typing import List, Dict
import time
import cProfile
import pstats
import io
import category_encoders as ce
from itertools import combinations
from collections import Counter
from torch_geometric.transforms import RandomNodeSplit

# Print versions of imported libraries
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Matplotlib version: {matplotlib.__version__}")
print(f"Scikit-learn version: {sklearn.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Torch Geometric version: {torch_geometric.__version__}")
print(f"NetworkX version: {nx.__version__}")

if torch.cuda.is_available():
    device = torch.device("cuda")  # Current CUDA device
    print(f"Using {torch.cuda.get_device_name()} ({device})")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
else:
    print("CUDA is not available on this device.")

Python version: 3.11.5 (tags/v3.11.5:cce6ba9, Aug 24 2023, 14:38:34) [MSC v.1936 64 bit (AMD64)]
NumPy version: 1.24.1
Pandas version: 2.1.0
Matplotlib version: 3.7.2
Scikit-learn version: 1.3.0
Torch version: 2.0.1+cu117
Torch Geometric version: 2.3.1
NetworkX version: 3.0
Using NVIDIA RTX A6000 (cuda)
CUDA version: 11.7
Number of CUDA devices: 2


## Spec

### Data

`data` Pandas DataFrame:

- `#chrom`: chromosome of SNP (int).
- `id`: the ID of the variant in the following format: `#chrom:pos:ref:alt` (string).
- `pos`: position of the genetic variant on the chromosome (int).
- `ref`: reference allele (or variant) at the genomic position (string).
- `alt`: alternate allele observed at this position (string).
- `gene_0` to `gene_21`: genes which are nearest to the variant (string).
- `mlogp`: minus log of the p-value, commonly used in genomic studies (float).
- `beta`: beta coefficient represents the effect size of the variant (float).
- `sebeta`: standard error of the beta coefficient (float).
- `af_alt`: allele frequency of the alternate variant in the general population (float).
- `af_alt_cases`: allele frequency of the alternate variant in the cases group (float).
- `af_alt_controls`: allele frequency of the alternate variant in the control group (float).
- `prob`: posterior probability of association (float).
- `lead_r2`: r2 value to a lead variant (the one with maximum PIP) in a credible set (float).
- `cs_99`: credible set to which the variant belongs to (int).
- `causal`: indicates causality of variant (1) or not (0) (int). 

### Task Overview

The objective is to design and implement a binary node classification GNN model to predict whether variants are causal (`causal=1`) or not (`causal=0`).

### Nodes and Their Features

**SNP Nodes**: Each SNP Node is characterized by various features:
`['#chrom', 'pos', 'ref', 'alt', 'mlogp', 'beta', 'sebeta', 'af_alt', 'af_alt_cases', 'prob', 'lead_r2', 'cs_99']`.


### Edges

Edges are created between 2 variants if either 1) the SNPs share the same nearest gene value (i.e., they share a value across the `gene_0` to `gene_21` columns, OR 2) the SNPs are on the same `#chrom` and share the same `cs_99` value and both have `lead_r2 > 0.8`.

## Graph Creation

In [2]:
# Load and prepare the data
data = pd.read_parquet('gwas_fm_t2d.parquet')

# Process only chromosome 10 and 3
chroms = [3, 10, 12]

data = data[data['#chrom'].isin(chroms)]

In [3]:
def get_unique_snps(data: pd.DataFrame) -> dict:
    return {snp: idx for idx, snp in enumerate(data['id'].unique())}

def preprocess_snp_features(data: pd.DataFrame, snp_to_idx: dict) -> pd.DataFrame:
    cols_to_extract = ['id', '#chrom', 'pos', 'ref', 'alt', 'mlogp', 'beta', 'sebeta', 
                       'af_alt', 'af_alt_cases', 'af_alt_controls', 'prob']
    snp_features = data.loc[data['id'].isin(snp_to_idx.keys()), cols_to_extract].set_index('id').sort_index()
    
    categorical_cols = ['ref', 'alt']
    binary_encoder = ce.BinaryEncoder(cols=categorical_cols)
    snp_features = binary_encoder.fit_transform(snp_features)

    numerical_cols = list(set(snp_features.columns) - set(categorical_cols))

    snp_features = snp_features.fillna(0)

    return snp_features

def preprocess_edges(data: pd.DataFrame, snp_to_idx: dict) -> torch.Tensor:
    data['snp_idx'] = data['id'].map(snp_to_idx)

    # Create a mapping from gene values to node indices for each gene column
    gene_cols = [f'gene_{i}' for i in range(22)]
    melted_df = data.melt(id_vars='snp_idx', value_vars=gene_cols)
    grouped_df = melted_df[melted_df['value'].notnull() & (melted_df['value'] != 0)].groupby('value')['snp_idx'].apply(list)

    # Create edges based on shared nearest gene values
    edge_list = [edge for node_indices in grouped_df for edge in combinations(node_indices, 2)]

    # Filter data outside the loop to reduce the DataFrame size
    filtered_df = data[data['lead_r2'] > 0.8][['id', 'cs_99', '#chrom']]
    filtered_data_list = filtered_df.to_records(index=False).tolist()

    # Create edges based on lead_r2 values and same chromosome
    for i, (id_i, cs_99_i, chrom_i) in enumerate(filtered_data_list):
        for id_j, cs_99_j, chrom_j in filtered_data_list[i + 1:]:
            if cs_99_i == cs_99_j and chrom_i == chrom_j:
                edge_list.append((snp_to_idx[id_i], snp_to_idx[id_j]))

    # Remove duplicate edges
    edge_list = list(set(edge_list))
    edge_tensor = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

    return edge_tensor


def create_pytorch_graph(features: torch.Tensor, edges: torch.Tensor, edge_attr: torch.Tensor) -> Data:
    return Data(x=features, edge_index=edges, edge_attr=edge_attr)

# Create a profiler object
pr = cProfile.Profile()
pr.enable()

start_time = time.time()

snp_to_idx = get_unique_snps(data)
snp_features = preprocess_snp_features(data, snp_to_idx)
features = torch.tensor(snp_features.values, dtype=torch.float)

edges = preprocess_edges(data, snp_to_idx)
graph = create_pytorch_graph(features, edges, None)  # Assuming no edge attributes for simplicity
graph.y = torch.tensor(data['causal'].values, dtype=torch.long)

print(f"Number of nodes: {graph.num_nodes}")
print(f"Number of edges: {graph.num_edges}")
print(f"Node feature dimension: {graph.num_node_features}")

# Calculate elapsed time
elapsed_time = time.time() - start_time
print(f"Execution time: {elapsed_time} seconds")

pr.disable()
s = io.StringIO()
sortby = 'cumulative'
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats(3)  # Only print the top 3 lines
print(s.getvalue())


  elif pd.api.types.is_categorical_dtype(cols):
  return pd.api.types.is_categorical_dtype(dtype)
  return pd.api.types.is_categorical_dtype(dtype)
  return pd.api.types.is_categorical_dtype(dtype)
  return pd.api.types.is_categorical_dtype(dtype)
  return pd.api.types.is_categorical_dtype(dtype)
  return pd.api.types.is_categorical_dtype(dtype)
  return pd.api.types.is_categorical_dtype(dtype)
  return pd.api.types.is_categorical_dtype(dtype)


Number of nodes: 15075
Number of edges: 2547943
Node feature dimension: 26
Execution time: 1.3853271007537842 seconds
         90051 function calls (88153 primitive calls) in 1.207 seconds

   Ordered by: cumulative time
   List reduced from 1079 to 3 due to restriction <3>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       13    0.000    0.000    1.392    0.107 C:\Users\Windows\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3472(run_code)
       13    0.000    0.000    1.392    0.107 {built-in method builtins.exec}
        1    0.342    0.342    1.143    1.143 C:\Users\Windows\AppData\Local\Temp\ipykernel_1684\124732808.py:19(preprocess_edges)





In [4]:
from collections import Counter
import numpy as np

# Calculate new graph statistics:

# Node degrees:
degrees = Counter(graph.edge_index[0].numpy())
avg_degree = sum(degrees.values()) / len(degrees)

# Minimum, Maximum, and Median Degree:
min_degree = min(degrees.values())
max_degree_node = max(degrees, key=degrees.get)
median_degree = np.median(list(degrees.values()))

print(f"Minimum node degree: {min_degree}")
print(f"Average node degree: {avg_degree:.2f}")
print(f"Median node degree: {median_degree}")
print(f"Node with the highest degree ({degrees[max_degree_node]}): {max_degree_node}")

# Degree Distribution:
degree_values = list(degrees.values())
degree_count = Counter(degree_values)
print(f"Degree distribution (Degree: Count) -> {degree_count}")

# Degree Variance:
degree_variance = np.var(degree_values)
print(f"Node degree variance: {degree_variance:.2f}")

# Nodes with Degree Above/Below Average:
above_avg_count = sum(1 for degree in degree_values if degree > avg_degree)
below_avg_count = sum(1 for degree in degree_values if degree < avg_degree)
print(f"Number of nodes with a degree above average: {above_avg_count}")
print(f"Number of nodes with a degree below average: {below_avg_count}")

# Number of Isolated Nodes:
isolated_nodes = graph.num_nodes - len(degrees)
print(f"Number of isolated nodes: {isolated_nodes}")

# Edge statistics:

# Total Edge Count:
print(f"Total number of edges: {graph.num_edges}")

# Unique Edges:
unique_edges = set(tuple(edge) for edge in graph.edge_index.t().numpy())
print(f"Number of unique edges: {len(unique_edges)}")

# Self-loops:
self_loops = (graph.edge_index[0] == graph.edge_index[1]).sum().item()
print(f"Number of self-loops: {self_loops}")

# Multiple Edges:
multiple_edges = graph.num_edges - len(unique_edges) + self_loops
print(f"Number of multiple edges (duplicates): {multiple_edges}")

# Graph Density:
max_possible_edges = graph.num_nodes * (graph.num_nodes - 1)
density = graph.num_edges / max_possible_edges
print(f"Graph density: {density:.4f}")


Minimum node degree: 1
Average node degree: 171.22
Median node degree: 85.0
Node with the highest degree (1411): 8726
Degree distribution (Degree: Count) -> Counter({1: 188, 2: 174, 3: 158, 4: 149, 5: 145, 6: 138, 8: 134, 7: 134, 9: 129, 10: 127, 11: 125, 12: 124, 13: 122, 15: 121, 14: 120, 17: 118, 16: 118, 18: 115, 20: 113, 19: 111, 22: 108, 21: 108, 25: 104, 24: 104, 23: 104, 27: 99, 26: 98, 28: 94, 29: 92, 31: 92, 30: 92, 38: 88, 32: 88, 35: 87, 34: 87, 33: 87, 39: 87, 37: 86, 41: 85, 36: 82, 40: 82, 43: 79, 45: 78, 44: 78, 42: 78, 46: 76, 49: 75, 47: 75, 50: 74, 48: 74, 51: 73, 52: 72, 55: 71, 54: 71, 53: 71, 56: 70, 57: 68, 58: 67, 59: 66, 62: 65, 60: 64, 61: 64, 63: 64, 64: 63, 65: 61, 70: 61, 66: 60, 71: 60, 68: 60, 73: 60, 72: 60, 67: 60, 69: 59, 77: 59, 78: 59, 75: 57, 76: 57, 80: 57, 74: 57, 81: 56, 84: 55, 82: 55, 79: 54, 83: 54, 85: 53, 86: 52, 89: 51, 92: 51, 87: 51, 95: 51, 100: 50, 101: 50, 93: 50, 90: 50, 94: 50, 104: 50, 88: 50, 97: 49, 103: 49, 91: 49, 96: 49, 98: 49

## Data Splitting

In [5]:
# Split the data into training, validation, and test sets

seed_value = 0
torch.manual_seed(seed_value)
np.random.seed(seed_value)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed_value)
    
transform = RandomNodeSplit(split="train_rest", num_val=0.35, num_test=0.0, key='y')
graph = transform(graph)

# Count the number of nodes per class in each set
train_class_counts = Counter(graph.y[graph.train_mask].numpy())
val_class_counts = Counter(graph.y[graph.val_mask].numpy())
test_class_counts = Counter(graph.y[graph.test_mask].numpy())

# Print the results
print("Number of nodes per class in each set:")
print("Train set:")
for class_label, count in train_class_counts.items():
    print(f"Class {class_label}: {count} nodes")
print("Validation set:")
for class_label, count in val_class_counts.items():
    print(f"Class {class_label}: {count} nodes")
print("Test set:")
#for class_label, count in test_class_counts.items():
#    print(f"Class {class_label}: {count} nodes")

# Calculate and print the percentage of class 1 vs. class 0 in the test set
#total_test_nodes = sum(test_class_counts.values())
#class_0_nodes = test_class_counts[0]
#class_1_nodes = test_class_counts[1]
#class_0_percentage = (class_0_nodes / total_test_nodes) * 100
#class_1_percentage = (class_1_nodes / total_test_nodes) * 100
#print(f"Percentage of Class 0 in test set: {class_0_percentage:.2f}%")
#print(f"Percentage of Class 1 in test set: {class_1_percentage:.2f}%")

Number of nodes per class in each set:
Train set:
Class 0: 9788 nodes
Class 1: 11 nodes
Validation set:
Class 0: 5271 nodes
Class 1: 5 nodes
Test set:


In [7]:
from sklearn.utils.class_weight import compute_class_weight
from torch.nn.functional import binary_cross_entropy, dropout, leaky_relu
from sklearn.metrics import precision_recall_curve, auc, roc_auc_score, f1_score
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# suppose y is your target vector
y = graph.y.cpu().numpy()

# Count number of occurrences of each class
class_counts = np.bincount(y)

# Compute class weights
class_weights = 1. / class_counts

class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
class_weights = class_weights / class_weights.sum()


class GraphSAGEModel(torch.nn.Module):
    def __init__(self, num_node_features, hidden_layers, num_classes, dropout_rate=0.5):
        super(GraphSAGEModel, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(GATConv(num_node_features, hidden_layers[0], heads=1, concat=True))
        for i in range(1, len(hidden_layers)):
            self.layers.append(GATConv(hidden_layers[i - 1], hidden_layers[i], heads=1, concat=True))
        self.layers.append(GATConv(hidden_layers[-1], num_classes, heads=1, concat=False))  # For the last layer, we don't concatenate
        self.dropout_rate = dropout_rate
        self.bn_layers = torch.nn.ModuleList([torch.nn.BatchNorm1d(size) for size in hidden_layers])

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for i, conv in enumerate(self.layers[:-1]):
            x = conv(x, edge_index)
            x = self.bn_layers[i](x)
            x = torch.nn.functional.leaky_relu(x)
            x = torch.nn.functional.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.layers[-1](x, edge_index)
        return torch.sigmoid(x.view(-1))
    
    
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=None, gamma=100, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.alpha is not None:
            alpha_t = self.alpha[targets.long()].view(-1, 1)
            logpt = -binary_cross_entropy(inputs, targets, reduction='none')
            logpt = logpt * alpha_t
        else:
            logpt = -binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(logpt)
        F_loss = -((1 - pt) ** self.gamma) * logpt
        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

        
class EarlyStopping:
    def __init__(self, patience=150, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def step(self, val_loss):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0\
            
            
def train(model, data, loss_fn, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(data)[data.train_mask].squeeze()
    loss = loss_fn(out, data.y[data.train_mask].float())
    loss.backward()
    optimizer.step()
    return loss.item()


def evaluate(model, data, mask, loss_fn, best_threshold, return_preds_outs=False):
    model.eval()
    with torch.no_grad():
        out = model(data)[mask].squeeze()
        preds = (out > best_threshold).long()
        loss = loss_fn(out, data.y[mask].float())
        accuracy = preds.eq(data.y[mask]).sum().item() / mask.sum().item()
        precision, recall, _ = precision_recall_curve(data.y[mask].cpu(), out.cpu())
        auprc = auc(recall, precision)
        roc_auc = roc_auc_score(data.y[mask].cpu(), out.cpu())

        if return_preds_outs:
            return loss.item(), accuracy, auprc, roc_auc, preds, out
        else:
            return loss.item(), accuracy, auprc, roc_auc


def main():
    # initialization
    hidden_layers = [512, 512, 512, 512] 
    model = GraphSAGEModel(graph.num_node_features, hidden_layers, 1, dropout_rate=0.1).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=100, cycle_momentum=False)
    loss_fn = FocalLoss(alpha=class_weights, gamma=10)

    early_stopping = EarlyStopping(patience=150, min_delta=0.0001)

    # Move your graph data to device
    graph.x = graph.x.to(device)
    graph.edge_index = graph.edge_index.to(device)
    graph.y = graph.y.to(device)
    graph.train_mask = graph.train_mask.to(device)
    graph.val_mask = graph.val_mask.to(device)
    graph.test_mask = graph.test_mask.to(device)

    best_threshold = 0.5
    max_f1 = 0
    max_auprc = 0
    
    for epoch in range(1000):
        train_loss = train(model, graph, loss_fn, optimizer)
        _, train_acc, _, _ = evaluate(model, graph, graph.train_mask, loss_fn, best_threshold)
        val_loss, val_acc, val_auprc, val_roc_auc, preds, outs = evaluate(model, graph, graph.val_mask, loss_fn, best_threshold, return_preds_outs=True)
        
        # Determine best threshold on validation set
        thresholds = np.arange(0, 1, 0.01)
        f1_scores = [f1_score(graph.y[graph.val_mask].cpu(), (outs > thresh).long().cpu()) for thresh in thresholds]
        best_threshold = thresholds[np.argmax(f1_scores)]
        max_f1 = max(f1_scores)
        
        # Update max_auprc if needed
        if val_auprc > max_auprc:
            max_auprc = val_auprc
    
        print(f'Epoch: {epoch+1}, Loss: {train_loss:.10f}, Train Acc: {train_acc:.10f}, Val Loss: {val_loss:.10f}, Val Acc: {val_acc:.10f}, Val AUPRC: {val_auprc:.10f}, Val ROC-AUC: {val_roc_auc:.10f}, Best Threshold: {best_threshold:.2f}, Max F1: {max_f1:.4f}')
    
        # Update cyclic learning rate
        scheduler.step()
        
        # Early stopping
        early_stopping.step(-val_auprc)  # Pass -val_auprc because we want to maximize it
        if early_stopping.early_stop:
            print("Early stopping!")
            print(f'Max AUPRC: {max_auprc:.10f}')
            
            # Final evaluation on the test set
            #test_loss, test_acc, test_auprc, test_roc_auc = evaluate(model, graph, graph.test_mask, loss_fn)
            #print(f'Test Loss: {test_loss:.10f}, Test Acc: {test_acc:.10f}, Test AUPRC: {test_auprc:.10f}, Test ROC-AUC: {test_roc_auc:.10f}')
            break

        

# Call the main function
main()


Epoch: 1, Loss: 0.0000158245, Train Acc: 0.9988774365, Val Loss: 0.0000002245, Val Acc: 0.9990523124, Val AUPRC: 0.0007486354, Val ROC-AUC: 0.4077404667, Best Threshold: 0.13, Max F1: 0.0021
Epoch: 2, Loss: 0.0000159673, Train Acc: 0.0153076845, Val Loss: 0.0000001576, Val Acc: 0.0199014405, Val AUPRC: 0.0007485864, Val ROC-AUC: 0.4077214950, Best Threshold: 0.15, Max F1: 0.0020
Epoch: 3, Loss: 0.0000569152, Train Acc: 0.0875599551, Val Loss: 0.0000001741, Val Acc: 0.0951478393, Val AUPRC: 0.0007486354, Val ROC-AUC: 0.4077404667, Best Threshold: 0.15, Max F1: 0.0021
Epoch: 4, Loss: 0.0000037838, Train Acc: 0.0925604654, Val Loss: 0.0000002590, Val Acc: 0.1014025777, Val AUPRC: 0.0007487342, Val ROC-AUC: 0.4077973819, Best Threshold: 0.13, Max F1: 0.0021
Epoch: 5, Loss: 0.0000079225, Train Acc: 0.0925604654, Val Loss: 0.0000003593, Val Acc: 0.1014025777, Val AUPRC: 0.0007485619, Val ROC-AUC: 0.4077025232, Best Threshold: 0.10, Max F1: 0.0020
Epoch: 6, Loss: 0.0000206439, Train Acc: 0.06

In [None]:
0.2171029356 / (5 / 5276)

In [None]:
(5 / 5276)