[![Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/giordamaug/EG-identification---Data-Science-in-App-Springer/blob/main/notebook/EssentialGenes_GNN.ipynb)
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/giordamaug/EG-identification---Data-Science-in-App-Springer/main?filepath=notebook%2FEssentialGenes_GNN.ipynb)

# Loading required libraries

### Install Pytorch libraries

In [1]:
import sys
IN_COLAB = 'google.colab' in sys.modules
if not IN_COLAB:
    !pip install -q pandas
    !pip install -q pandas
    !pip install -q sklearn
    !pip install -q imblearn
    !pip install -q xgboost
    !pip install -q tqdm
    !conda install -y pytorch torchvision -c pytorch

import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric 

Looking in links: https://pytorch-geometric.com/whl/torch-1.11.0+cu113.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.11.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 2.8 MB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.0.9
Looking in links: https://pytorch-geometric.com/whl/torch-1.11.0+cu113.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.11.0%2Bcu113/torch_sparse-0.6.13-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 2.5 MB/s 
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.13
Looking in links: https://pytorch-geometric.com/whl/torch-1.11.0+cu113.html
Collecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.11.0%2Bcu113/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl (2.5 MB)
[K     |████████████████████████████████| 2.5

In [2]:
import warnings
warnings.filterwarnings('ignore')
import random
import numpy as np
import pandas as pd
import torch
def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

# Download dataset from Github

In [3]:
!wget https://raw.githubusercontent.com/giordamaug/EG-identification---Data-Science-in-App-Springer/main/data/ppi.csv
!wget https://raw.githubusercontent.com/giordamaug/EG-identification---Data-Science-in-App-Springer/main/data/labels.csv
!wget https://raw.githubusercontent.com/giordamaug/EG-identification---Data-Science-in-App-Springer/main/data/bio_attributes.csv
!wget https://raw.githubusercontent.com/giordamaug/EG-identification---Data-Science-in-App-Springer/main/data/net_attributes.csv
!wget https://raw.githubusercontent.com/giordamaug/EG-identification---Data-Science-in-App-Springer/main/data/gtex_attributes.csv

--2022-05-04 13:15:21--  https://raw.githubusercontent.com/giordamaug/EG-identification---Data-Science-in-App-Springer/main/data/ppi.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3966521 (3.8M) [text/plain]
Saving to: ‘ppi.csv’


2022-05-04 13:15:21 (49.7 MB/s) - ‘ppi.csv’ saved [3966521/3966521]

--2022-05-04 13:15:21--  https://raw.githubusercontent.com/giordamaug/EG-identification---Data-Science-in-App-Springer/main/data/labels.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 237495 (232K) [text/plain]
Saving to: ‘labe

# Load the label
Only a subset of genes are selected for classification:
+ genes belonging to CS0 group, that are labeled as Essential (E);
+ genes belonging to CS6, CS7, ..., CS9 groups, that are labeled as Not-Essential (NE).

All remaining genes belong to intermediate groups (CS1-CS5) and are considered undetermined (label ND) 

In [4]:
labels = pd.read_csv("labels.csv", index_col='name')
labels = labels[labels["CS0_vs_CS6-9"].isin(['E', 'NE']) == True]       # drop any gene with undefined (ND) label
genes = labels.index.values                                             # get genes with defined labels (E or NE)
print(f'Selected {len(genes)} genes')

Selected 3814 genes


## Encode the labels
String labels E and Ne are respectively encoded to 0 and 1.
The array `y` containes numeric labels of genes.

In [5]:
from sklearn import preprocessing
from collections import Counter
encoder = preprocessing.LabelEncoder()
y = encoder.fit_transform(labels['CS0_vs_CS6-9'].values)  
classes_mapping = dict(zip(encoder.classes_, encoder.transform(encoder.classes_)))
print(classes_mapping, Counter(y))

{'E': 0, 'NE': 1} Counter({1: 3069, 0: 745})


# Load attributes to be used
We identified three sets of attributes:
1. bio attributes, related to gene information (such as, expression, etc.)
2. net attributes, derived from role of gene/node in the network (such as, degree, centrality, etc.)
3. GTEX-* attribute, additional biological information of genes 
Based on user selection, the node attributes are appended in a single matrix of attributes (`x`)

In the attribute matrix `x` there can be NaN or Infinite values. They are corrected as it follow:
+ NaN is replaced by the mean in the attribute range, 
+ Infinte value is replaced by the maximum in the range.

After Nan and Infinite values fixing, the attributes are normalized with Z-score or MinMax normalization functions.

At the end, only nodes (genes) with E or NE labels are selected for the classification

In [6]:
#@title Choose attributes { form-width: "20%" }
normalize_node = "zscore" #@param ["", "zscore", "minmax"]
bio = True #@param {type:"boolean"}
gtex = True #@param {type:"boolean"}
net = True #@param {type:"boolean"}
variable_name = "bio"
bio_df = pd.read_csv("bio_attributes.csv", index_col='name') if bio else pd.DataFrame()
gtex_df = pd.read_csv("gtex_attributes.csv", index_col='name') if gtex else pd.DataFrame()
net_df = pd.read_csv("net_attributes.csv", index_col='name') if net else pd.DataFrame()
x = pd.concat([bio_df, gtex_df, net_df], axis=1)
print(f'Found {x.isnull().sum().sum()} NaN values and {np.isinf(x).values.sum()} Infinite values')
for col in x.columns[x.isna().any()].tolist():
  mean_value=x[col].mean()          # Replace NaNs in column with the mean of values in the same column
  if mean_value is not np.nan:
    x[col].fillna(value=mean_value, inplace=True)
  else:                             # otherwise, if the mean is NaN, remove the column
    x = x.drop(col, 1)
if normalize_node == 'minmax':
  print("X attributes normalization (minmax)...")
  x = (x-x.min())/(x.max()-x.min())
elif normalize_node == 'zscore':
  print("X attributes normalization (zscore)...")
  x = (x-x.mean())/x.std()
x = x.loc[genes]
print(f'New attribute matrix x{x.shape}')

Found 15919 NaN values and 0 Infinite values
X attributes normalization (zscore)...
New attribute matrix x(3814, 119)


# Load the PPI+MET network
The PPI networks is loaded from a CSV file, where
*   `A` is the column name for edge source (gene name)
*   `B` is the column name for edge target (gene name)
*   `weight` is the column name for edge weight
Only some method use the PPI netoworks, as an example all GCN methods, and Node2Vec.

The PPI+MET network is reduced by removing genes with undetermined labels

In [7]:
ppi = pd.read_csv('ppi.csv')                                               # read PPI+MET network from CSV file
ppi = ppi.loc[((ppi['A'].isin(genes)) & (ppi['B'].isin(genes)))]           # reduce network only to selected nodes/genes
idxlbl = labels.reset_index(drop=True)
idxlbl['name'] = labels.index
map_gene_to_idx = { v['name']: i  for i,v in idxlbl.to_dict('Index').items() }
vfunc = np.vectorize(lambda t: map_gene_to_idx[t])
edges_index = torch.from_numpy(vfunc(ppi[['A','B']].to_numpy().T)) 

## Normalize edge weights

In [8]:
#@title Edge normalization { form-width: "30%" }
normalize_edge = "minmax" #@param ["", "zscore", "minmax"]
if normalize_edge == 'minmax':
    maximum = ppi.loc[ppi['weight'] != np.inf, 'weight'].max()   # get max other than infinity
    minimum = ppi.loc[ppi['weight'] != np.nan, 'weight'].min()   # get min other than NaN
    ppi['weight'].replace(np.inf,maximum,inplace=True)             # replace ininity with max
    ppi['weight'].replace(np.nan,minimum,inplace=True)             # replace NaN with min
    ppi['weight'] = (ppi['weight'] - minimum) / (maximum - minimum)
elif normalize_edge == 'zscore':
    ppi['weight'] = (ppi['weight'] - ppi['weight'].mean()) / ppi['weight'].std()    

# Build PyG storage for network

In [9]:
import numpy as np
import pandas as pd
import tensorflow as tf
import random
from tqdm import tqdm
import sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from sklearn import metrics
from sklearn import preprocessing
from collections import Counter

import seaborn as sns

import torch
import torch.nn.functional as F

import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.transforms import RandomNodeSplit
import torch.nn.functional as F
classes_mapping = dict(zip(encoder.classes_, encoder.transform(encoder.classes_)))
data = Data(x=torch.from_numpy(x.to_numpy()).float(), edge_index=edges_index, edge_attr=torch.from_numpy(ppi['weight'].values).float(), y = torch.from_numpy(y))
data.num_classes = len(np.unique(y))
tfs =  RandomNodeSplit()
tfs(data)

train_indices = np.arange(0,len(data.x))
data.train_idx = torch.tensor(train_indices[data.train_mask], dtype=torch.long)
data.val_idx = torch.tensor(train_indices[data.val_mask], dtype=torch.long)
data.test_idx = torch.tensor(train_indices[data.test_mask], dtype=torch.long)

print()
print(data)
print('===========================================================================================================')

# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of node features: {data.num_features}')
print(f'Number of {classes_mapping} classes: {data.num_classes}')
print(f'Class distritions: {Counter(data.y.numpy())}')
print(f'Number of edges: {data.num_edges}')
print(f'Nodes indices: {train_indices}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node indices: {data.train_idx}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has edge weights: {data.edge_attr}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is directed: {data.is_directed()}')


Data(x=[3814, 119], edge_index=[2, 107513], edge_attr=[107513], y=[3814], num_classes=2, train_mask=[3814], val_mask=[3814], test_mask=[3814], train_idx=[2314], val_idx=[500], test_idx=[1000])
Number of nodes: 3814
Number of node features: 119
Number of {'E': 0, 'NE': 1} classes: 2
Class distritions: Counter({1: 3069, 0: 745})
Number of edges: 107513
Nodes indices: [   0    1    2 ... 3811 3812 3813]
Average node degree: 28.19
Number of training nodes: 2314
Training node indices: tensor([   0,    1,    2,  ..., 3809, 3812, 3813])
Training node label rate: 0.61
Has isolated nodes: True
Has edge weights: tensor([0.0413, 0.0338, 0.0395,  ..., 0.0003, 0.0066, 0.0047])
Has self-loops: False
Is directed: True


# The Trainer class

In [10]:
import copy
from dataclasses import dataclass
import os

import torch
import torch.nn as nn
from torch.optim import Adam, lr_scheduler
from tqdm import tqdm

from sklearn import metrics

@dataclass
class RunConfig:  # default parameters from the paper and official implementation
    learning_rate: float = 0.01
    num_epochs: int = 200
    weight_decay: float = 5e-4
    num_warmup_steps: int = 0
    save_each_epoch: bool = False
    output_dir: str = "."

class Trainer:
    def __init__(self, model):
        self.model = model

    def train(self, features, train_labels, val_labels, edge_index, edge_weights, device, run_config, log=True, eval_train=False):
        self.model = self.model.to(device)
        features = features.to(device)
        train_labels = train_labels.to(device)
        val_labels = val_labels.to(device)
        edge_index = edge_index.to(device)  # edhe list and weight
        edge_weights = edge_weights.to(device)

        optimizer = Adam(self.model.parameters(), lr=run_config.learning_rate, weight_decay=run_config.weight_decay)

        # https://huggingface.co/transformers/_modules/transformers/optimization.html#get_linear_schedule_with_warmup
        def lr_lambda(current_step: int):
            if current_step < run_config.num_warmup_steps:
                return float(current_step) / float(max(1, run_config.num_warmup_steps))
            return max(0.0, float(run_config.num_epochs - current_step) /
                       float(max(1, run_config.num_epochs - run_config.num_warmup_steps)))

        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda)

        if log:
            print("Training started:")
            print(f"\tNum Epochs = {run_config.num_epochs}")

        best_loss, best_model_accuracy = float("inf"), 0
        best_model_state_dict = None
        if log: train_iterator = tqdm(range(0, int(run_config.num_epochs)))
        else: train_iterator = range(0, int(run_config.num_epochs))
        train_logs = {'train loss' : [], 'train acc' : [], 'train mcc' : [], 'val loss' : [],  'val acc' : [], 'val mcc' : []}
        for epoch in train_iterator:
            self.model.train()
            outputs = self.model(features, edge_index, edge_weights, train_labels)
            loss = outputs[1]

            self.model.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            if eval_train: train_loss, train_accuracy, train_mcc, _ , _ = self.evaluate(features, train_labels, edge_index, edge_weights, device)
            val_loss, val_accuracy, val_mcc, _ , _= self.evaluate(features, val_labels, edge_index, edge_weights, device)
            if log: train_iterator.set_description(f"Training loss = {loss.item():.4f}, " f"val loss = {val_loss:.4f}, val acc = {val_accuracy:.2f}, val mcc = {val_mcc:.2f}")

            save_best_model = val_loss < best_loss
            if save_best_model:
                best_loss = val_loss
                best_model_accuracy = val_accuracy
                best_model_state_dict = copy.deepcopy(self.model.state_dict())
            if save_best_model or run_config.save_each_epoch or epoch + 1 == run_config.num_epochs:
                output_dir = os.path.join(run_config.output_dir, f"Epoch_{epoch + 1}")
                #self.save(output_dir)
            if eval_train:
                train_logs['train loss'].append(train_loss)
                train_logs['train acc'].append(train_accuracy)
                train_logs['train mcc'].append(train_mcc)
            train_logs['val loss'].append(val_loss)
            train_logs['val acc'].append(val_accuracy)
            train_logs['val mcc'].append(val_mcc)
        if log:
            print(f"Best model val CE loss = {best_loss:.4f}, best model val accuracy = {best_model_accuracy:.2f}")
        # reloads the best model state dict, bit hacky :P
        self.model.load_state_dict(best_model_state_dict)
        return train_logs 

    def evaluate(self, features, test_labels, edge_index, edge_weights, device):
        features = features.to(device)
        test_labels = test_labels.to(device)
        edge_index = edge_index.to(device)  # edhe list and weight
        edge_weights = edge_weights.to(device)

        self.model.eval()

        outputs = self.model(features, edge_index, edge_weights, test_labels)
        ce_loss = outputs[1].item()

        ignore_label = nn.CrossEntropyLoss().ignore_index
        predicted_label = torch.max(outputs[0], dim=1).indices[test_labels != ignore_label]
        true_label = test_labels[test_labels != -100]
        accuracy = torch.mean((true_label == predicted_label).type(torch.FloatTensor)).item()
        mcc = metrics.matthews_corrcoef(true_label, predicted_label)
        cm = metrics.confusion_matrix(true_label, predicted_label)
        return ce_loss, accuracy, mcc, cm, predicted_label 

    def save(self, output_dir):
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        model_path = os.path.join(output_dir, "model.pth")
        torch.save(self.model.state_dict(), model_path)

