In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import time
import torch
import random
import warnings
import argparse
import numpy as np
import pandas as pd
import pickle as pk
import torch.nn as nn
import torchmetrics as tm
import pytorch_lightning as pl
import torch.nn.functional as F
from tool import METRICS
from tqdm import tqdm,trange
from model import GraphBepi
from dataset import PDB,collate_fn,chain
from collections import defaultdict
from torch.utils.data import DataLoader,Dataset
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import Callback,EarlyStopping,ModelCheckpoint
warnings.simplefilter('ignore')

In [3]:
def seed_everything(seed=2022):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(2022)

In [4]:
gpu=0
fold=0
lr=1e-6
batch=4
epochs=300
root='./data/BCE_633/'
log_name=f'BCE_633_GraphBepi'

In [5]:
trainset=PDB(mode='train',fold=fold,root=root)
valset=PDB(mode='val',fold=fold,root=root)
testset=PDB(mode='test',fold=fold,root=root)

100%|████████████████████| 520/520 [00:05<00:00, 89.66it/s, chain=3lh2_V]
100%|█████████████████████| 57/57 [00:00<00:00, 105.67it/s, chain=2qqn_A]
100%|██████████████████████| 56/56 [00:00<00:00, 70.71it/s, chain=7ue9_C]


In [6]:
train_loader=DataLoader(trainset, batch, shuffle=True, collate_fn=collate_fn, drop_last=True)
val_loader=DataLoader(valset, batch, shuffle=False, collate_fn=collate_fn)
test_loader=DataLoader(testset, batch, shuffle=False, collate_fn=collate_fn)

In [7]:
device='cpu' if gpu==-1 else f'cuda:{gpu}'
metrics=METRICS(device)
es=EarlyStopping('val_AUPRC',patience=40,mode='max')
mc=ModelCheckpoint(
    f'./model/{log_name}/',f'model_{fold}',
    'val_AUPRC',
    mode='max',
    save_weights_only=True, 
)
logger = TensorBoardLogger(
    './log', 
    name=f'{log_name}_{fold}'
)
cb=[mc,es]
trainer = pl.Trainer(
    gpus=[gpu] if gpu!=-1 else None, 
    max_epochs=epochs, callbacks=cb,
    logger=logger,check_val_every_n_epoch=1,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
model=GraphBepi(
    feat_dim=2560,                     # esm2 representation dim
    hidden_dim=256,                    # hidden representation dim
    exfeat_dim=13,                     # dssp feature dim
    edge_dim=51,                       # edge feature dim
    augment_eps=0.05,                  # random noise rate
    dropout=0.2,
    lr=lr,                             # learning rate
    metrics=metrics,                   # an implement to compute performance
    result_path=f'./model/{log_name}', # path to save temporary result file of testset
)

In [None]:
trainer.fit(model, train_loader, val_loader)
model.load_state_dict(
    torch.load(f'./model/{log_name}/model_{args.fold}.ckpt')['state_dict'],
)
trainer = pl.Trainer(gpus=[args.gpu],logger=None)
result = trainer.test(model,test_loader)