# Grapher

## Dataset

- `id`: This column represents the id of the variant in the following format: #chrom:pos:ref:alt (string).

- `#chrom`: This column represents the chromosome number where the genetic variant is located.

- `pos`: This is the position of the genetic variant on the chromosome.

- `ref`: This column represents the reference allele (or variant) at the genomic position.

- `alt`: This is the alternate allele observed at this position.

- `rsids`: This stands for reference SNP cluster ID. It's a unique identifier for each variant used in the dbSNP database.

- `nearest_genes`: This column represents the gene which is nearest to the variant.

- `pval`: This represents the p-value, which is a statistical measure for the strength of evidence against the null hypothesis.

- `mlogp`: This represents the minus log of the p-value, commonly used in genomic studies.

- `beta`: The beta coefficient represents the effect size of the variant.

- `sebeta`: This is the standard error of the beta coefficient.

- `af_alt`: This is the allele frequency of the alternate variant in the general population.

- `af_alt_cases`: This is the allele frequency of the alternate variant in the cases group.

- `af_alt_controls`: This is the allele frequency of the alternate variant in the control group.

- `finemapped`: This column represents whether the variant is included in the post-finemapped dataset (1) or not (0). 

- `trait`: This column represents the trait associated with the variant. In this dataset, it is the response to the drug paracetamol and NSAIDs.

## Load libraries

In [1]:
import sys
import os
import random
import numpy as np
from numba import jit, prange
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import sklearn
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, average_precision_score
from sklearn.preprocessing import RobustScaler, LabelEncoder, StandardScaler, OrdinalEncoder, OneHotEncoder, MinMaxScaler
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import to_undirected, negative_sampling
import networkx as nx
from scipy.spatial import cKDTree
from scipy.special import expit
from typing import List, Dict
import time
import cProfile
import pstats
import io
import category_encoders as ce
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
import copy
from torch_geometric.transforms import RandomNodeSplit
from collections import Counter



# Print versions of imported libraries
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Matplotlib version: {matplotlib.__version__}")
print(f"Scikit-learn version: {sklearn.__version__}")
print(f"Torch version: {torch.__version__}")
print(f"Torch Geometric version: {torch_geometric.__version__}")
print(f"NetworkX version: {nx.__version__}")

if torch.cuda.is_available():
    device = torch.device("cuda")          # Current CUDA device
    print(f"Using {torch.cuda.get_device_name()} ({device})")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
else:
    print("CUDA is not available on this device.")

Python version: 3.11.3 (tags/v3.11.3:f3909b8, Apr  4 2023, 23:49:59) [MSC v.1934 64 bit (AMD64)]
NumPy version: 1.24.3
Pandas version: 2.0.1
Matplotlib version: 3.7.1
Scikit-learn version: 1.2.2
Torch version: 2.0.0+cu118
Torch Geometric version: 2.3.1
NetworkX version: 3.0
Using NVIDIA GeForce RTX 3060 Ti (cuda)
CUDA version: 11.8
Number of CUDA devices: 1


## Load data

In [2]:
dtypes = {
    'id': 'string',
    '#chrom': 'int64',
    'pos': 'int64',
    'ref': 'string',
    'alt': 'string',
    'rsids': 'string',
    'nearest_genes': 'string',
    'pval': 'float64',
    'mlogp': 'float64',
    'beta': 'float64',
    'sebeta': 'float64',
    'af_alt': 'float64',
    'af_alt_cases': 'float64',
    'af_alt_controls': 'float64',
    'finemapped': 'int64'
}

data = pd.read_csv('~/Desktop/gwas-graph/FinnGen/data/gwas-finemap.csv', dtype=dtypes)

# Assert column names
expected_columns = ['#chrom', 'pos', 'ref', 'alt', 'rsids', 'nearest_genes', 'pval', 'mlogp', 'beta',
                    'sebeta', 'af_alt', 'af_alt_cases', 'af_alt_controls', 'finemapped',
                    'id', 'trait']
assert set(data.columns) == set(expected_columns), "Unexpected columns in the data DataFrame."