# The model GNN 

In [11]:
from torch_geometric.nn import GCNConv, GATConv, ChebConv
import torch.nn as nn
from torch_geometric.nn import GraphUNet
from torch_geometric.utils import dropout_adj

use_gdc = True
weights=torch.tensor([round(Counter(data.y.numpy())[classes_mapping['NE']]/len(data.y),2), round(Counter(data.y.numpy())[classes_mapping['E']]/len(data.y),2)])
#weights = None

def reset_weights(m):
  '''
    Try resetting model weights to avoid
    weight leakage.
  '''
  for layer in m.children():
   if hasattr(layer, 'reset_parameters'):
    #print(f'Reset trainable parameters of layer = {layer}')
    layer.reset_parameters()

class ChebNetGCN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_hidden_layers=0, dropout=0.1, residual=False, k=2):
        super(ChebNetGCN, self).__init__()

        self.dropout = dropout
        self.residual = residual

        self.input_conv = ChebConv(input_size, hidden_size, k)
        self.hidden_convs = nn.ModuleList([ChebConv(hidden_size, hidden_size, k) for _ in range(num_hidden_layers)])
        self.output_conv = ChebConv(hidden_size, output_size, k)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor, labels: torch.Tensor = None):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.input_conv(x, edge_index, edge_weight))
        for conv in self.hidden_convs:
            if self.residual:
                x = F.relu(conv(x, edge_index, edge_weight)) + x
            else:
                x = F.relu(conv(x, edge_index, edge_weight))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.output_conv(x, edge_index, edge_weight)

        if labels is None:
            return x

        loss = nn.CrossEntropyLoss()(x, labels)
        return x, loss

