# GNN Modeling

## Data and Set up

In [None]:
import numpy as np
import pandas as pd
import json
import os

from tqdm import tqdm

np.random.seed(314159) # set random seed

import torch
import pytorch_lightning as pl

from torch_geometric.data import Data

import wandb

In [None]:
# read dataset
node_data_path = f'data/{graph_used}_graph/final_nodeonly_node_data_{graph_used}_v1.csv'
node_dataset = pd.read_csv(node_data_path, index_col=0)

# assert that all genes in dataset are unique
assert(node_dataset.index.duplicated().sum() == 0)

# create mapping to 0-based index map
genes = node_dataset.index.to_numpy()
gene_id_dict = {ensembl: idx for idx, ensembl in enumerate(genes)}

# map ID's in node dataset
myID = node_dataset.index.map(gene_id_dict).rename('myID')
node_dataset.insert(loc=0, column='myID', value=myID)
node_dataset = node_dataset.reset_index().set_index('myID')
node_dataset.drop(columns=['ensembl'], inplace=True)

In [None]:
# PROCESS LABELS

# read labels
thres = '0,02'
labels = pd.read_csv('../data/final_data/training_labels_trials.csv', index_col=0)

# drop label columns from node_dataset
label_cols_drop = ['gda_score', 'gda_score_thres0.01', 'gda_score_thres0.02', 
              'gda_score_thres0.03', 'gda_score_thres0.04', 'gda_score_thres0.05']
node_dataset.drop(columns=label_cols_drop, inplace=True)

# map Ensmebl index to myID in labels dataset
myID = labels.index.map(gene_id_dict).rename('myID')
labels.insert(loc=0, column='myID', value=myID)
labels = labels.reset_index().set_index('myID')
labels.drop(columns=['ensembl'], inplace=True)

# convert type
num_label_trials = 100
for label_col in [f'label_{i}' for i in range(num_label_trials)]:
    labels[label_col] = labels[label_col].astype('Int32')

print('\ndistribution of labels')
print(labels['label_0'].value_counts())

In [None]:
node_dataset

In [None]:
num_classes = 2
num_features = len(node_dataset.columns)

In [None]:
# load edge list
edge_list_path = '../data/final_data/ls-fgin_edge_list_features.csv' 

edge_list = pd.read_csv(edge_list_path)

# map edge list
edge_list.gene1 = edge_list.gene1.map(gene_id_dict)
edge_list.gene2 = edge_list.gene2.map(gene_id_dict)

# scale edge features appropriately (they take values in the range 0-1000)
edge_feat_cols = edge_list.columns[2:].to_numpy()
edge_list[edge_feat_cols] /= 1000

edge_list

In [None]:
# extract edges by index
edge_index = torch.Tensor(edge_list[['gene1', 'gene2']].to_numpy().T).type(torch.int64)
print('edge_index shape: ', edge_index.shape)
# edge_list = torch.Tensor(np.load(edge_list_path).T).type(torch.int64) # read in format expected by pytorch geometric [2, n_edges]

# extract edge features
print()
print('edge feature columns: ', edge_feat_cols)
edge_attr = torch.Tensor(edge_list[edge_feat_cols].to_numpy())
print('edge_attr shape: ', edge_attr.shape)
# load protein-ID dictionary (need new ID system starting at index 0 for pytorch geometric)
# protein_id_dict = np.load('data/protein_ids_dict.npy', allow_pickle=True).item() # maps my custom ID system to Ensembl IDs
# protein_id_dict_inv = {Ensembl: id_ for id_, Ensembl in protein_id_dict.items()} # maps Ensembl IDs to my custom ID system

In [None]:
def create_data(node_dataset, labels, label_col, test_size=0.2, val_size=0.1):
    # get subset of node features features + labels
    node_labels = node_dataset.merge(labels[label_col], left_on='myID', right_on='myID')

    node_feat_cols = node_labels.columns[:-1]

    X = torch.Tensor(node_labels[node_feat_cols].to_numpy())#.type(torch.float64)

    y = node_labels[label_col].fillna(-1).astype('int') # fill NaN with -1 so that it can be converted to pytorch tensor
    y = torch.Tensor(y).type(torch.int64)

    # restrict to data with labels
    node_data_labeled = node_labels[node_labels[label_col].notna()]
    train_mask, val_mask, test_mask = get_train_val_test_masks(node_data_labeled, label_col)

    data = Data(x=X, y=y, edge_index=edge_index, edge_attr=edge_attr)

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data


from sklearn.model_selection import train_test_split

