# Grapher

## Dataset

- `id`: This column represents the id of the variant in the following format: #chrom:pos:ref:alt (string).

- `#chrom`: This column represents the chromosome number where the genetic variant is located.

- `pos`: This is the position of the genetic variant on the chromosome.

- `ref`: This column represents the reference allele (or variant) at the genomic position.

- `alt`: This is the alternate allele observed at this position.

- `rsids`: This stands for reference SNP cluster ID. It's a unique identifier for each variant used in the dbSNP database.

- `nearest_genes`: This column represents the gene which is nearest to the variant.

- `pval`: This represents the p-value, which is a statistical measure for the strength of evidence against the null hypothesis.

- `mlogp`: This represents the minus log of the p-value, commonly used in genomic studies.

- `beta`: The beta coefficient represents the effect size of the variant.

- `sebeta`: This is the standard error of the beta coefficient.

- `af_alt`: This is the allele frequency of the alternate variant in the general population.

- `af_alt_cases`: This is the allele frequency of the alternate variant in the cases group.

- `af_alt_controls`: This is the allele frequency of the alternate variant in the control group.

- `causal`: This binary column indicates whether the variant is determined to be causal (1) or not (0).

- `LD`: This binary column indicates whether the variant is determined to be in linkage disequilibrium (1) or not (0).

- `lead`: This string column contains the id of the SNP of which the variant is in LD with (string).

- `trait`: This column represents the trait associated with the variant. In this dataset, it is the response to the drug paracetamol and NSAIDs.

## Load libraries

In [1]:
import sys
import os
import random
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score
from sklearn.preprocessing import LabelEncoder, StandardScaler, OrdinalEncoder, OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import to_undirected, negative_sampling
import networkx as nx
from scipy.spatial import cKDTree
from typing import List, Dict
import time

# Print versions of imported libraries
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Matplotlib version: {matplotlib.__version__}")
print(f"Scikit-learn version: {sklearn.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Torch Geometric version: {torch_geometric.__version__}")
print(f"NetworkX version: {nx.__version__}")

if torch.cuda.is_available():
    device = torch.device("cuda")          # Current CUDA device
    print(f"Using {torch.cuda.get_device_name()} ({device})")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
else:
    print("CUDA is not available on this device.")

Python version: 3.11.3 (tags/v3.11.3:f3909b8, Apr  4 2023, 23:49:59) [MSC v.1934 64 bit (AMD64)]
NumPy version: 1.24.3
Pandas version: 2.0.1
Matplotlib version: 3.7.1
Scikit-learn version: 1.2.2
Torch version: 2.0.0+cu118
Torch Geometric version: 2.3.1
NetworkX version: 3.0
Using NVIDIA GeForce RTX 3060 Ti (cuda)
CUDA version: 11.8
Number of CUDA devices: 1


## Load data

In [2]:
dtypes = {
    'id': 'string',
    '#chrom': 'string',
    'pos': 'int64',
    'ref': 'string',
    'alt': 'string',
    'rsids': 'string',
    'nearest_genes': 'string',
    'pval': 'float64',
    'mlogp': 'float64',
    'beta': 'float64',
    'sebeta': 'float64',
    'af_alt': 'float64',
    'af_alt_cases': 'float64',
    'af_alt_controls': 'float64',
    'causal': 'int64',
    'LD': 'int64',
    'lead': 'string',
    'trait': 'string'
}

data = pd.read_csv('~/Desktop/gwas-graph/FinnGen/data/gwas-causal.csv', dtype=dtypes)

# Assert column names
expected_columns = ['#chrom', 'pos', 'ref', 'alt', 'rsids', 'nearest_genes', 'pval', 'mlogp', 'beta',
                    'sebeta', 'af_alt', 'af_alt_cases', 'af_alt_controls', 'causal', 'LD', 'lead',
                    'id', 'trait']
assert set(data.columns) == set(expected_columns), "Unexpected columns in the data DataFrame."

