# Grapher

## Dataset

- `id`: This column represents the id of the variant in the following format: #chrom:pos:ref:alt (string).

- `#chrom`: This column represents the chromosome number where the genetic variant is located.

- `pos`: This is the position of the genetic variant on the chromosome.

- `ref`: This column represents the reference allele (or variant) at the genomic position.

- `alt`: This is the alternate allele observed at this position.

- `rsids`: This stands for reference SNP cluster ID. It's a unique identifier for each variant used in the dbSNP database.

- `nearest_genes`: This column represents the gene which is nearest to the variant.

- `pval`: This represents the p-value, which is a statistical measure for the strength of evidence against the null hypothesis.

- `mlogp`: This represents the minus log of the p-value, commonly used in genomic studies.

- `beta`: The beta coefficient represents the effect size of the variant.

- `sebeta`: This is the standard error of the beta coefficient.

- `af_alt`: This is the allele frequency of the alternate variant in the general population.

- `af_alt_cases`: This is the allele frequency of the alternate variant in the cases group.

- `af_alt_controls`: This is the allele frequency of the alternate variant in the control group.

- `finemapped`: This column represents whether the variant is included in the post-finemapped dataset (1) or not (0). 

- `trait`: This column represents the trait associated with the variant. In this dataset, it is the response to the drug paracetamol and NSAIDs.

## Load libraries

In [1]:
import sys
import os
import random
import numpy as np
from numba import jit, prange
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score
from sklearn.preprocessing import LabelEncoder, StandardScaler, OrdinalEncoder, OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import to_undirected, negative_sampling
import networkx as nx
from scipy.spatial import cKDTree
from scipy.special import expit
from typing import List, Dict
import time
import cProfile
import pstats
import io


# Print versions of imported libraries
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Matplotlib version: {matplotlib.__version__}")
print(f"Scikit-learn version: {sklearn.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Torch Geometric version: {torch_geometric.__version__}")
print(f"NetworkX version: {nx.__version__}")

if torch.cuda.is_available():
    device = torch.device("cuda")          # Current CUDA device
    print(f"Using {torch.cuda.get_device_name()} ({device})")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
else:
    print("CUDA is not available on this device.")

Python version: 3.11.3 (tags/v3.11.3:f3909b8, Apr  4 2023, 23:49:59) [MSC v.1934 64 bit (AMD64)]
NumPy version: 1.24.3
Pandas version: 2.0.1
Matplotlib version: 3.7.1
Scikit-learn version: 1.2.2
Torch version: 2.0.0+cu118
Torch Geometric version: 2.3.1
NetworkX version: 3.0
Using NVIDIA GeForce RTX 3060 Ti (cuda)
CUDA version: 11.8
Number of CUDA devices: 1


## Load data

In [2]:
dtypes = {
    'id': 'string',
    '#chrom': 'string',
    'pos': 'int64',
    'ref': 'string',
    'alt': 'string',
    'rsids': 'string',
    'nearest_genes': 'string',
    'pval': 'float64',
    'mlogp': 'float64',
    'beta': 'float64',
    'sebeta': 'float64',
    'af_alt': 'float64',
    'af_alt_cases': 'float64',
    'af_alt_controls': 'float64',
    'finemapped': 'int64'
}

data = pd.read_csv('~/Desktop/gwas-graph/FinnGen/data/gwas-finemap.csv', dtype=dtypes)

# Assert column names
expected_columns = ['#chrom', 'pos', 'ref', 'alt', 'rsids', 'nearest_genes', 'pval', 'mlogp', 'beta',
                    'sebeta', 'af_alt', 'af_alt_cases', 'af_alt_controls', 'finemapped',
                    'id', 'trait']
assert set(data.columns) == set(expected_columns), "Unexpected columns in the data DataFrame."

