In [1]:
import torch, dgl, os
import numpy as np
import pandas as pd

from glob import glob
from time import time
from datetime import datetime

from dgl.data.utils import Subset
from dgl.dataloading import GraphDataLoader

from sklearn.model_selection import KFold

from data.dataset import MSPDataset

from model.model import MetabolicStabilityPrediction, MSP_MoE

from train.utils import *
from train.train import *
from train.early_stop import EarlyStopping
from train.scheduler import CosineAnnealingWarmUpRestarts

In [2]:
TORCH = format_pytorch_version(torch.__version__)
CUDA = format_cuda_version(torch.version.cuda)
DGL = format_dgl_version(dgl.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] DEVICE: {device}, CUDA version: {CUDA}, TORCH version: {TORCH}, DGL version: {DGL}")

[INFO] DEVICE: cuda, CUDA version: cu117, TORCH version: 2.0.1, DGL version: 1.1.1+cu117


In [3]:
hyperparameter = { 'batch_size':64, 'num_layers': 8, 'out_size': 128, 'node_dim': 62, 'pose_dim': 20, 'edge_dim': 13, 
                  'lr_min': 1e-4, 'lr_max': 1e-2, 'T_0': 20, 'T_mult':1, 'T_up':5, 'T_gamma':0.9, 'dropout_ratio':0.15,
                  'model_type': 'GatedGCNLSPE_MoE', 'loss_function': 'BCE', 'seed':42 }

In [23]:
def cross_validation_training( train_data_paths, hyperparameter, device="cuda" ):
    batch_size = hyperparameter.get('batch_size', 64)
     
    num_layers = hyperparameter.get( 'num_layers', 8 )
    out_size   = hyperparameter.get( 'out_size', 256 )
    pose_dim   = hyperparameter.get( 'pose_dim', 20 )
    node_dim   = hyperparameter.get( 'node_dim', 57 )
    edge_dim   = hyperparameter.get( 'edge_dim', 13 )
    dropout_ratio  = hyperparameter.get( 'dropout_ration', 0.15 )
    
    lr_min  = hyperparameter.get( 'lr_min',  1e-6 )
    lr_max  = hyperparameter.get( 'lr_max',  1e-2 )
    T_0     = hyperparameter.get( 'T_0',     20 )
    T_mult  = hyperparameter.get( 'T_mult',  1 )
    T_up    = hyperparameter.get( 'T_up',    5 )
    T_gamma = hyperparameter.get( 'T_gamma', 0.8 )
    seed = hyperparameter.get( 'seed', 42 )
    
    model_type    = hyperparameter.get( 'model_type', 'GatedGCN')
    loss_function = hyperparameter.get( 'loss_function', 'MSE' )

    set_random_seed(seed)
    
    g = torch.Generator()
    g.manual_seed(seed)
    
    n_splits = 5
    kf = KFold( n_splits=n_splits, shuffle=True, random_state=seed )

    train = MSPDataset( train_data_paths, pos_enc_dim=pose_dim )

    oof_preds = np.zeros((len(train_data_paths)))
    oof_trues = np.zeros((len(train_data_paths)))

    LOSS = { idx: [] for idx in range(n_splits) }

    for num, (train_id, valid_id) in enumerate( kf.split( train ) ):
        model_path = f'./save/{loss_function}_{model_type}_nl{num_layers}_od{out_size}_pd{pose_dim}_bs{batch_size}_lmax{lr_max}_lmin{lr_min}_seed{seed}'

        if not os.path.isdir( model_path ):
            os.makedirs(model_path)
        else:
            print( 'already made' )

        dt = datetime.now()
        dt = f'{model_path}/early_stop_{dt.date()}_{dt.hour:02d}-{dt.minute:02d}-{dt.second:02d}_{num}fold.pth'

        print(f"# # # # # # # # # # Start {num}-fold cross validation # # # # # # # # # # ")
        train_data = Subset( train, train_id )
        valid_data = Subset( train, valid_id )
    
        train_loader = GraphDataLoader(train_data, batch_size=batch_size, shuffle=True,  drop_last=False, generator=g)
        valid_loader = GraphDataLoader(valid_data, batch_size=batch_size, shuffle=False, drop_last=False, generator=g )

        model = MetabolicStabilityPrediction(node_dim, out_size, edge_dim, pose_dim, 1024, num_layers=num_layers).to(device)
        # model = MSP_MoE( node_dim, edge_dim, pose_dim, out_size, num_layers, 1024, num_experts=5).to(device)

        stopper = EarlyStopping( patience=100, mode="higher", metric="rmse", filename=dt)

        optimizer = torch.optim.AdamW(model.parameters(), lr=lr_min)
        scheduler = CosineAnnealingWarmUpRestarts( optimizer, T_0=T_0, T_mult=T_mult, eta_max=lr_max, T_up=T_up, gamma=T_gamma )

        each_fold_loss = { "TRAIN":[], "VALID":[] }
        for epoch in range(1, 10001):
            start = time()
            lr = optimizer.param_groups[0]["lr"]

            train_loss = train_model(model, train_loader, optimizer, scheduler, device=device)
            valid_loss, valid_pred, valid_true, valid_name = valid_model( model, valid_loader, device=device ) ## 1 fold 

            each_fold_loss["TRAIN"].append( train_loss )
            each_fold_loss["VALID"].append( valid_loss )
            
            print( f"\n[ INFO ] {num+1} Fold, Epoch: {epoch:04d}, Lr: {lr:.5f}, Time: {time() - start:.2f}, Train Loss: {train_loss:.3f}, Valid Loss: {valid_loss:.3f}" )
            
            results = analysis(valid_true, valid_pred)
            analysis_table(results)        
            
            early_stop = stopper.step(results["ROC_AUC"], model)
            if early_stop:
                model.load_state_dict(torch.load(dt)['model_state_dict'])
    
                valid_loss, valid_pred, valid_true, valid_name = valid_model( model, valid_loader, device=device ) ## 1 fold 
    
                oof_preds[valid_id] = valid_pred
                oof_trues[valid_id] = valid_true
                break
                
        LOSS[num].append( each_fold_loss )
        
    return oof_preds, oof_trues, model_path, LOSS

