# Grapher

## Load libraries

In [1]:
import sys
import os
import random
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score
from sklearn.preprocessing import LabelEncoder, StandardScaler, OrdinalEncoder, OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import to_undirected, negative_sampling
import networkx as nx
from scipy.spatial import cKDTree
from typing import List, Dict
import time
import collections

# Print versions of imported libraries
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Matplotlib version: {matplotlib.__version__}")
print(f"Scikit-learn version: {sklearn.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Torch Geometric version: {torch_geometric.__version__}")
print(f"NetworkX version: {nx.__version__}")

if torch.cuda.is_available():
    device = torch.device("cuda")          # Current CUDA device
    print(f"Using {torch.cuda.get_device_name()} ({device})")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
else:
    print("CUDA is not available on this device.")

Python version: 3.11.3 (tags/v3.11.3:f3909b8, Apr  4 2023, 23:49:59) [MSC v.1934 64 bit (AMD64)]
NumPy version: 1.24.3
Pandas version: 2.0.1
Matplotlib version: 3.7.1
Scikit-learn version: 1.2.2
Torch version: 2.0.0+cu118
Torch Geometric version: 2.3.1
NetworkX version: 3.0
Using NVIDIA GeForce RTX 3060 Ti (cuda)
CUDA version: 11.8
Number of CUDA devices: 1


## Load data

In [2]:
dtypes = {
    'id': 'string',
    '#chrom': 'string',
    'pos': 'int64',
    'ref': 'string',
    'alt': 'string',
    'rsids': 'string',
    'nearest_genes': 'string',
    'pval': 'float64',
    'mlogp': 'float64',
    'beta': 'float64',
    'sebeta': 'float64',
    'af_alt': 'float64',
    'af_alt_cases': 'float64',
    'af_alt_controls': 'float64',
    'causal': 'int64',
    'LD': 'int64',
    'lead': 'string',
    'trait': 'string'
}

data = pd.read_csv('~/Desktop/gwas-graph/FinnGen/data/gwas-causal.csv', dtype=dtypes)

# Assert column names
expected_columns = ['#chrom', 'pos', 'ref', 'alt', 'rsids', 'nearest_genes', 'pval', 'mlogp', 'beta',
                    'sebeta', 'af_alt', 'af_alt_cases', 'af_alt_controls', 'causal', 'LD', 'lead',
                    'id', 'trait']
assert set(data.columns) == set(expected_columns), "Unexpected columns in the data DataFrame."

# Assert data types
expected_dtypes = {
    'id': 'string',
    '#chrom': 'string',
    'pos': 'int64',
    'ref': 'string',
    'alt': 'string',
    'rsids': 'string',
    'nearest_genes': 'string',
    'pval': 'float64',
    'mlogp': 'float64',
    'beta': 'float64',
    'sebeta': 'float64',
    'af_alt': 'float64',
    'af_alt_cases': 'float64',
    'af_alt_controls': 'float64',
    'causal': 'int64',
    'LD': 'int64',
    'lead': 'string',
    'trait': 'string'
}

for col, expected_dtype in expected_dtypes.items():
    assert data[col].dtype == expected_dtype, f"Unexpected data type for column {col}."

In [3]:
# Check for total number of null values in each column
null_counts = data.isnull().sum()

print("Total number of null values in each column:")
print(null_counts)

Total number of null values in each column:
#chrom                    0
pos                       0
ref                       0
alt                       0
rsids               1366396
nearest_genes        727855
pval                      0
mlogp                     0
beta                      0
sebeta                    0
af_alt                    0
af_alt_cases              0
af_alt_controls           0
causal                    0
LD                        0
id                        0
lead               20168881
trait                     0
dtype: int64


## Data manipulation

### Create new rows per gene

In [4]:
data['nearest_genes'] = data['nearest_genes'].astype(str)

# Assert column 'nearest_genes' is a string
assert data['nearest_genes'].dtype == 'object', "Column 'nearest_genes' is not of string type."

# Split the gene names in the 'nearest_genes' column
split_genes = data['nearest_genes'].str.split(',')

# Flatten the list of split gene names
flat_genes = [item for sublist in split_genes for item in sublist]