class GAT(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout=0.1):
        super(GAT, self).__init__()
        torch.manual_seed(42)

        self.conv1 = GATConv(input_size, hidden_size, heads=hidden_size, edge_dim=1, dropout=0.6)
        # On the Pubmed dataset, use heads=8 in conv2.
        self.conv2 = GATConv(hidden_size * hidden_size, output_size, heads=1, edge_dim=1, concat=False,
                             dropout=0.6)
        self.elu = nn.ELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor, labels: torch.Tensor = None):
        x = self.dropout(x)
        x = self.conv1(x, edge_index, edge_weight)
        x = self.elu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index, edge_weight)
        if labels is None:
            return x

        loss = nn.CrossEntropyLoss(weight=weights)(x, labels)
        return x, loss

class OneLayerGCN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout=0.1):
        super(OneLayerGCN, self).__init__()
        torch.manual_seed(42)
        self.conv = GCNConv(input_size, output_size, cached=True, improved=True, add_self_loops=True, normalize=not use_gdc)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor, labels: torch.Tensor = None):
        x = self.dropout(x)
        x = self.conv(x, edge_index, edge_weight)
        if labels is None:
            return x

        loss = nn.CrossEntropyLoss(weight=weights)(x, labels)
        return x, loss

class TwoLayerGCN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout=0.1):
        super(TwoLayerGCN, self).__init__()
        torch.manual_seed(42)
        self.conv1 = GCNConv(input_size, hidden_size, cached=True, improved=True, add_self_loops=True, normalize=not use_gdc)
        self.conv2 = GCNConv(hidden_size, output_size, cached=True, improved=True, add_self_loops=True, normalize=not use_gdc)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor, labels: torch.Tensor = None):
        x = self.dropout(x)
        x = self.conv1(x, edge_index, edge_weight)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index, edge_weight)
        if labels is None:
            return x

        loss = nn.CrossEntropyLoss(weight=weights)(x, labels)
        return x, loss