In [24]:
random_seeds = [42, 189, 2494, 4592, 22232, 49492, 89982, 100294, 199606, 202310]
random_seeds = [42]
random_seeds

[42]

In [25]:
oof_preds_list = []
oof_trues_list = []
loss_list = []
for seed in random_seeds:
    hyperparameter['seed'] = seed
    
    print( f'seed: {seed} start' )
    oof_preds, oof_trues, model_path, LOSS = cross_validation_training( glob("./data/train/*"), hyperparameter)
    print( f'seed: {seed} done' )
    
    oof_preds_list.append( oof_preds )
    oof_trues_list.append( oof_trues )
    loss_list.append( LOSS )

42 start
[INFO] RANDOM, DGL, NUMPY and TORCH random seed is set 42.
already made
# # # # # # # # # # Start 0-fold cross validation # # # # # # # # # # 
For metric rmse, the lower the better.

[ INFO ] 1 Fold, Epoch: 0001, Lr: 0.00010, Time: 1.60, Train Loss: 0.689, Valid Loss: 0.736
+-------+-------+--------------------+-------------+---------------+-------+-------+---------+
|  ACC  |  MCC  | Sensitivity Recall | Specificity | Precision PPV |  NPV  |   F1  | ROC_AUC |
+-------+-------+--------------------+-------------+---------------+-------+-------+---------+
| 0.540 | 0.189 |       0.758        |    0.499    |     0.224     | 0.915 | 0.345 |  0.628  |
+-------+-------+--------------------+-------------+---------------+-------+-------+---------+

[ INFO ] 1 Fold, Epoch: 0002, Lr: 0.00208, Time: 1.56, Train Loss: 0.650, Valid Loss: 0.742
+-------+-------+--------------------+-------------+---------------+-------+-------+---------+
|  ACC  |  MCC  | Sensitivity Recall | Specificity | 

In [7]:
oof_preds_tensor = torch.as_tensor( oof_preds_list )
oof_trues_tensor = torch.as_tensor( oof_trues_list )

  oof_preds_tensor = torch.as_tensor( oof_preds_list )


In [20]:
analysis_table( analysis( oof_preds_tensor.view(-1), oof_trues_tensor.view(-1)) )

+-------+-------+--------------------+-------------+---------------+-------+-------+---------+
|  ACC  |  MCC  | Sensitivity Recall | Specificity | Precision PPV |  NPV  |   F1  | ROC_AUC |
+-------+-------+--------------------+-------------+---------------+-------+-------+---------+
| 0.819 | 0.606 |       0.860        |    0.745    |     0.859     | 0.747 | 0.860 |  0.864  |
+-------+-------+--------------------+-------------+---------------+-------+-------+---------+


In [8]:
oof_preds_tensor_binary = torch.where( oof_preds_tensor > 0.5, 1, 0 ).float()
oof_trues_tensor_binary = torch.where( oof_trues_tensor > 0.5, 1, 0 ).float()

In [9]:
oof_preds_mean = oof_preds_tensor_binary.mean(dim=0)
oof_trues_mean = oof_trues_tensor_binary.mean(dim=0)

In [10]:
analysis_table( analysis( oof_preds_mean, oof_trues_mean ) )

+-------+-------+--------------------+-------------+---------------+-------+-------+---------+
|  ACC  |  MCC  | Sensitivity Recall | Specificity | Precision PPV |  NPV  |   F1  | ROC_AUC |
+-------+-------+--------------------+-------------+---------------+-------+-------+---------+
| 0.819 | 0.606 |       0.860        |    0.745    |     0.859     | 0.747 | 0.860 |  0.803  |
+-------+-------+--------------------+-------------+---------------+-------+-------+---------+


In [44]:
def inference( dataset, model_paths, grid ):
    print(f"How Many Models?: {len(model_paths)}")
    loader = GraphDataLoader(dataset, batch_size=512, shuffle=False, drop_last=False)
    preds = []
    trues = []
    names = []
    for idx, model_dict in enumerate(model_paths):
        model.load_state_dict( torch.load( model_dict )['model_state_dict'] )
        loss, pred, true, name = valid_model( model, loader, device=device ) ## 1 fold 

        preds.append(pred)
        trues.append(true)
        
        if idx == 0:
            names.append(name)
            
    preds = torch.as_tensor(preds)
    trues = torch.as_tensor(trues)

    preds = torch.where( preds >= grid, 1, 0 ).float()
    preds = preds.mean(dim=0)
    trues = torch.where( trues >= grid, 1, 0 ).float()
    trues = trues.mean(dim=0)
    
    return preds, trues, names

In [45]:
model = MetabolicStabilityPrediction(hyperparameter['node_dim'], hyperparameter['out_size'], hyperparameter['edge_dim'], hyperparameter['pose_dim'], hyperparameter['num_layers']).to(device)
test  = MSPDataset(  sorted(glob("./data/CMMS-GCL/test/*")),  pos_dim=20 )

test_pred, test_true, test_name = inference( test, glob( f'./save/seed42/*/*' ), 0.5)
analysis_table( analysis( test_pred, test_true ) )

How Many Models?: 10
+-------+--------+--------------------+-------------+---------------+-------+-------+---------+
|  ACC  |  MCC   | Sensitivity Recall | Specificity | Precision PPV |  NPV  |   F1  | ROC_AUC |
+-------+--------+--------------------+-------------+---------------+-------+-------+---------+
| 0.649 | -0.012 |       0.817        |    0.172    |     0.736     | 0.250 | 0.775 |  0.530  |
+-------+--------+--------------------+-------------+---------------+-------+-------+---------+


In [25]:
test_pred3

tensor([0.9000, 0.9000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9000, 1.0000, 1.0000,
        1.0000, 0.9000, 0.9000, 0.9000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        0.9000, 1.0000, 1.0000, 1.0000, 1.0000, 0.4000, 1.0000, 0.9000, 0.9000,
        0.8000, 0.0000, 0.9000, 0.0000, 0.0000, 0.1000, 1.0000, 0.6000, 0.0000,
        0.1000, 0.1000, 0.9000, 1.0000, 1.0000, 0.8000, 0.4000, 0.9000, 1.0000,
        1.0000, 0.1000, 1.0000, 1.0000, 0.8000, 1.0000, 1.0000, 0.5000, 1.0000,
        1.0000, 0.6000, 0.9000, 1.0000, 0.6000, 0.9000, 1.0000, 1.0000, 1.0000,
        0.7000, 0.7000, 1.0000, 0.2000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 0.0000, 1.0000, 1.0000, 0.2000, 1.0000, 0.2000, 0.8000,
        1.0000, 0.2000, 0.7000, 0.1000, 0.1000, 1.0000, 0.8000, 0.5000, 1.0000,
        1.0000, 0.1000, 0.9000, 0.9000, 1.0000, 0.9000, 1.0000, 1.0000, 1.0000,
        1.0000, 0.9000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.8000, 1.0000,
        1.0000, 1.0000, 1.0000])