In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import time
import torch
import random
import warnings
import argparse
import numpy as np
import pandas as pd
import pickle as pk
import torch.nn as nn
import torchmetrics as tm
import pytorch_lightning as pl
import torch.nn.functional as F
from tool import METRICS
from tqdm import tqdm,trange
from model import GraphBepi
from dataset import PDB,collate_fn,chain
from collections import defaultdict
from torch.utils.data import DataLoader,Dataset
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import Callback,EarlyStopping,ModelCheckpoint
warnings.simplefilter('ignore')

In [3]:
def seed_everything(seed=2022):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(2022)

In [4]:
gpu=0
fold=-1
lr=1e-6
batch=4
epochs=300
root='./data/BCE_633/'
log_name=f'BCE_633_GraphBepi'

In [5]:
trainset=PDB(mode='train',fold=fold,root=root)
valset=PDB(mode='val',fold=fold,root=root)
testset=PDB(mode='test',root=root)

100%|██████████████████████████████████████| 577/577 [00:07<00:00, 82.05it/s, chain=3lh2_V]
0it [00:00, ?it/s]
100%|████████████████████████████████████████| 56/56 [00:00<00:00, 83.24it/s, chain=7ue9_C]


In [6]:
train_loader=DataLoader(trainset, batch, shuffle=True, collate_fn=collate_fn, drop_last=True)
val_loader=DataLoader(valset, batch, shuffle=False, collate_fn=collate_fn)
test_loader=DataLoader(testset, batch, shuffle=False, collate_fn=collate_fn)

In [7]:
device='cpu' if gpu==-1 else f'cuda:{gpu}'
metrics=METRICS(device)
es=EarlyStopping('val_AUPRC',patience=40,mode='max')
mc=ModelCheckpoint(
    f'./model/{log_name}/',f'model_{fold}',
    'val_AUPRC',
    mode='max',
    save_weights_only=True, 
)
logger = TensorBoardLogger(
    './log', 
    name=f'{log_name}_{fold}'
)
cb=[mc,es]
trainer = pl.Trainer(
    gpus=[gpu] if gpu!=-1 else None, 
    max_epochs=epochs, callbacks=cb,
    logger=logger,check_val_every_n_epoch=1,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
model=GraphBepi(
    feat_dim=2560,                     # esm2 representation dim
    hidden_dim=256,                    # hidden representation dim
    exfeat_dim=13,                     # dssp feature dim
    edge_dim=51,                       # edge feature dim
    augment_eps=0.05,                  # random noise rate
    dropout=0.2,
    lr=lr,                             # learning rate
    metrics=metrics,                   # an implement to compute performance
    result_path=f'./model/{log_name}', # path to save temporary result file of testset
)

In [9]:
trainer.fit(model, train_loader, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name        | Type       | Params
-------------------------------------------
0 | loss_fn     | BCELoss    | 0     
1 | W_v         | Linear     | 655 K 
2 | W_u1        | AE         | 69.9 K
3 | edge_linear | Sequential | 3.3 K 
4 | gat         | EGAT       | 8.5 M 
5 | lstm1       | LSTM       | 1.2 M 
6 | lstm2       | LSTM       | 1.2 M 
7 | mlp         | Sequential | 262 K 
-------------------------------------------
11.9 M    Trainable params
0         Non-trainable params
11.9 M    Total params
47.536    Total estimated model params size (MB)


Epoch 299: 100%|██████████████████████████████████████| 158/158 
[00:39<00:00, 3.97it/s, loss=0.305, v_num=1, train_auc_step=0.775, train_prc_step=0.416, val_loss=0.264, val_AUROC=0.765, val_AUPRC=0.261, val_mcc=0.248, val_f1=0.322, train_loss=0.290, train_auc_epoch=0.794, train_prc_epoch=0.353]


In [10]:
model.load_state_dict(
    torch.load(f'./model/{log_name}/model_{fold}.ckpt')['state_dict'],
)
trainer = pl.Trainer(gpus=[gpu],logger=None)
result = trainer.test(model,test_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_AUPRC            0.26434326171875
       test_AUROC            0.761267364025116
        test_bacc           0.6568726897239685
         test_f1            0.32405197620391846
        test_loss           0.2643062174320221
        test_mcc            0.24895137548446655
     test_precision         0.2569386065006256
       test_recall          0.4386216700077057
     test_threshold         0.17629370093345642
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