# Assert data types
expected_dtypes = {
    'id': 'string',
    '#chrom': 'string',
    'pos': 'int64',
    'ref': 'string',
    'alt': 'string',
    'rsids': 'string',
    'nearest_genes': 'string',
    'pval': 'float64',
    'mlogp': 'float64',
    'beta': 'float64',
    'sebeta': 'float64',
    'af_alt': 'float64',
    'af_alt_cases': 'float64',
    'af_alt_controls': 'float64',
    'causal': 'int64',
    'LD': 'int64',
    'lead': 'string',
    'trait': 'string'
}

for col, expected_dtype in expected_dtypes.items():
    assert data[col].dtype == expected_dtype, f"Unexpected data type for column {col}."

In [3]:
# Check for total number of null values in each column
null_counts = data.isnull().sum()

print("Total number of null values in each column:")
print(null_counts)

Total number of null values in each column:
#chrom                    0
pos                       0
ref                       0
alt                       0
rsids               1366396
nearest_genes        727855
pval                      0
mlogp                     0
beta                      0
sebeta                    0
af_alt                    0
af_alt_cases              0
af_alt_controls           0
causal                    0
LD                        0
id                        0
lead               20168198
trait                     0
dtype: int64


## Data manipulation

### Create new rows per gene

In [4]:
data['nearest_genes'] = data['nearest_genes'].astype(str)

# Assert column 'nearest_genes' is a string
assert data['nearest_genes'].dtype == 'object', "Column 'nearest_genes' is not of string type."

# Split the gene names in the 'nearest_genes' column
split_genes = data['nearest_genes'].str.split(',')

# Flatten the list of split gene names
flat_genes = [item for sublist in split_genes for item in sublist]

# Create a new DataFrame by repeating rows and substituting the gene names
data_new = data.loc[data.index.repeat(split_genes.str.len())].copy()
data_new['nearest_genes'] = flat_genes

# Assert the shape of the new DataFrame is as expected
expected_shape = (len(flat_genes), data.shape[1])
assert data_new.shape == expected_shape, "Shape of the new DataFrame is not as expected."

# Reset index to have a standard index
data = data_new.reset_index(drop=True)

In [5]:
data

Unnamed: 0,#chrom,pos,ref,alt,rsids,nearest_genes,pval,mlogp,beta,sebeta,af_alt,af_alt_cases,af_alt_controls,causal,LD,id,lead,trait
0,1,13668,G,A,rs2691328,OR4F5,0.944365,0.024860,-0.005926,0.084918,0.005842,0.005729,0.005863,0,0,1:13668:G:A,,T2D
1,1,14773,C,T,rs878915777,OR4F5,0.844305,0.073501,0.010088,0.051369,0.013495,0.013547,0.013485,0,0,1:14773:C:T,,T2D
2,1,15585,G,A,rs533630043,OR4F5,0.841908,0.074735,0.031464,0.157751,0.001113,0.001125,0.001110,0,0,1:15585:G:A,,T2D
3,1,16549,T,C,rs1262014613,OR4F5,0.343308,0.464316,0.241377,0.254711,0.000561,0.000620,0.000550,0,0,1:16549:T:C,,T2D
4,1,16567,G,C,rs1194064194,OR4F5,0.129883,0.886447,0.130736,0.086319,0.004170,0.004250,0.004154,0,0,1:16567:G:C,,T2D
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20565622,23,155697920,G,A,,,0.027115,1.566790,-0.013475,0.006097,0.290961,0.286054,0.291879,0,0,23:155697920:G:A,,T2D
20565623,23,155698443,C,A,,,0.178417,0.748564,-0.069907,0.051951,0.003259,0.003022,0.003304,0,0,23:155698443:C:A,,T2D
20565624,23,155698490,C,T,,,0.279640,0.553400,-0.020245,0.018725,0.024406,0.024312,0.024423,0,0,23:155698490:C:T,,T2D
20565625,23,155699751,C,T,,,0.078864,1.103120,-0.011284,0.006421,0.244829,0.241257,0.245498,0,0,23:155699751:C:T,,T2D


## Spec

`data` df:

- `id`: This column represents the id of the variant in the following format: #chrom:pos:ref:alt (string).

- `#chrom`: This column represents the chromosome number where the genetic variant is located.

- `pos`: This is the position of the genetic variant on the chromosome.