# Assert data types
expected_dtypes = {
    'id': 'string',
    '#chrom': 'int64',
    'pos': 'int64',
    'ref': 'string',
    'alt': 'string',
    'rsids': 'string',
    'nearest_genes': 'string',
    'pval': 'float64',
    'mlogp': 'float64',
    'beta': 'float64',
    'sebeta': 'float64',
    'af_alt': 'float64',
    'af_alt_cases': 'float64',
    'af_alt_controls': 'float64',
    'finemapped': 'int64'
}

for col, expected_dtype in expected_dtypes.items():
    assert data[col].dtype == expected_dtype, f"Unexpected data type for column {col}."

In [3]:
# Check for total number of null values in each column
null_counts = data.isnull().sum()

print("Total number of null values in each column:")
print(null_counts)

Total number of null values in each column:
#chrom                   0
pos                      0
ref                      0
alt                      0
rsids              1366396
nearest_genes       727855
pval                     0
mlogp                    0
beta                     0
sebeta                   0
af_alt                   0
af_alt_cases             0
af_alt_controls          0
id                       0
finemapped               0
trait                    0
dtype: int64


## Data manipulation

In [4]:
data = data.sample(frac=0.001, random_state=42)

### Find nearest gene

In [5]:
data['nearest_genes'] = data['nearest_genes'].astype(str)

# Assert column 'nearest_genes' is a string
assert data['nearest_genes'].dtype == 'object', "Column 'nearest_genes' is not of string type."

# Get the length of the data before transformation
original_length = len(data)

# Extract the first gene name from the 'nearest_genes' column
data['nearest_genes'] = data['nearest_genes'].str.split(',').str[0]

# Reset index to have a standard index
data = data.reset_index(drop=True)

# Assert the length of the data remains the same
assert len(data) == original_length, "Length of the data has changed after transformation."

## Spec

Here's the adjusted spec in markdown format:

### Data

`data` Pandas DataFrame:

- `id`: This column represents the id of the variant in the following format: #chrom:pos:ref:alt (string).
- `#chrom`: This column represents the chromosome number where the genetic variant is located.
- `pos`: This is the position of the genetic variant on the chromosome (int: 1-200,000).
- `ref`: This column represents the reference allele (or variant) at the genomic position.
- `alt`: This is the alternate allele observed at this position.
- `rsids`: This stands for reference SNP cluster ID. It's a unique identifier for each variant used in the dbSNP database.
- `nearest_genes`: This column represents the gene which is nearest to the variant (string).
- `pval`: This represents the p-value, which is a statistical measure for the strength of evidence against the null hypothesis.
- `mlogp`: This represents the minus log of the p-value, commonly used in genomic studies.
- `beta`: The beta coefficient represents the effect size of the variant.
- `sebeta`: This is the standard error of the beta coefficient.
- `af_alt`: This is the allele frequency of the alternate variant in the general population (float: 0-1.
- `af_alt_cases`: This is the allele frequency of the alternate variant in the cases group (float: 0-1).
- `af_alt_controls`: This is the allele frequency of the alternate variant in the control group (float: 0-1).
- `finemapped`: This column represents whether the variant is included in the post-finemapped dataset (1) or not (0) (int).
- `trait`: This column represents the trait associated with the variant. In this dataset, it is the response to the drug paracetamol and NSAIDs.

### Task Overview

The objective is to design and implement a binary node classification GNN model to predict whether variants are included after post-finemapping or not based on `finemapping`.

### Nodes and Their Features

There is one type of node: SNP nodes.

- **SNP Nodes**: Each SNP Node is characterized by various features, including `id`, `nearest_genes`, `#chrom`, `pos`, `ref`, `alt`, `mlogp`, `beta`, `sebeta`,  `af_alt`, `af_alt_cases`, and `af_alt_controls` columns.

### Edges, Their Features, and Labels

Edges represent relationships between SNP nodes in the graph.

1. **Type 1 Edges: LD-based edges**

   - For each pair of SNPs (row1 and row2) that exist on the same chromosome (`#chrom`), an edge is created if the absolute difference between their positions (`pos`) is less than or equal to 500,000 and greater than 1 (no loops).
   - The weight of the edge is determined by the following formula:
     
```
    weights = (average_mlogp / (1 + pos_diff_abs * \
                      af_alt_diff_abs * \
                      af_alt_cases_diff_abs * \
                      af_alt_controls_diff_abs))
```

    - For each chromosome, standardize the edge weights between 0 and 1 after all weights have been computed.
    - Prune any edges that have a weight of less than `1e-3`

## Graph creation

In [6]:
edge_weight_cutoff = 1e-3  # set the cutoff value, this is an example and you can choose any suitable value


def get_unique_snps(data: pd.DataFrame) -> dict:
    """
    Function to create mappings for SNPs to integer indices.
    """
    return {snp: idx for idx, snp in enumerate(data['id'].unique())}


def preprocess_snp_features(data: pd.DataFrame, snp_to_idx: dict) -> pd.DataFrame:
    """
    Function to create node feature vectors for SNPs and preprocess categorical and numerical features.
    """
    # Ensure that 'id' exists in the data and 'id' and 'nearest_genes' are not null
    assert 'id' in data.columns and 'nearest_genes' in data.columns, "Columns 'id' or 'nearest_genes' do not exist in the dataframe"
    assert data[['id', 'nearest_genes']].isnull().sum().sum() == 0, "Columns 'id' or 'nearest_genes' contain null values"
    
    # Columns to be extracted from the original dataframe
    cols_to_extract = ['id', 'nearest_genes', '#chrom', 'pos', 'ref', 'alt', 'mlogp', 'beta', 'sebeta', 'af_alt', 'af_alt_cases', 'af_alt_controls']
    
    snp_features = data.loc[data['id'].isin(snp_to_idx.keys()), cols_to_extract].set_index('id').sort_index()
    scaler = RobustScaler()

    # Frequency encoding for 'nearest_genes' using LabelEncoder
    label_encoder = LabelEncoder()
    snp_features['nearest_genes'] = label_encoder.fit_transform(snp_features['nearest_genes'])
    
    categorical_cols = ['ref', 'alt']

    # BinaryEncoder for 'ref', 'alt'
    binary_encoder = ce.BinaryEncoder(cols=categorical_cols, drop_invariant=True)
    snp_features = binary_encoder.fit_transform(snp_features)

    numerical_cols = list(set(snp_features.columns) - set(categorical_cols))
    snp_features[numerical_cols] = scaler.fit_transform(snp_features[numerical_cols])
    
    # Filling 0 values for all columns
    snp_features = snp_features.fillna(0)

    return snp_features


def preprocess_positive_edges(data: pd.DataFrame, snp_to_idx: dict) -> torch.Tensor:
    scaler = RobustScaler()

    # Sort data once before grouping
    data = data.sort_values(by=['#chrom', 'pos'])

    # Create new column for SNP index
    data['snp_idx'] = data['id'].map(snp_to_idx)

    positive_edges_snp_snp = []
    snp_weights = []

    for chrom, group in data.groupby('#chrom'):
        # Skip if group is empty
        if group.empty:
            continue

        for idx, row in group.iterrows():
            pos_diff = group['pos'] - row['pos']
            mask = (pos_diff > 1) & (pos_diff <= 500000)
            filtered_group = group[mask]

            if filtered_group.empty:
                continue
            
            af_alt_diff = abs(row['af_alt'] - filtered_group['af_alt'])
            pos_diff_abs = abs(pos_diff[mask])
            average_mlogp = (row['mlogp'] + filtered_group['mlogp']) / 2
            weights = (average_mlogp / (1 + pos_diff_abs * af_alt_diff  * \
                      abs(row['af_alt_cases'] - filtered_group['af_alt_cases']) * \
                      abs(row['af_alt_controls'] - filtered_group['af_alt_controls'])))

            # Skip if weights are empty
            if len(weights) == 0:
                continue

            positive_edges_snp_snp.extend(zip([row['snp_idx']] * len(filtered_group), filtered_group['snp_idx']))
            snp_weights.extend(weights)

    # Normalize weights
    snp_weights = scaler.fit_transform(np.array(snp_weights).reshape(-1, 1)).flatten()
    mask_cutoff = snp_weights >= edge_weight_cutoff
    filtered_edges = np.array(positive_edges_snp_snp)[mask_cutoff]
    filtered_weights = snp_weights[mask_cutoff]

    return torch.tensor(filtered_edges, dtype=torch.long).t().contiguous(), torch.tensor(filtered_weights, dtype=torch.float)


def create_pytorch_graph(features: torch.Tensor, edges: torch.Tensor, edge_weights: torch.Tensor) -> Data:
    return Data(x=features, edge_index=edges, edge_attr=edge_weights)


# Add profiling
pr = cProfile.Profile()
pr.enable()

start_time = time.time()


# Main
snp_to_idx = get_unique_snps(data)
labels = data['finemapped'].map(lambda x: 1 if x > 0 else 0)
snp_features = preprocess_snp_features(data, snp_to_idx)
features = torch.tensor(snp_features.values, dtype=torch.float)

positive_edges_snp_snp, snp_weights = preprocess_positive_edges(data, snp_to_idx)
graph = create_pytorch_graph(features, positive_edges_snp_snp, snp_weights)
graph.y = torch.tensor(labels.values, dtype=torch.long)

print(f"Number of nodes: {graph.num_nodes}")
print(f"Number of edges: {graph.num_edges}")
print(f"Node feature dimension: {graph.num_node_features}")

elapsed_time = time.time() - start_time
print(f"Execution time: {elapsed_time} seconds")

pr.disable()
s = io.StringIO()
sortby = 'cumulative'
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats(3)  # Only print the top 5 lines
print(s.getvalue())


Number of nodes: 20170
Number of edges: 36767
Node feature dimension: 26
Execution time: 55.78683924674988 seconds
         112926481 function calls (110616940 primitive calls) in 55.778 seconds

   Ordered by: cumulative time
   List reduced from 1279 to 3 due to restriction <3>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       14    0.000    0.000   55.786    3.985 C:\Users\falty\AppData\Local\Programs\Python\Python311\Lib\site-packages\IPython\core\interactiveshell.py:3472(run_code)
       14    0.000    0.000   55.786    3.985 {built-in method builtins.exec}
        1    0.750    0.750   55.512   55.512 C:\Users\falty\AppData\Local\Temp\ipykernel_2092\966270683.py:44(preprocess_positive_edges)





## Save/Load graph

In [7]:
# Save PyTorch Geometric graph

#torch.save(graph, "pytorch_geometric_graph.pt")
#graph = torch.load("pytorch_geometric_graph.pt")

## Graph stats

In [8]:
from torch_geometric.utils import degree

def print_graph_stats(graph, positive_edges_snp_snp, features_list):
    print(f"Number of nodes: {graph.num_nodes}")
    print(f"Number of edges: {graph.num_edges}")
    print(f"Node feature dimension: {graph.num_node_features}")
    print(f"Number of finemapped nodes: {data['finemapped'].sum()}")

    # Compute and print degree-related stats
    degrees = degree(graph.edge_index[0].long(), num_nodes=graph.num_nodes)
    average_degree = degrees.float().mean().item()
    median_degree = np.median(degrees.numpy())
    std_degree = degrees.float().std().item()

    print(f"Average Degree: {average_degree}")
    print(f"Median Degree: {median_degree}")
    print(f"Standard Deviation of Degree: {std_degree}")

    # Density is the ratio of actual edges to the maximum number of possible edges
    num_possible_edges = graph.num_nodes * (graph.num_nodes - 1) / 2
    density = graph.num_edges / num_possible_edges

    print(f"Density: {density:.10f}")

    # Check for NaN values in features
    nan_mask = torch.isnan(graph.x)
    nan_features = []
    for feature_idx, feature_name in enumerate(features_list):
        if nan_mask[:, feature_idx].any():
            nan_features.append(feature_name)

    print("Features with NaN values:")
    print(nan_features)

def print_edge_weight_stats(edge_weights):
    print(f"Number of edges: {edge_weights.size(0)}")
    print(f"Average edge weight: {edge_weights.float().mean().item()}")
    print(f"Median edge weight: {np.median(edge_weights.numpy())}")
    print(f"Standard deviation of edge weights: {edge_weights.float().std().item()}")
    print(f"Maximum edge weight: {edge_weights.max().item()}")
    print(f"Minimum edge weight: {edge_weights.min().item()}")

# Print graph stats
print("Graph stats:")
print_graph_stats(graph, positive_edges_snp_snp, snp_features.columns)

print("Edge weight stats:")
print_edge_weight_stats(graph.edge_attr)

Graph stats:
Number of nodes: 20170
Number of edges: 36767
Node feature dimension: 26
Number of finemapped nodes: 3020
Average Degree: 1.8228557109832764
Median Degree: 2.0
Standard Deviation of Degree: 1.7116793394088745
Density: 0.0001807582
Features with NaN values:
[]
Edge weight stats:
Number of edges: 36767
Average edge weight: 1.2632436752319336
Median edge weight: 0.9518038034439087
Standard deviation of edge weights: 1.2150322198867798
Maximum edge weight: 34.015201568603516
Minimum edge weight: 0.0010152175091207027


## Data splitting

In [9]:
# Split the data into training, validation, and test sets
transform = RandomNodeSplit(split="train_rest", num_val=0.2, num_test=0.2, key='y')
graph = transform(graph)

# Count the number of nodes per class in each set
train_class_counts = Counter(graph.y[graph.train_mask].numpy())
val_class_counts = Counter(graph.y[graph.val_mask].numpy())
test_class_counts = Counter(graph.y[graph.test_mask].numpy())

# Print the results
print("Number of nodes per class in each set:")
print("Train set:")
for class_label, count in train_class_counts.items():
    print(f"Class {class_label}: {count} nodes")
print("Validation set:")
for class_label, count in val_class_counts.items():
    print(f"Class {class_label}: {count} nodes")
print("Test set:")
for class_label, count in test_class_counts.items():
    print(f"Class {class_label}: {count} nodes")


Number of nodes per class in each set:
Train set:
Class 0: 10248 nodes
Class 1: 1854 nodes
Validation set:
Class 0: 3434 nodes
Class 1: 600 nodes
Test set:
Class 0: 3468 nodes
Class 1: 566 nodes


## Model

### GraphSAGE

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class GraphSAGEModel(torch.nn.Module):
    def __init__(self, num_node_features, hidden_layers, num_classes):
        super(GraphSAGEModel, self).__init__()
        self.conv1 = SAGEConv(num_node_features, hidden_layers[0])
        self.conv2 = SAGEConv(hidden_layers[0], hidden_layers[1])
        self.conv3 = SAGEConv(hidden_layers[1], num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.leaky_relu(x)
        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv2(x, edge_index)
        x = F.leaky_relu(x)
        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv3(x, edge_index)

        return torch.sigmoid(x.view(-1))

# Use the GraphSAGE model
hidden_layers = [64, 64]  
model = GraphSAGEModel(graph.num_node_features, hidden_layers, 1).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.2)

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=None, gamma=2, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.alpha is not None:
            alpha_t = self.alpha[targets.long()].view(-1, 1)
            logpt = -F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
            logpt = logpt * alpha_t
        else:
            logpt = -F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        
        pt = torch.exp(logpt)
        F_loss = -((1 - pt) ** self.gamma) * logpt

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

class_counts = graph.y[graph.train_mask].bincount(minlength=2).float()
class_weights = 1. / class_counts
class_weights = class_weights / class_weights.sum()
class_weights = class_weights.detach().to(device)  # Add to(device)
print(class_weights)

loss_fn = FocalLoss(alpha=class_weights, gamma=2)

def train(data):
    data = data.to(device)  # Add this line
    model.train()
    optimizer.zero_grad()
    out = model(data)[data.train_mask].squeeze()
    loss = loss_fn(out, data.y[data.train_mask].float())
    loss.backward()
    optimizer.step()
    scheduler.step()  # Update learning rate
    return loss.item()

def evaluate(data, mask):
    data = data.to(device)  # Add this line
    model.eval()
    with torch.no_grad():
        out = model(data)[mask].squeeze()  # Predicted probabilities
        preds = (out > 0.5).long()  # Binary predictions
        print(f"Classes predicted: {preds.unique().cpu().numpy()}")  # Move to CPU before converting to NumPy
        accuracy = preds.eq(data.y[mask]).sum().item() / mask.sum().item()
        precision = precision_score(data.y[mask].cpu().numpy(), preds.cpu().numpy(), zero_division=1)  # Move to CPU
        recall = recall_score(data.y[mask].cpu().numpy(), preds.cpu().numpy(), zero_division=1)  # Move to CPU
        f1 = f1_score(data.y[mask].cpu().numpy(), preds.cpu().numpy(), zero_division=1)  # Move to CPU
        roc_auc = roc_auc_score(data.y[mask].cpu().numpy(), out.cpu().numpy())  # Use out, not preds
    return accuracy, precision, recall, f1, roc_auc

best_model = None
best_val_roc_auc = 0
patience = 10
epochs_no_improve = 0
early_stop_threshold = 0.0001

for epoch in range(150):
    loss = train(graph)
    train_metrics = evaluate(graph, graph.train_mask)
    val_metrics = evaluate(graph, graph.val_mask)
    test_metrics = evaluate(graph, graph.test_mask)
    print(f'Epoch: {epoch+1}, Loss: {loss:.10f}, Train ROC AUC: {train_metrics[4]:.10f}, Val ROC AUC: {val_metrics[4]:.10f}, Test ROC AUC: {test_metrics[4]:.10f}')

    # Check if the validation ROC AUC score is higher than what we've seen so far
    if val_metrics[4] > best_val_roc_auc + early_stop_threshold:
        # We've seen a model performance improvement!
        best_val_roc_auc = val_metrics[4]
        best_model = copy.deepcopy(model.state_dict())
        epochs_no_improve = 0  # Reset the count
    else:
        # We did not see any improvement this epoch
        epochs_no_improve += 1

    # If we've had too many epochs with no improvement, stop training early
    if epochs_no_improve == patience:
        print("Early stopping!")
        break

# Load the best model back in
model.load_state_dict(best_model)
test_metrics = evaluate(graph, graph.test_mask)
print(f'After early stopping, Test ROC AUC: {test_metrics[4]:.10f}')

tensor([0.1532, 0.8468], device='cuda:0')
Classes predicted: [0 1]
Classes predicted: [0 1]
Classes predicted: [0 1]
Epoch: 1, Loss: 0.0403727069, Train ROC AUC: 0.4882356870, Val ROC AUC: 0.5003069792, Test ROC AUC: 0.4829738121
Classes predicted: [0 1]
Classes predicted: [0 1]
Classes predicted: [0 1]
Epoch: 2, Loss: 0.0402203985, Train ROC AUC: 0.4884991636, Val ROC AUC: 0.5001739953, Test ROC AUC: 0.4835459282
Classes predicted: [0 1]
Classes predicted: [0 1]
Classes predicted: [0 1]
Epoch: 3, Loss: 0.0400897264, Train ROC AUC: 0.4887786403, Val ROC AUC: 0.5000400408, Test ROC AUC: 0.4842693521
Classes predicted: [0 1]
Classes predicted: [0 1]
Classes predicted: [0 1]
Epoch: 4, Loss: 0.0399211049, Train ROC AUC: 0.4891442496, Val ROC AUC: 0.5000550864, Test ROC AUC: 0.4848957251
Classes predicted: [0 1]
Classes predicted: [0 1]
Classes predicted: [0 1]
Epoch: 5, Loss: 0.0398151912, Train ROC AUC: 0.4895148063, Val ROC AUC: 0.5000342167, Test ROC AUC: 0.4853789416
Classes predicted: