In [1]:
import numpy as np
import pandas as pd
import json
import os
import datetime

from tqdm import tqdm

import os 
os.chdir("/home/pengq/LenskAI/geneDRAGNN/models")

np.random.seed(314159) # set random seed

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from torch_geometric.data import Data
import torch_geometric.loader

import wandb

import models
from data_utils import read_data, create_data
import model_utils

In [4]:
base_path = '../data/final_data/'

model_creator_dict = {'SGConv': models.create_SGConv_GNN,
                      'GraphSAGE': models.create_GraphSAGE_GNN,
                      'TAG': models.create_TAG_GNN,
                      'ClusterGCN': models.create_clusterGCN_GNN,
                      'MLP': models.create_MLP}

In [19]:
node_dataset, edge_list, labels = read_data(node_filepath="../data/final_data/node_node2vec_data.csv",
                                            label_filepath="../data/final_data/training_labels_trials.csv",
                                            edgelist_path="../data/final_data/ls-fgin_edge_list.edg",
                                            feats_type='nodeonly')

In [6]:
num_classes = 2
num_features = len(node_dataset.columns)

In [7]:
import os
notebook_name = 'train_gnn_model.ipynb'
os.environ['WANDB_NOTEBOOK_NAME'] = notebook_name

In [25]:
import gc
from sklearn import metrics

def run_trials(create_model, start_trial=0, end_trial=100, n_epochs=500, log=False, log_project=None):

    if log:
        # dt_string = str(datetime.datetime.today()).replace(' ', '_')
        if log_project is None:
            print('Enter the name of the log project: ')
            log_project = input()

    # model info
    model = create_model()
    model_summary = pl.utilities.model_summary.summarize(model, max_depth=4)
    model_summary_str = str(model_summary)
    num_trainable_params = model_summary.trainable_parameters

    print(model_summary_str)

    train_reports = []
    test_reports = []
    roc_data = []

    for trial in tqdm(range(start_trial, end_trial + 1)):

        print(f'running trial {str(trial)}')
        data = create_data(node_dataset, edge_list, labels, f'label_{trial}', test_size=0.2, val_size=0.1)


        model = create_model()

        if log:
            n_zfills = int(np.ceil(np.log10(100)))
            log_name = f'{log_project}_trial{str(trial).zfill(n_zfills)}'

            logger = WandbLogger(name=log_name, project=log_project, log_model="\all\\", save_dir='wandb_projects')

            logger.log_metrics({'model_summary_str': model_summary_str,
                                'num_trainable_params': num_trainable_params})
            
            # log random train-val-test split
            logger.log_metrics({'train_mask': data.train_mask, 'val_mask': data.val_mask, 'test_mask': data.test_mask})

        else:
            logger = False

        AVAIL_GPUS = min(1, torch.cuda.device_count())

        data_loader = torch_geometric.loader.DataLoader([data], batch_size=1, num_workers=os.cpu_count())

        trainer = pl.Trainer(
                    callbacks=[ModelCheckpoint(save_weights_only=False, mode="max", monitor='val_acc')],
                    # gpus=AVAIL_GPUS,
                    accelerator = "gpu",
                    devices=AVAIL_GPUS,
                    max_epochs=n_epochs,
                    logger=logger,
                    enable_model_summary=False
                    # progress_bar_refresh_rate=0,
                    )

        trainer.fit(model, data_loader, data_loader)

        model = models.LitGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

        train_report, test_report = model_utils.evaluate_model(model, data, logger=logger)

        train_reports.append(train_report)
        test_reports.append(test_report)

        model.to(device='cuda')
        logits, _, _ = model.forward(data.to(device='cuda'))

        preds = logits[data.test_mask][:, 1].cpu().detach().numpy()
        y = data.y[data.test_mask].cpu().detach().numpy()

        fpr, tpr, thresholds = metrics.roc_curve(y, preds)
        auc_score = metrics.roc_auc_score(y, preds)

        roc_data.append({'fpr': fpr, 'tpr': tpr, 'threholds': thresholds, 'auc': auc_score})

        if log:
            logger.log_metrics({'auc_test': auc_score, 'fpr_test': fpr, 
                                'tpr_test': tpr, 'roc_thres': thresholds})

        if log:
            wandb.save('modeling_gnn.ipynb')
            wandb.finish(quiet=True)

        del model, data_loader, trainer, data
        gc.collect()

        print('memory allocated: ', torch.cuda.memory_allocated())
        print('memory reserved: ', torch.cuda.memory_reserved())
        torch.cuda.empty_cache()
        print('\\nafter empty_cache:')
        print('memory allocated: ', torch.cuda.memory_allocated())
        print('memory reserved: ', torch.cuda.memory_reserved())



    return train_reports, test_reports, roc_data

In [None]:
print(edge_list.head())
edge_list.iloc[:, :2] = edge_list.iloc[:, :2].apply(pd.to_numeric, errors='coerce')
edge_list = edge_list.dropna().astype(int)
edge_index = torch.tensor(edge_list.iloc[:, :2].to_numpy().T, dtype=torch.int64)


       0     1
0    839  8988
1   4109  8988
2  12757  8988
3   9850  8988
4   6690  8988
       0     1
0    839  8988
1   4109  8988
2  12757  8988
3   9850  8988
4   6690  8988
0    int64
1    int64
dtype: object


In [26]:
## TRAIN AND EVALUATE MODEL

model_name = 'SGConv'
create_model = model_creator_dict[model_name]

log_project_name = f'{model_name}'

# run multiple trials
train_reports, test_reports, roc_data = run_trials(lambda: create_model(model_name, num_features, num_classes), start_trial=0, end_trial=0,
                                         n_epochs=250, log=False, log_project=log_project_name)

# save reports from trials to json
model_utils.save_reports(f'project_reports/{log_project_name}_reports', train_reports, test_reports)
np.save(f'project_reports/{model_name}_roc', roc_data)

   | Name                      | Type             | Params | Mode 
------------------------------------------------------------------------
0  | model                     | GNNModel         | 162 K  | train
1  | model.convs               | ModuleList       | 145 K  | train
2  | model.convs.0             | SGConv           | 13.6 K | train
3  | model.convs.0.aggr_module | SumAggregation   | 0      | train
4  | model.convs.0.lin         | Linear           | 13.6 K | train
5  | model.convs.1             | SGConv           | 33.0 K | train
6  | model.convs.1.aggr_module | SumAggregation   | 0      | train
7  | model.convs.1.lin         | Linear           | 33.0 K | train
8  | model.convs.2             | SGConv           | 65.8 K | train
9  | model.convs.2.aggr_module | SumAggregation   | 0      | train
10 | model.convs.2.lin         | Linear           | 65.8 K | train
11 | model.convs.3             | SGConv           | 32.9 K | train
12 | model.convs.3.aggr_module | SumAggregation   | 0   

/home/pengq/miniconda3/envs/LenskAI/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/pengq/miniconda3/envs/LenskAI/lib/python3.10/s ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


running trial 0


Sanity Checking: |                                                                                            …

/home/pengq/miniconda3/envs/LenskAI/lib/python3.10/site-packages/pytorch_lightning/core/module.py:512: You called `self.log('val_acc', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/home/pengq/miniconda3/envs/LenskAI/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 14854. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/pengq/miniconda3/envs/LenskAI/lib/python3.10/site-packages/pytorch_lightning/core/module.py:512: You called `self.log('val_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`


Training: |                                                                                                   …

/home/pengq/miniconda3/envs/LenskAI/lib/python3.10/site-packages/pytorch_lightning/core/module.py:512: You called `self.log('train_loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/home/pengq/miniconda3/envs/LenskAI/lib/python3.10/site-packages/pytorch_lightning/core/module.py:512: You called `self.log('train_acc', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`


Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


Detected KeyboardInterrupt, attempting graceful shutdown ...
  0%|                                                                                                                                      | 0/1 [06:22<?, ?it/s]


NameError: name 'exit' is not defined