- `ref`: This column represents the reference allele (or variant) at the genomic position.

- `alt`: This is the alternate allele observed at this position.

- `rsids`: This stands for reference SNP cluster ID. It's a unique identifier for each variant used in the dbSNP database.

- `nearest_genes`: This column represents the gene which is nearest to the variant.

- `pval`: This represents the p-value, which is a statistical measure for the strength of evidence against the null hypothesis.

- `mlogp`: This represents the minus log of the p-value, commonly used in genomic studies.

- `beta`: The beta coefficient represents the effect size of the variant.

- `sebeta`: This is the standard error of the beta coefficient.

- `af_alt`: This is the allele frequency of the alternate variant in the general population.

- `af_alt_cases`: This is the allele frequency of the alternate variant in the cases group.

- `af_alt_controls`: This is the allele frequency of the alternate variant in the control group.

- `causal`: This binary column indicates whether the variant is determined to be causal (1) or not (0).

- `LD`: This binary column indicates whether the variant is determined to be in linkage disequilibrium (1) or not (0).

- `lead`: This string column contains the id of the SNP of which the variant is in LD with (string).

- `trait`: This column represents the trait associated with the variant. In this dataset, it is the response to the drug paracetamol and NSAIDs.

**Task Overview**
- The objective is to design and implement a link prediction deep neural network model for predicting Linkage Disequilibrium between SNPs.

**Nodes and Their Features**
- There is one types of node: SNP nodes.
- *SNP Nodes*: Each SNP Node is characterized by various features, including `id`, `nearest_genes`, `#chrom`, `pos`, `ref`, `alt`, `beta`, `sebeta`, `af_alt`, and `af_alt_cases` columns.

**Edges, Their Features, and Labels**
- Edges represent relationships between nodes. There is one type of edge: SNP-SNP.
- *SNP-SNP Edges*:
  - These edges are undirected, linking an SNP Node (as identified by the `id` column) to another SNP Node (as identified by the `lead` column) in the same data row:
    - An edge is created when `data['LD'] = 1`, signifying that the two SNPs are in linkage disequilibrium.

## Graph creation

In [6]:
def get_unique_snps(data: pd.DataFrame) -> Dict:
    """
    Function to create mappings for SNPs to integer indices.
    """
    assert 'id' in data.columns, "id column is missing in the data DataFrame."
    assert not data['id'].isnull().any(), "id column contains NaN values."

    return {snp: idx for idx, snp in enumerate(data['id'].unique())}


def preprocess_snp_features(data: pd.DataFrame, snp_to_idx: Dict) -> pd.DataFrame:
    """
    Function to create node feature vectors for SNPs and preprocess categorical and numerical features.
    """
    assert 'id' in data.columns, "id column is missing in the data DataFrame."
    assert not data['id'].isnull().any(), "id column contains NaN values."

    snp_features = data.loc[data['id'].isin(snp_to_idx.keys()),
                            ['id', 'nearest_genes', '#chrom', 'pos', 'ref', 'alt', 'beta', 'sebeta',
                             'af_alt', 'af_alt_cases']].drop_duplicates().set_index('id').sort_index()

    # Fill NaNs with appropriate replacements
    nan_replacements = {'nearest_genes': 'N/A', '#chrom': 'N/A', 'pos': 0, 'ref': 'N/A', 'alt': 'N/A',
                        'beta': 0, 'sebeta': 0, 'af_alt': 0, 'af_alt_cases': 0}

    for col in snp_features.columns:
        assert col in nan_replacements, f"{col} column is missing in nan_replacements dictionary."

    # Impute missing values with appropriate replacements
    imputer = SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=nan_replacements)
    snp_features = pd.DataFrame(imputer.fit_transform(snp_features), columns=snp_features.columns)

    # Assert no empty or null values in features
    assert not snp_features.isnull().any().any(), "Features contain null values."
    assert not snp_features.isna().any().any(), "Features contain NaN values."
    assert snp_features.size > 0, "Features are empty."

    # Label encoding for categorical columns and standardize numerical features
    le = LabelEncoder()
    scaler = StandardScaler()
    for col in snp_features.columns:
        if col in ['ref', 'alt', 'nearest_genes', '#chrom']:
            snp_features[col] = le.fit_transform(snp_features[col].astype(str))
        else:
            snp_features[col] = pd.Series(scaler.fit_transform(snp_features[col].to_frame()).flatten())

    return snp_features