# Create a new DataFrame by repeating rows and substituting the gene names
data_new = data.loc[data.index.repeat(split_genes.str.len())].copy()
data_new['nearest_genes'] = flat_genes

# Assert the shape of the new DataFrame is as expected
expected_shape = (len(flat_genes), data.shape[1])
assert data_new.shape == expected_shape, "Shape of the new DataFrame is not as expected."

# Reset index to have a standard index
data = data_new.reset_index(drop=True)

In [5]:
data = data.sample(frac=0.1, random_state=42)

## Spec

### Required Data:

- `id`: id of the variant in the following format: #chrom:pos:ref:alt.
- `#chrom`: chromosome on build GRCh38 (1-23)
- `pos`: position in base pairs on build GRCh38
- `ref`: reference allele
- `alt`: alternative allele (effect allele)
- `nearest_genes`: nearest gene(s) (comma separated) from variant
- `pval`: p-value
- `mlogp`: -log10(p-value)
- `beta`: effect size (log(OR) scale)
- `sebeta`: standard error of effect size
- `af_alt`: alternative (effect) allele frequency
- `af_alt_cases`: alternative (effect) allele frequency among cases
- `af_alt_controls`: alternative (effect) allele frequency among controls

### Procedure:

**1. Node Creation:** 

- Create a node for each variant in the dataset. 
- Label the node with the `id`, `#chrom`, and `pos` fields.

**2. Edge Creation:** 

- For each pair of variants within 200,000 base pairs of each other on the same `#chrom`, create an edge.
- The physical distance is calculated using the `pos` field.

**3. Edge Weighting:** 

- `finalWeight = absDiffAltAlleleFreq / (1 + exp(-distance/decayConstant))`
- Based on typical distances between SNPs in LD, a reasonable range for the decay constant could be from 10^4 to 10^6 base pairs. 


**4. Graph Refinement:** 

- Simplify the graph and emphasize important connections by removing any edge with a weight below a certain threshold.

**5. Clustering:** 

- Apply the Louvain method for community detection on the graph. 

**6. Cluster Evaluation:** 

- For each cluster, calculate the average edge weight by summing all the edge weights in the cluster and dividing by the number of edges. 
- Calculate the average physical distance in a similar manner. 

## Clustering

In [8]:
import pandas as pd
import networkit as nk
import numpy as np
from scipy.special import expit
import community as community_louvain
import itertools

def preprocess_data(data: pd.DataFrame) -> pd.DataFrame:
    print("Preprocessing data...")
    assert isinstance(data, pd.DataFrame), "Input data must be a pandas DataFrame"
    data = data.fillna({'nearest_genes': 'N/A', '#chrom': 'N/A', 'pos': 0, 'ref': 'N/A', 'alt': 'N/A',
                        'beta': 0, 'sebeta': 0, 'af_alt': 0, 'af_alt_cases': 0, 'af_alt_controls': 0})
    data['absDiffAltAlleleFreq'] = abs(data['af_alt_cases'] - data['af_alt_controls'])
    print("Data preprocessing completed!")
    return data

def create_networkit_graph(features: pd.DataFrame, decay_constant: float) -> nk.Graph:
    print("Creating Networkit graph...")
    assert isinstance(features, pd.DataFrame), "Features must be a pandas DataFrame"
    assert isinstance(decay_constant, (int, float)), "Decay constant must be a numeric value"
    
    # Group data by chromosome
    grouped = features.groupby('#chrom')

    # Initialize an undirected graph with number of nodes equal to the length of the features dataframe
    G = nk.Graph(len(features), weighted=True)

    # A dictionary to store the position of each node
    pos_dict = {}

    for index, row in features.iterrows():
        pos_dict[index] = row['pos']

    for chrom, chr_data in grouped:
        for i, row1 in chr_data.iterrows():
            for j, row2 in chr_data[i+1:].iterrows():
                if abs(row1['pos'] - row2['pos']) <= 500000:
                    weight = row1['absDiffAltAlleleFreq'] / (1 + expit(-abs(row1['pos'] - row2['pos'])/decay_constant))
                    G.addEdge(i, j, weight)
                    
    print("Networkit graph creation completed!")
    return G, pos_dict