class ThreeLayerGCN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout=0.1):
        super(ThreeLayerGCN, self).__init__()
        torch.manual_seed(42)
        self.conv1 = GCNConv(input_size, hidden_size, cached=True, improved=True, add_self_loops=True, normalize=not use_gdc)
        self.conv2 = GCNConv(hidden_size, hidden_size, cached=True, improved=True, add_self_loops=True, normalize=not use_gdc)
        self.conv3 = GCNConv(hidden_size, output_size, cached=True, improved=True, add_self_loops=True, normalize=not use_gdc)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor, labels: torch.Tensor = None):
        x = self.dropout(x)
        x = self.conv1(x, edge_index, edge_weight)
        x = self.relu1(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index, edge_weight)
        x = self.relu2(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index, edge_weight)
        if labels is None:
            return x

        loss = nn.CrossEntropyLoss(weight=weights)(x, labels)
        return x, loss

class GUNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout=0.6):
        super().__init__()
        pool_ratios = [2000 / data.num_nodes, 0.5]
        self.unet = GraphUNet(input_size, hidden_size, output_size,
                              depth=3, pool_ratios=pool_ratios)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor, labels: torch.Tensor = None):
        edge_index, _ = dropout_adj(edge_index, edge_attr=edge_weight, p=0.2,
                                    force_undirected=True,
                                    num_nodes=data.num_nodes,
                                    training=self.training)
        x = F.dropout(data.x, p=0.92, training=self.training)

        x = self.unet(x, edge_index)
        if labels is None:
            return x

        loss = nn.CrossEntropyLoss(weight=weights)(x, labels)
        return x, loss