def preprocess_positive_edges(data: pd.DataFrame, snp_to_idx: Dict) -> torch.Tensor:
    """
    Function to create positive SNP-SNP edges and preprocess them.
    """
    assert 'id' in data.columns, "id column is missing in the data DataFrame."
    assert 'lead' in data.columns, "lead column is missing in the data DataFrame."
    assert not data['id'].isnull().any(), "id column contains NaN values."

    positive_edges_snp_snp = data.loc[(data['LD'] == 1) & (data['id'].isin(snp_to_idx)) & (data['lead'].isin(snp_to_idx)),
                                      ['id', 'lead']].drop_duplicates().applymap(snp_to_idx.get).values
    assert positive_edges_snp_snp.size > 0, "No positive SNP-SNP edges found."

    return torch.tensor(positive_edges_snp_snp, dtype=torch.long).t().contiguous()


def create_pytorch_graph(features: torch.Tensor, edges: torch.Tensor) -> Data:
    """
    Function to create the PyTorch Geometric graph.
    """
    assert isinstance(features, torch.Tensor), "features must be a torch.Tensor."
    assert isinstance(edges, torch.Tensor), "edges must be a torch.Tensor."

    # Create edge labels (+1 for positive edges)
    edge_attr = torch.ones(edges.size(1), dtype=torch.float)

    return Data(x=features, edge_index=edges, edge_attr=edge_attr)


assert 'data' in globals(), "data variable is not defined."

start_time = time.time()

snp_to_idx = get_unique_snps(data)
assert len(snp_to_idx) > 0, "No unique SNPs found."
print(f"Number of unique SNPs: {len(snp_to_idx)}")

snp_features = preprocess_snp_features(data, snp_to_idx)
features = torch.tensor(snp_features.values, dtype=torch.float)

positive_edges_snp_snp = preprocess_positive_edges(data, snp_to_idx)
assert positive_edges_snp_snp.size(1) > 0, "No positive SNP-SNP edges found."
print(f"Number of positive SNP-SNP edges: {positive_edges_snp_snp.size(1)}")

graph = create_pytorch_graph(features, positive_edges_snp_snp)

print(f"Number of nodes: {graph.num_nodes}")
print(f"Number of edges: {graph.num_edges}")
print(f"Node feature dimension: {graph.num_node_features}")

elapsed_time = time.time() - start_time
print(f"Execution time: {elapsed_time} seconds")

Number of unique SNPs: 20170006
Number of positive SNP-SNP edges: 1808
Number of nodes: 20565627
Number of edges: 1808
Node feature dimension: 9
Execution time: 129.99298095703125 seconds


## Graph stats

In [7]:
from torch_geometric.utils import degree

def print_graph_stats(graph, positive_edges_snp_snp):
    print(f"Number of nodes: {graph.num_nodes}")
    print(f"Number of positive SNP-SNP edges: {positive_edges_snp_snp.size(1)}")
    print(f"Number of edges: {graph.num_edges}")
    print(f"Node feature dimension: {graph.num_node_features}")

    # Compute and print degree-related stats
    degrees = degree(graph.edge_index[0].long(), num_nodes=graph.num_nodes)
    average_degree = degrees.float().mean().item()
    median_degree = np.median(degrees.numpy())
    std_degree = degrees.float().std().item()

    # Assert average degree, median degree, and std degree
    assert isinstance(average_degree, float), "Average degree is not a float."
    assert isinstance(std_degree, float), "Standard deviation of degree is not a float."

    # Density is the ratio of actual edges to the maximum number of possible edges
    num_possible_edges = graph.num_nodes * (graph.num_nodes - 1) / 2
    density = graph.num_edges / num_possible_edges

    # Assert density
    assert isinstance(density, float), "Density is not a float."

    print(f"Density: {density:.10f}")

    # Check for NaN values in features
    nan_mask = torch.isnan(graph.x)
    nan_features = []
    for feature_idx, feature_name in enumerate(snp_features.columns):
        if nan_mask[:, feature_idx].any():
            nan_features.append(feature_name)

    print("Features with NaN values:")
    print(nan_features)