# Assert data types
expected_dtypes = {
    'id': 'string',
    '#chrom': 'string',
    'pos': 'int64',
    'ref': 'string',
    'alt': 'string',
    'rsids': 'string',
    'nearest_genes': 'string',
    'pval': 'float64',
    'mlogp': 'float64',
    'beta': 'float64',
    'sebeta': 'float64',
    'af_alt': 'float64',
    'af_alt_cases': 'float64',
    'af_alt_controls': 'float64',
    'finemapped': 'int64'
}

for col, expected_dtype in expected_dtypes.items():
    assert data[col].dtype == expected_dtype, f"Unexpected data type for column {col}."

In [3]:
# Check for total number of null values in each column
null_counts = data.isnull().sum()

print("Total number of null values in each column:")
print(null_counts)

Total number of null values in each column:
#chrom                   0
pos                      0
ref                      0
alt                      0
rsids              1366396
nearest_genes       727855
pval                     0
mlogp                    0
beta                     0
sebeta                   0
af_alt                   0
af_alt_cases             0
af_alt_controls          0
id                       0
finemapped               0
trait                    0
dtype: int64


## Data manipulation

### Create new rows per gene

In [4]:
data['nearest_genes'] = data['nearest_genes'].astype(str)

# Assert column 'nearest_genes' is a string
assert data['nearest_genes'].dtype == 'object', "Column 'nearest_genes' is not of string type."

# Split the gene names in the 'nearest_genes' column
split_genes = data['nearest_genes'].str.split(',')

# Flatten the list of split gene names
flat_genes = [item for sublist in split_genes for item in sublist]

# Create a new DataFrame by repeating rows and substituting the gene names
data_new = data.loc[data.index.repeat(split_genes.str.len())].copy()
data_new['nearest_genes'] = flat_genes

# Assert the shape of the new DataFrame is as expected
expected_shape = (len(flat_genes), data.shape[1])
assert data_new.shape == expected_shape, "Shape of the new DataFrame is not as expected."

# Reset index to have a standard index
data = data_new.reset_index(drop=True)

In [5]:
data = data.sample(frac=0.001, random_state=42)

## Spec

`data` df:

- `id`: This column represents the id of the variant in the following format: #chrom:pos:ref:alt (string).

- `#chrom`: This column represents the chromosome number where the genetic variant is located.

- `pos`: This is the position of the genetic variant on the chromosome.

- `ref`: This column represents the reference allele (or variant) at the genomic position.

- `alt`: This is the alternate allele observed at this position.

- `rsids`: This stands for reference SNP cluster ID. It's a unique identifier for each variant used in the dbSNP database.

- `nearest_genes`: This column represents the gene which is nearest to the variant.

- `pval`: This represents the p-value, which is a statistical measure for the strength of evidence against the null hypothesis.

- `mlogp`: This represents the minus log of the p-value, commonly used in genomic studies.

- `beta`: The beta coefficient represents the effect size of the variant.

- `sebeta`: This is the standard error of the beta coefficient.

- `af_alt`: This is the allele frequency of the alternate variant in the general population.

- `af_alt_cases`: This is the allele frequency of the alternate variant in the cases group.

- `af_alt_controls`: This is the allele frequency of the alternate variant in the control group.

- `finemapped`: This column represents whether the variant is included in the post-finemapped dataset (1) or not (0). 

- `trait`: This column represents the trait associated with the variant. In this dataset, it is the response to the drug paracetamol and NSAIDs.

**Task Overview**
- The objective is to design and implement a binary node classification GNN model to predict whether variants are included after post-finemapping or not.

**Nodes and Their Features**
- There is one types of node: SNP nodes.
- *SNP Nodes*: Each SNP Node is characterized by various features, including `id`, `nearest_genes`, `#chrom`, `pos`, `ref`, `alt`, `beta`, `sebeta`, `finemapped`, `af_alt`, and `af_alt_cases` columns. 