# k-fold validation

In [None]:
#@title Choose GNN { form-width: "20%" }
netmodel = "OneLayerGCN" #@param ["OneLayerGCN", "TwoLayerGCN", "ChebNetGCN"]
epochs = 1000 #@param {type:"slider", min:10, max:1000, step:10}
def set_labels(initial_labels, set_mask, ignore_label):
    initial_labels[~set_mask] = ignore_label
    return initial_labels
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(1)

from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
from sklearn.metrics import *
from tqdm import tqdm
nfolds = 5
kf = KFold(n_splits=nfolds)
accuracies = []
mccs = []
cma = np.array([[0,0],[0,0]])
columns_names = ["Accuracy","BA", "Sensitivity", "Specificity","MCC", 'CM']
scores = pd.DataFrame(columns=columns_names)
mm = np.array([], dtype=np.int)
predictions = np.array([])
for fold, (train_index, test_index) in enumerate(tqdm(kf.split(np.arange(data.num_nodes)), total=kf.get_n_splits(), desc=f"{nfolds}-fold")):
    mm = np.concatenate((mm, test_index))
    train_labels = y[train_index]
    train_index, val_index = train_test_split(train_index, test_size=0.025, stratify=train_labels)
    tr_index  = torch.tensor(train_index)
    val_index  = torch.tensor(val_index)
    ts_index = torch.tensor(test_index)
    train_mask = torch.BoolTensor([False]*data.num_nodes)
    val_mask = torch.BoolTensor([False]*data.num_nodes)
    test_mask = torch.BoolTensor([False]*data.num_nodes)
    train_mask.scatter_(0, tr_index, True)
    val_mask.scatter_(0, val_index, True)
    test_mask.scatter_(0, ts_index, True)

    ignore_index = nn.CrossEntropyLoss().ignore_index  # = -100, used to ignore not allowed labels in CE loss
    train_labels = set_labels(data.y.clone(), train_mask, ignore_index)
    val_labels = set_labels(data.y.clone(), val_mask, ignore_index)
    test_labels = set_labels(data.y.clone(), test_mask, ignore_index)
    # training parameters
    run_config = RunConfig(learning_rate=0.01, num_epochs=epochs, weight_decay=5e-4)
    model = globals()[netmodel](input_size=data.x.shape[1],hidden_size=16,output_size=data.num_classes,dropout=0)

    # training
    trainer = Trainer(model)
    trainer.train(data.x, train_labels, val_labels, data.edge_index, data.edge_attr, device, run_config, log=False)

    # evaluating
    ce_loss, accuracy, mcc, cm, preds = trainer.evaluate(data.x, test_labels, data.edge_index, data.edge_attr, device)
    accuracies.append(accuracy)
    mccs.append(mcc)
    predictions = np.concatenate((predictions, preds.ravel()))
    cma += cm
    true_label = test_labels[test_labels != -100]
    scores = scores.append(pd.DataFrame([[accuracy_score(true_label.cpu().numpy(), preds.cpu().numpy()), balanced_accuracy_score(true_label.cpu().numpy(), preds.cpu().numpy()), 
        cm[0,0]/(cm[0,0]+cm[0,1]), cm[1,1]/(cm[1,0]+cm[1,1]), 
        matthews_corrcoef(true_label.cpu().numpy(), preds.cpu().numpy()), cm]], columns=columns_names, index=[fold]))
df_scores = pd.DataFrame(scores.mean(axis=0)).T
df_scores.index=[f'{netmodel}']
df_scores['CM'] = [cma]
df_scores

5-fold:  20%|██        | 1/5 [00:26<01:47, 26.80s/it]

# Print predictions

In [20]:
p = np.zeros(len(y))
p[mm] = predictions
labels['predictions'] = ['NE' if x>0 else 'E' for x in p]

Unnamed: 0_level_0,CS0_vs_CS6-9,predictions
name,Unnamed: 1_level_1,Unnamed: 2_level_1
ENSG00000001036,NE,NE
ENSG00000001461,NE,NE
ENSG00000001561,NE,NE
ENSG00000001630,NE,NE
ENSG00000001631,NE,NE
...,...,...
ENSG00000288257,NE,NE
ENSG00000288283,NE,NE
ENSG00000288359,NE,NE
ENSG00000288407,NE,NE