# Print graph stats
print("Graph stats:")
print_graph_stats(graph, positive_edges_snp_snp)

Graph stats:
Number of nodes: 20565627
Number of positive SNP-SNP edges: 1808
Number of edges: 1808
Node feature dimension: 9
Density: 0.0000000000
Features with NaN values:
[]


## Data splitting

In [8]:
from torch_geometric.transforms import RandomLinkSplit

transform = RandomLinkSplit(neg_sampling_ratio=1000, num_val=0.25, num_test=0.25, is_undirected=True, split_labels=True)

graph_train, graph_val, graph_test = transform(graph)

# Assert graph_train
assert isinstance(graph_train, Data), "graph_train is not an instance of torch_geometric.data.Data."
assert graph_train.num_nodes == graph.num_nodes, "Number of nodes in graph_train does not match the original graph."

# Assert graph_val
assert isinstance(graph_val, Data), "graph_val is not an instance of torch_geometric.data.Data."

# Assert graph_test
assert isinstance(graph_test, Data), "graph_test is not an instance of torch_geometric.data.Data."

print(graph_train)
print(graph_val)
print(graph_test)

Data(x=[20565627, 9], edge_index=[2, 788], edge_attr=[788], pos_edge_label=[394], pos_edge_label_index=[2, 394], neg_edge_label=[394000], neg_edge_label_index=[2, 394000])
Data(x=[20565627, 9], edge_index=[2, 788], edge_attr=[788], pos_edge_label=[196], pos_edge_label_index=[2, 196], neg_edge_label=[196000], neg_edge_label_index=[2, 196000])
Data(x=[20565627, 9], edge_index=[2, 1180], edge_attr=[1180], pos_edge_label=[196], pos_edge_label_index=[2, 196], neg_edge_label=[196000], neg_edge_label_index=[2, 196000])


## Models

### GCN

In [9]:
import torch.nn as nn
from torch_geometric.nn import SAGEConv, GINConv

torch.cuda.empty_cache()

# Define the GCN model
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        # First conv layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.75, training=self.training)

        # Second conv layer
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        return x


# Train and evaluate the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN(in_channels=9, hidden_channels=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

graph_train = graph_train.to(device)
graph_val = graph_val.to(device)
graph_test = graph_test.to(device)


from torch.nn import BCEWithLogitsLoss

def compute_alpha_beta(targets):
    num_positives = targets.sum()
    num_samples = len(targets)
    beta = num_positives / num_samples  # proportion of positive instances
    alpha = 1 - beta  # proportion of negative instances
    return alpha, beta

def compute_adaptive_gamma(val_roc_auc, prev_val_roc_auc, gamma, increment_factor=0.1):
    if val_roc_auc < prev_val_roc_auc:  # If the performance decreases, increase gamma
        gamma += increment_factor
    return gamma

# Train function
def train(prev_val_roc_auc, gamma):
    model.train()
    optimizer.zero_grad()
    z = model(graph_train.x.float(), graph_train.edge_index)

    # Only consider positive edges for the positive score calculation
    pos_edge_index = graph_train.pos_edge_label_index
    pos = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1)

    # Use the pre-generated negative labels
    neg_edge_index = graph_train.neg_edge_label_index
    neg = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)

    logits = torch.cat([pos, neg], dim=0)
    targets = torch.tensor([1] * pos.size(0) + [0] * neg.size(0), dtype=torch.float32).to(device)

    # Compute dynamic alpha and beta
    alpha, beta = compute_alpha_beta(targets)

    # Apply Focal Loss
    bce_loss = BCEWithLogitsLoss(reduction='none')
    bce = bce_loss(logits, targets)
    pt = torch.exp(-bce)  # Compute the focal term
    focal_loss = (alpha * (1 - pt) ** gamma * bce).mean()  # Compute the focal loss

    focal_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    optimizer.step()
    return focal_loss.item()

