In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dgllife.utils import EarlyStopping
from configure import get_exp_configure
from utils import set_random_seed, load_dataset, collate, load_model
from main import run_a_train_epoch, run_an_eval_epoch, update_msg_from_scores

In [None]:
idx = 0 # select configuration
args = dict()
choices=['ACNN_PDBBind_core_pocket_random', 'ACNN_PDBBind_core_pocket_scaffold',
         'ACNN_PDBBind_core_pocket_stratified', 'ACNN_PDBBind_core_pocket_temporal',
         'ACNN_PDBBind_refined_pocket_random', 'ACNN_PDBBind_refined_pocket_scaffold',
         'ACNN_PDBBind_refined_pocket_stratified', 'ACNN_PDBBind_refined_pocket_temporal']
args.update(get_exp_configure(choices[idx]))
args['device'] = torch.device("cuda: 0") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])

dataset, train_set, test_set = load_dataset(args)
args['train_mean'] = train_set.labels_mean.to(args['device'])
args['train_std'] = train_set.labels_std.to(args['device'])
train_loader = DataLoader(dataset=train_set,
                          batch_size=args['batch_size'],
                          shuffle=False,
                          collate_fn=collate)
test_loader = DataLoader(dataset=test_set,
                         batch_size=args['batch_size'],
                         shuffle=True,
                         collate_fn=collate)

model = load_model(args)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
stopper = EarlyStopping(mode=args['mode'],
                        patience=args['patience'],
                        filename=choices[idx]+'_model.h5')
if args['load_checkpoint']:
    print('Loading checkpoint...')
    stopper.load_checkpoint(model)
model.to(args['device'])

for epoch in range(args['num_epochs']):
    run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
    test_scores = run_an_eval_epoch(args, model, test_loader)
    test_msg = update_msg_from_scores('test results', test_scores)
    early_stop = stopper.step(test_scores['mae'], model)
    print(test_msg)
    
    if early_stop:
        print('Early stopping')
        break