**Edges, Their Features, and Labels**
- Edges represent relationships between SNP nodes. The edges in this graph are created based on the proximity of SNPs on a chromosome.
- Specifically, for each pair of SNPs (row1 and row2) that exist on the same chromosome (#chrom), an edge is created if the absolute difference between their positions (pos) is less than or equal to 200,000.
- The weight of an edge is determined by a function of the absolute difference in the alternate allele frequency (absDiffAltAlleleFreq) and the absolute difference in position (pos) between the two SNPs. Specifically, the weight is given by:

- `weight = row1['absDiffAltAlleleFreq'] / (1 + expit(-abs(row1['pos'] - row2['pos'])/decay_constant))`where expit is the logistic sigmoid function, and decay_constant is a parameter that controls the rate at which the weight decays with increasing difference in position. The absDiffAltAlleleFreq is calculated by taking the absolute difference between af_alt of the two SNPs.

## Graph creation

In [6]:
import cProfile
import pstats
import io
from scipy.special import expit
from sklearn.preprocessing import LabelEncoder, StandardScaler
import torch
from torch_geometric.data import Data
import pandas as pd
import time

decay_constant = 1e5  # Define decay constant as per your requirements

def get_unique_snps(data: pd.DataFrame) -> dict:
    """
    Function to create mappings for SNPs to integer indices.
    """
    return {snp: idx for idx, snp in enumerate(data['id'].unique())}

def preprocess_snp_features(data: pd.DataFrame, snp_to_idx: dict) -> pd.DataFrame:
    """
    Function to create node feature vectors for SNPs and preprocess categorical and numerical features.
    """
    snp_features = data.loc[data['id'].isin(snp_to_idx.keys()), ['id', 'nearest_genes', '#chrom', 'pos', 'ref', 'alt', 'beta', 'sebeta', 'af_alt', 'af_alt_cases']].set_index('id').sort_index()
    le = LabelEncoder()
    scaler = StandardScaler()
    for col in snp_features.columns:
        if col in ['ref', 'alt', 'nearest_genes', '#chrom']:
            snp_features[col] = le.fit_transform(snp_features[col].astype(str))
        else:
            snp_features[col] = pd.Series(scaler.fit_transform(snp_features[col].to_frame()).flatten())

    snp_features.loc[:, 'nearest_genes'] = snp_features['nearest_genes'].fillna('N/A')
    snp_features.loc[:, '#chrom'] = snp_features['#chrom'].fillna('N/A')
    snp_features[['pos', 'beta', 'sebeta', 'af_alt', 'af_alt_cases']] = snp_features[['pos', 'beta', 'sebeta', 'af_alt', 'af_alt_cases']].fillna(0)
    snp_features.loc[:, ['ref', 'alt']] = snp_features[['ref', 'alt']].fillna('N/A')

    return snp_features

def preprocess_positive_edges(data: pd.DataFrame, snp_to_idx: dict) -> torch.Tensor:
    """
    Function to create positive SNP-SNP edges and preprocess them.
    """
    positive_edges_snp_snp = []
    snp_weights = []

    for _, group in data.groupby('#chrom'):
        sorted_group = group.sort_values(by='pos')
        for idx, row in sorted_group.iterrows():
            pos_diff = sorted_group['pos'] - row['pos']
            mask = (pos_diff > 0) & (pos_diff <= 200000)
            filtered_group = sorted_group[mask]

            ids = filtered_group['id'].map(snp_to_idx)
            af_alt_diff = abs(row['af_alt'] - filtered_group['af_alt'])
            pos_diff_abs = abs(row['pos'] - filtered_group['pos'])

            weights = af_alt_diff / (1 + expit(-pos_diff_abs / decay_constant))

            positive_edges_snp_snp.extend(zip([snp_to_idx[row['id']]] * len(ids), ids))
            snp_weights.extend(weights)

    return torch.tensor(positive_edges_snp_snp, dtype=torch.long).t().contiguous(), torch.tensor(snp_weights, dtype=torch.float)

def create_pytorch_graph(features: torch.Tensor, edges: torch.Tensor, edge_weights: torch.Tensor) -> Data:
    """
    Function to create the PyTorch Geometric graph.
    """
    return Data(x=features, edge_index=edges, edge_attr=edge_weights)

# Add profiling
pr = cProfile.Profile()
pr.enable()

start_time = time.time()

snp_to_idx = get_unique_snps(data)

labels = data['finemapped'].map(lambda x: 1 if x > 0 else 0)

snp_features = preprocess_snp_features(data, snp_to_idx)
features = torch.tensor(snp_features.values, dtype=torch.float)

positive_edges_snp_snp, edge_weights = preprocess_positive_edges(data, snp_to_idx)
graph = create_pytorch_graph(features, positive_edges_snp_snp, edge_weights)

# Set labels here after creating the graph
graph.y = torch.tensor(labels.values, dtype=torch.long)

print(f"Number of nodes: {graph.num_nodes}")
print(f"Number of edges: {graph.num_edges}")
print(f"Node feature dimension: {graph.num_node_features}")

elapsed_time = time.time() - start_time
print(f"Execution time: {elapsed_time} seconds")

pr.disable()
s = io.StringIO()
sortby = 'cumulative'
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats(5)  # Only print the top 5 lines
print(s.getvalue())

Number of nodes: 20566
Number of edges: 31243
Node feature dimension: 9
Execution time: 181.7657985687256 seconds
         104269693 function calls (102256695 primitive calls) in 181.764 seconds

   Ordered by: cumulative time
   List reduced from 970 to 5 due to restriction <5>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       14    0.000    0.000  181.763   12.983 C:\Users\falty\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3472(run_code)
       14    0.000    0.000  181.763   12.983 {built-in method builtins.exec}
        1    0.002    0.002  181.630  181.630 C:\Users\falty\AppData\Local\Temp\ipykernel_18580\1596713975.py:1(<module>)
        1    1.039    1.039  181.628  181.628 C:\Users\falty\AppData\Local\Temp\ipykernel_18580\1596713975.py:39(preprocess_positive_edges)
    20567    4.946    0.000  130.980    0.006 C:\Users\falty\AppData\Local\Programs\Python\Python311\Lib\site-packages\pandas\core\series.py

In [7]:
# Save PyTorch Geometric graph
#torch.save(graph, "pytorch_geometric_graph.pt")

## Graph stats

In [8]:
from torch_geometric.utils import degree

def print_graph_stats(graph, positive_edges_snp_snp, features_list):
    print(f"Number of nodes: {graph.num_nodes}")
    print(f"Number of positive SNP-SNP edges: {positive_edges_snp_snp.size(1)}")
    print(f"Number of edges: {graph.num_edges}")
    print(f"Node feature dimension: {graph.num_node_features}")

    # Compute and print degree-related stats
    degrees = degree(graph.edge_index[0].long(), num_nodes=graph.num_nodes)
    average_degree = degrees.float().mean().item()
    median_degree = np.median(degrees.numpy())
    std_degree = degrees.float().std().item()

    print(f"Average Degree: {average_degree}")
    print(f"Median Degree: {median_degree}")
    print(f"Standard Deviation of Degree: {std_degree}")

    # Density is the ratio of actual edges to the maximum number of possible edges
    num_possible_edges = graph.num_nodes * (graph.num_nodes - 1) / 2
    density = graph.num_edges / num_possible_edges

    print(f"Density: {density:.10f}")

    # Check for NaN values in features
    nan_mask = torch.isnan(graph.x)
    nan_features = []
    for feature_idx, feature_name in enumerate(features_list):
        if nan_mask[:, feature_idx].any():
            nan_features.append(feature_name)

    print("Features with NaN values:")
    print(nan_features)

def print_edge_weight_stats(edge_weights):
    print(f"Number of edges: {edge_weights.size(0)}")
    print(f"Average edge weight: {edge_weights.float().mean().item()}")
    print(f"Median edge weight: {np.median(edge_weights.numpy())}")
    print(f"Standard deviation of edge weights: {edge_weights.float().std().item()}")
    print(f"Maximum edge weight: {edge_weights.max().item()}")
    print(f"Minimum edge weight: {edge_weights.min().item()}")

# Print graph stats
print("Graph stats:")
print_graph_stats(graph, positive_edges_snp_snp, snp_features.columns)

print("Edge weight stats:")
print_edge_weight_stats(graph.edge_attr)

Graph stats:
Number of nodes: 20566
Number of positive SNP-SNP edges: 31243
Number of edges: 31243
Node feature dimension: 9
Average Degree: 1.519157886505127
Median Degree: 1.0
Standard Deviation of Degree: 1.329643964767456
Density: 0.0001477421
Features with NaN values:
[]
Edge weight stats:
Number of edges: 31243
Average edge weight: 0.14570282399654388
Median edge weight: 0.0459057055413723
Standard deviation of edge weights: 0.1975155621767044
Maximum edge weight: 0.8884702324867249
Minimum edge weight: 0.0


## Data splitting

In [9]:
from torch_geometric.transforms import RandomNodeSplit

# Split the data into training, validation, and test sets
transform = RandomNodeSplit(split="random", num_train_per_class=1000, num_val=0.2, num_test=0.2, key='y')
graph = transform(graph)

print(f"Train nodes: {graph.train_mask.sum().item()}")
print(f"Validation nodes: {graph.val_mask.sum().item()}")
print(f"Test nodes: {graph.test_mask.sum().item()}")

Train nodes: 2000
Validation nodes: 4113
Test nodes: 4113


## Models

### GCN

In [10]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCNModel(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)

        x = self.conv2(x, edge_index)

        return torch.sigmoid(x)

model = GCNModel(graph.num_node_features, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(data):
    model.train()
    optimizer.zero_grad()
    out = model(data)[data.train_mask].squeeze()
    loss = F.binary_cross_entropy(out, data.y[data.train_mask].float())
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluate(data, mask):
    model.eval()
    with torch.no_grad():
        preds = model(data)[mask].squeeze()
        preds = (preds > 0.5).long()
        acc = preds.eq(data.y[mask]).sum().item() / mask.sum().item()
    return acc

for epoch in range(200):
    loss = train(graph)
    train_acc = evaluate(graph, graph.train_mask)
    val_acc = evaluate(graph, graph.val_mask)
    test_acc = evaluate(graph, graph.test_mask)
    print(f'Epoch: {epoch+1}, Loss: {loss:.10f}, Train Acc: {train_acc:.10f}, Val Acc: {val_acc:.10f}, Test Acc: {test_acc:.10f}')


Epoch: 1, Loss: 45.8633766174, Train Acc: 0.4995000000, Val Acc: 0.8653051301, Test Acc: 0.8743009968
Epoch: 2, Loss: 47.1373519897, Train Acc: 0.4975000000, Val Acc: 0.8584974471, Test Acc: 0.8672501823
Epoch: 3, Loss: 44.6423950195, Train Acc: 0.5000000000, Val Acc: 0.1232676878, Test Acc: 0.1184050571
Epoch: 4, Loss: 46.7130393982, Train Acc: 0.5000000000, Val Acc: 0.1232676878, Test Acc: 0.1184050571
Epoch: 5, Loss: 48.3965263367, Train Acc: 0.5000000000, Val Acc: 0.1232676878, Test Acc: 0.1184050571
Epoch: 6, Loss: 49.7029418945, Train Acc: 0.5000000000, Val Acc: 0.1232676878, Test Acc: 0.1184050571
Epoch: 7, Loss: 50.1112785339, Train Acc: 0.5000000000, Val Acc: 0.1232676878, Test Acc: 0.1184050571
Epoch: 8, Loss: 49.8173828125, Train Acc: 0.5000000000, Val Acc: 0.1232676878, Test Acc: 0.1184050571
Epoch: 9, Loss: 50.0652313232, Train Acc: 0.5000000000, Val Acc: 0.1232676878, Test Acc: 0.1184050571
Epoch: 10, Loss: 49.9016990662, Train Acc: 0.5000000000, Val Acc: 0.1232676878, Te