def pos_edge_accuracy(preds, true_labels):
    # Get predictions and true labels for positive edges
    pos_preds = preds[:len(true_labels)//2]
    pos_labels = true_labels[:len(true_labels)//2]

    # Convert predictions to binary labels
    bin_preds = (pos_preds > 0.5).astype(int)

    # Compute the accuracy
    accuracy = (bin_preds == pos_labels).mean()

    return accuracy


# Evaluation function
def evaluate(graph):
    model.eval()
    with torch.no_grad():
        z = model(graph.x.float(), graph.edge_index)

        pos_edge_index = graph.pos_edge_label_index
        pos = torch.sigmoid((z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1)).view(-1)

        neg_edge_index = graph.neg_edge_label_index
        neg = torch.sigmoid((z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)).view(-1)

        preds = torch.cat([pos, neg]).cpu().numpy()
        true_labels = np.concatenate([np.ones_like(pos.cpu().numpy()), np.zeros_like(neg.cpu().numpy())])

        # Handle NaN values
        preds = np.nan_to_num(preds, nan=0.5)

        roc_auc = roc_auc_score(true_labels, preds)
        pos_edge_acc = pos_edge_accuracy(preds, true_labels)

        return roc_auc, pos_edge_acc


gamma = 2.0  # Starting value for gamma
prev_val_roc_auc = 0.0  # Initial performance value
best_val_roc_auc = 0.0  # Best validation ROC-AUC score
early_stop = 10  # Number of epochs to continue without improvement before stopping
epochs_no_improve = 0  # Counter for epochs without improvement

for epoch in range(150):
    loss = train(prev_val_roc_auc, gamma)
    torch.cuda.empty_cache()  # Release unnecessary GPU memory
    scheduler.step()
    with torch.no_grad():  # Use torch.no_grad() during training
        val_roc_auc, val_pos_edge_acc = evaluate(graph_val)
        gamma = compute_adaptive_gamma(val_roc_auc, prev_val_roc_auc, gamma)
        prev_val_roc_auc = val_roc_auc
    print(f"Epoch: {epoch + 1}, Loss: {loss:.4f}, Val ROC-AUC: {val_roc_auc:.10f}, Val Pos Edge Acc: {val_pos_edge_acc:.10f}, Gamma: {gamma:.2f}")
    if val_roc_auc > best_val_roc_auc:
        best_val_roc_auc = val_roc_auc
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve == early_stop:
        print("Early stopping!")
        break

Epoch: 1, Loss: 17929294.0000, Val ROC-AUC: 0.5370016920, Val Pos Edge Acc: 0.7406369141, Gamma: 2.00
Epoch: 2, Loss: 291161.2188, Val ROC-AUC: 0.5075635933, Val Pos Edge Acc: 0.3049705397, Gamma: 2.10
Epoch: 3, Loss: 8829.8379, Val ROC-AUC: 0.5056755388, Val Pos Edge Acc: 0.0354033721, Gamma: 2.20
Epoch: 4, Loss: 1.4206, Val ROC-AUC: 0.5000000000, Val Pos Edge Acc: 0.9980019980, Gamma: 2.30
Epoch: 5, Loss: 0.1406, Val ROC-AUC: 0.5000000000, Val Pos Edge Acc: 0.9980019980, Gamma: 2.30
Epoch: 6, Loss: 0.1406, Val ROC-AUC: 0.5000000000, Val Pos Edge Acc: 0.9980019980, Gamma: 2.30
Epoch: 7, Loss: 0.1406, Val ROC-AUC: 0.5294697782, Val Pos Edge Acc: 0.1944993782, Gamma: 2.30
Epoch: 8, Loss: 5.6884, Val ROC-AUC: 0.5000000000, Val Pos Edge Acc: 0.0019980020, Gamma: 2.40
Epoch: 9, Loss: 0.1315, Val ROC-AUC: 0.5000000000, Val Pos Edge Acc: 0.9980019980, Gamma: 2.40
Epoch: 10, Loss: 0.1312, Val ROC-AUC: 0.5000000000, Val Pos Edge Acc: 0.9980019980, Gamma: 2.40
Epoch: 11, Loss: 0.1312, Val ROC-A