def get_train_val_test_masks(node_data_labeled, label_col):
    X_myIDs = node_data_labeled.index.to_numpy() # myIDs for nodes with labels for training/testing
    labels = node_data_labeled[label_col].to_numpy() # for stratification

    test_size = 0.2
    val_size = 0.1 * (1/(1-test_size))

    myIDs_train_val, myIDs_test = train_test_split(X_myIDs, test_size=test_size, shuffle=True, stratify=labels)

    labels_train_val = node_data_labeled.loc[myIDs_train_val][label_col].to_numpy()
    myIDs_train, myIDs_val = train_test_split(myIDs_train_val, test_size=val_size, shuffle=True, stratify=labels_train_val)

    # NOTE: train-val-test split is shuffled and stratified

    # create masks
    n_nodes = len(node_dataset)
    train_mask = np.zeros(n_nodes, dtype=bool)
    train_mask[myIDs_train] = True
    train_mask = torch.Tensor(train_mask).type(torch.bool)

    val_mask = np.zeros(n_nodes, dtype=bool)
    val_mask[myIDs_val] = True
    val_mask = torch.Tensor(val_mask).type(torch.bool)

    test_mask = np.zeros(n_nodes, dtype=bool)
    test_mask[myIDs_test] = True
    test_mask = torch.Tensor(test_mask).type(torch.bool)

    return train_mask, val_mask, test_mask

In [None]:
data = create_data(node_dataset, labels, label_col='label_0', test_size=0.2, val_size=0.1)

## Graph Convolutional Neural Network

In [None]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# define GNN architecture
class GNNModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, hidden_dense, 
                 GNN_conv_layer=GCNConv, dropout_rate=0.1, **kwargs):
        """
        Args:
            in_channels (int): Dimension of input features
            hidden_channels (List[int]): Dimension of hidden features
            out_channels (int): Dimension of the output.
            hidden_dense (int): number of units in hidden dense layer following convolutions.
            GNN_conv_layer: Class of the graph convolutional layer to use.
            dropout_rate (float): Dropout rate to apply throughout the network
            kwargs: Additional arguments for the graph layer (e.g. number of heads for GAT)
        """
        super().__init__()

        self.convs = []
        self.convs.append(GNN_conv_layer(in_channels=in_channels, out_channels=hidden_channels[0], **kwargs)) # first GNN Conv layer

        for c1, c2 in zip(hidden_channels[:-1], hidden_channels[1:]): # middle layers
            self.convs.append(GNN_conv_layer(in_channels=c1, out_channels=c2, **kwargs))

        self.convs = torch.nn.ModuleList(self.convs)

        self.dense1 = torch.nn.Linear(hidden_channels[-1], hidden_dense)
        self.dense_out = torch.nn.Linear(hidden_dense, num_classes)

        self.dropout_rate = dropout_rate

    def forward(self, x, edge_index, edge_attr=None):
        """
        Args:
            x: node features
            edge_index: edge list
        """

        for i,conv in enumerate(self.convs):
            if edge_attr is None:
                x = conv(x, edge_index)
            else:
                x = conv(x, edge_index, edge_attr=edge_attr)
            x = x.relu()
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        x = self.dense1(x)
        x = x.relu()
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.dense_out(x)

        return x

In [None]:
# define Pytorch Lightning model
class LitGNN(pl.LightningModule):
    def __init__(self, model_name, model=None, model_type='edge', **model_kwargs):
        super().__init__()

        # Saving hyperparameters
        self.save_hyperparameters()

        self.model_name = model_name
        
        self.model_type = model_type
        if self.model_type not in ['edge', 'edge_attr', 'baseline']:
            raise TypeError("Invalid `model_type`. Must be one of ['edge', 'edge_attr', 'baseline']")
        
        # create model using GNNModel if one isn't given
        if model is None:
            self.model = GNNModel(**model_kwargs)
        else:
            self.model = model(**model_kwargs)

        # define the loss function
        self.loss_module = torch.nn.CrossEntropyLoss()

        # give example input
        self.example_input_array = data

    def forward(self, data, mode='train'):
        if self.model_type == 'edge_attr':
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
            x = self.model(x, edge_index, edge_attr=edge_attr)
        elif self.model_type == 'edge':
            x, edge_index = data.x, data.edge_index
            x = self.model(x, edge_index)
        else:
            x = data.x
            x = self.model(x)

        # Only calculate the loss and acc on the nodes corresponding to the mask
        if mode == 'train':
            mask = data.train_mask
        elif mode == 'val':
            mask = data.val_mask
        elif mode == 'test':
            mask = data.test_mask
        else:
            raise ValueError(f'Unknown forward mode: {mode}')

        #TODO: add other metrics like recall, precision, f1, etc...
        loss = self.loss_module(x[mask], data.y[mask])
        acc = (x[mask].argmax(dim=-1) == data.y[mask]).sum().float() / mask.sum()

        return x, loss, acc

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())#SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=2e-3)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, loss, acc = self.forward(batch, mode='train')
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        logits, loss, acc = self.forward(batch, mode="val")
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return logits

    def validation_epoch_end(self, validation_step_outputs):

        flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
        if self.logger:
            self.logger.experiment.log({'val_logits': wandb.Histogram(flattened_logits.to('cpu')), 
                                    'global_step': self.global_step})

    def test_step(self, batch, batch_idx):
        x, _, acc = self.forward(batch, mode="test")
        self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)

In [None]:
import os
notebook_name = 'train_gnn_model.ipynb'
os.environ['WANDB_NOTEBOOK_NAME'] = notebook_name

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import torch_geometric.loader