def apply_clustering(G: nk.Graph) -> nk.community.PLM:
    print("Applying PLM method for community detection...")
    assert isinstance(G, nk.Graph), "Input must be a Networkit graph"
    plm = nk.community.PLM(G, True)
    plm.run()
    print("Community detection completed!")
    return plm.getPartition()

def evaluate_clusters(G: nk.Graph, partition: nk.community.PLM, pos_dict: dict) -> pd.DataFrame:
    print("Evaluating clusters...")
    assert isinstance(G, nk.Graph), "Graph must be a Networkit graph"
    assert isinstance(partition, nk.community.PLM), "Partition must be a PLM object"
    
    clusters = pd.DataFrame([(node, cluster_id) for node, cluster_id in enumerate(partition)], columns=['node', 'cluster'])
    cluster_stats = clusters.groupby('cluster')['node'].apply(list).apply(
        lambda nodes: {'avg_weight': np.mean([G.weight(u, v) for u, v in itertools.combinations(nodes, 2) if G.hasEdge(u, v)]) 
                       if any(G.hasEdge(u, v) for u, v in itertools.combinations(nodes, 2)) else None,
                       'avg_distance': np.mean([abs(pos_dict[u] - pos_dict[v]) for u, v in itertools.combinations(nodes, 2) if G.hasEdge(u, v)]) 
                       if any(G.hasEdge(u, v) for u, v in itertools.combinations(nodes, 2)) else None}).apply(pd.Series)
    print("Cluster evaluation completed!")
    return cluster_stats

# Preprocess data
data = preprocess_data(data)

import warnings

# Step 1: Check the data
print("Minimum pos difference in data:", data['pos'].diff().abs().min())
print("Maximum pos difference in data:", data['pos'].diff().abs().max())

# Step 2: Adjust the parameters
decay_constant = 1e5  # adjust this

# Create graph
G, pos_dict = create_networkit_graph(data, decay_constant=decay_constant)

# Apply PLM method for community detection
partition = apply_clustering(G)

if partition.numberOfSubsets() > 0:
    print("Partition created with", partition.numberOfSubsets(), "clusters.")
else:
    warnings.warn("No clusters were formed. Please check the input data and parameters.")

# Evaluate clusters
cluster_stats = evaluate_clusters(G, partition, pos_dict)

if cluster_stats.empty:
    warnings.warn("Cluster statistics are empty. Please check the clustering results and graph properties.")

print(cluster_stats)

# Check for NaN, null, or empty values in the final DataFrame
print("Checking for NaN, null, or empty values in the final DataFrame...")
print("Total NaN or Null values in cluster_stats DataFrame:")
print(cluster_stats.isna().sum())
print()
print("Total Empty values in cluster_stats DataFrame:")
print(cluster_stats.eq('').sum())
print("Check completed!")

Preprocessing data...
Data preprocessing completed!
Minimum pos difference in data: 27.0
Maximum pos difference in data: 248430921.0
Creating Networkit graph...


KeyboardInterrupt: 

In [None]:
def visualize_graph(graph, chromosome: str):
    plt.figure(figsize=(25, 25))  # Increase figure size for better visibility
    
    # Create a subgraph for the specified chromosome
    nodes_chromosome = [node for node, attr in graph.nodes(data=True) if attr.get('#chrom') == chromosome]
    subgraph = graph.subgraph(nodes_chromosome)

    if subgraph.number_of_nodes() == 0:
        print(f"No nodes found for chromosome {chromosome}.")
        return

    # Get a list of unique chromosomes in the subgraph
    chroms = list(set(nx.get_node_attributes(subgraph, '#chrom').values()))
    # Assign each chromosome a color
    colormap = plt.cm.tab20
    
    # Get node colors
    node_colors = [colormap(chroms.index(subgraph.nodes[node]['#chrom'])) for node in subgraph.nodes]

    pos = nx.spring_layout(subgraph, k=0.25)  
    nx.draw(subgraph, pos, with_labels=True, node_color=node_colors, node_size=20, 
            font_size=8, edge_color='gray')  
    
    edge_labels = nx.get_edge_attributes(subgraph, 'label')
    nx.draw_networkx_edge_labels(subgraph, pos, edge_labels=edge_labels, font_size=8)  
    
    plt.show()

# To visualize a specific chromosome, pass the chromosome number as a string. For example:
visualize_graph(graph, "23")