from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GATv2Conv, ChebConv

import datetime

In [None]:
import gc
from sklearn.metrics import classification_report
def evaluate_model(model, data, logger=None):
    '''returns train and test classification reports'''

    model.to(device='cuda')
    logits, _, _ = model.forward(data.to(device='cuda'))

    preds_train = logits[data.train_mask].argmax(dim=-1).cpu().detach().numpy()
    preds_test = logits[data.test_mask].argmax(dim=-1).cpu().detach().numpy()

    y_train = data.y[data.train_mask].cpu().detach().numpy()
    y_test = data.y[data.test_mask].cpu().detach().numpy()

    train_report = classification_report(y_train, preds_train, labels=[0,1], target_names=['negative', 'positive'],
                                         output_dict=True)
    test_report = classification_report(y_test, preds_test, labels=[0,1], target_names=['negative', 'positive'],
                                        output_dict=True)
    if logger:
        train_acc = train_report['accuracy']
        test_acc = test_report['accuracy']
        logger.log_metrics({'final_train_acc': train_acc, 'final_test_acc': test_acc})
        logger.log_metrics({'train_report': train_report, 'test_report': test_report})

    return train_report, test_report

def save_reports(filename, train_reports, test_reports):
    '''saves train and test reports to a json file'''
    save_dict = {'train_reports': train_reports, "test_reports": test_reports}
    json_string = json.dumps(save_dict)
    json_file = open(f'{filename}.json', 'w')
    json_file.write(json_string)
    json_file.close()


def get_value_from_dict(report, *keys):
    '''gets value from dict through sequence of keys'''
    value = report
    for key in keys:
        value = value[key]
    return value

def average_report_val(reports, *keys):
    '''averages a particular report value over a list of reports'''
    return np.average([get_value_from_dict(report, *keys) for report in reports])


def run_trials(create_model, data, start_trial=0, end_trial=100, n_epochs=500, log=False, log_project=None):

    if log:
        dt_string = str(datetime.datetime.today()).replace(' ', '_')
        if log_project is None:
            print('Enter the name of the log project: ')
            log_project = input()

    # model info
    model = create_model()
    model_summary = pl.utilities.model_summary.summarize(model, max_depth=4)
    model_summary_str = str(model_summary)
    num_trainable_params = model_summary.trainable_parameters

    print(model_summary_str)

    train_reports = []
    test_reports = []

    for trial in tqdm(range(start_trial, end_trial + 1)):

        print(f'running trial {str(trial)}')
        data = create_data(node_dataset, labels, f'label_{trial}', test_size=0.2, val_size=0.1)

        model = create_model()

        if log:
            n_zfills = int(np.ceil(np.log10(100)))
            log_name = f'{log_project}_trial{str(trial).zfill(n_zfills)}'

            logger = WandbLogger(name=log_name, project=log_project, log_model="all", save_dir='wandb_projects')

            logger.log_metrics({'model_summary_str': model_summary_str,
                                'num_trainable_params': num_trainable_params})
        else:
            logger = False

        AVAIL_GPUS = min(1, torch.cuda.device_count())

        data_loader = torch_geometric.loader.DataLoader([data], batch_size=1, num_workers=os.cpu_count())

        trainer = pl.Trainer(
                    callbacks=[ModelCheckpoint(save_weights_only=False, mode="max", monitor="val_acc")],
                    gpus=AVAIL_GPUS,
                    max_epochs=n_epochs,
                    logger=logger,
                    enable_model_summary=False
                    # progress_bar_refresh_rate=0,
                    )

        trainer.fit(model, data_loader, data_loader)

        model = LitGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

        train_report, test_report = evaluate_model(model, data, logger=logger)

        train_reports.append(train_report)
        test_reports.append(test_report)

        if log:
            wandb.save('modeling_gnn.ipynb')
            wandb.finish(quiet=True)

        del model, data_loader, trainer, data
        gc.collect()

        print('memory allocated: ', torch.cuda.memory_allocated())
        print('memory reserved: ', torch.cuda.memory_reserved())
        torch.cuda.empty_cache()
        print('\nafter empty_cache:')
        print('memory allocated: ', torch.cuda.memory_allocated())
        print('memory reserved: ', torch.cuda.memory_reserved())



    return train_reports, test_reports

In [None]:
## TRAIN AND EVALUATE MODEL


from models import create_SGConv_GNN # choose model to train
model_name = 'SGConv_GNN'
create_model = create_SGConv_GNN

log_project_name = f'{model_name}_{graph_used}_{thres}'


# run multiple trials
train_reports, test_reports = run_trials(create_model, data, start_trial=65, end_trial=99, 
                                         n_epochs=250, log=True, log_project=log_project_name)

# save reports from trials to json
save_reports(f'project_reports/{log_project_name}', train_reports, test_reports)

In [None]:
avg_acc = average_report_val(test_reports, 'accuracy')
print('Test Accuracy: ', avg_acc)

avg_f1 = average_report_val(test_reports, 'weighted avg', 'f1-score')
print('Test f1-score: ', avg_f1)