In [None]:
from argparse import ArgumentParser
from model.origgnn import MolecularGNN, LigandDataset, pad_data
from utils.package import plot_fit_confidence_bond
from pytorch_lightning import Trainer
import pytorch_lightning as pl
from collections import defaultdict
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torch
import numpy as np
import wandb
import time
from sklearn.metrics import r2_score
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

In [None]:

def main(hparams):

    project = '3dgnn'
    wandb.login(key='local-8fe6e6b5840c4c05aaaf6aac5ca8c1fb58abbd1f', host='http://localhost:8080')
    wandb.init(project=project, save_code=True)

    model_name = f'3dgnn-dim-{hparams.dim}-hlayer-{hparams.layer_hidden}-olayer-{hparams.layer_output}-' + time.strftime("%Y%m%d_%H%M%S", time.localtime())
    dict_args = vars(hparams)
    model = MolecularGNN(**dict_args)
    # logger
    wandb_logger = pl.loggers.WandbLogger()
    # callbacks
    # early stopping
    early_stopping = pl.callbacks.early_stopping.EarlyStopping(monitor='val_loss', patience=20, mode='min')
    # checkpoint
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min', save_last=True,
                                                         dirpath='output/model_path', filename=model_name)
    if hparams.checkpoint == None:
        trainer = Trainer.from_argparse_args(hparams, logger=wandb_logger, auto_lr_find=True, callbacks=[early_stopping, checkpoint_callback])
    else:
        trainer = Trainer(resume_from_checkpoint=hparams.checkpoint, callbacks=[early_stopping])
    # trainer.tune(model)
    # manage data
    elements_dict = defaultdict(lambda: len(elements_dict))
    dataset = LigandDataset(hparams.data_path, elements_dict)
    train_size = int(len(dataset)*0.8)
    train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size], generator=torch.Generator().manual_seed(42))
    train_dataloader = DataLoader(train_dataset, batch_size=hparams.batch_size, collate_fn=pad_data, num_workers=10)
    val_dataloader = DataLoader(val_dataset, batch_size=hparams.batch_size, collate_fn=pad_data, num_workers=10)
    # Train
    trainer.fit(model, train_dataloader, val_dataloader)
    # trainer.save_checkpoint(time.strftime("%Y%m%d_%H%M%S", time.localtime()) + ".ckpt")
    trainer.test(model, val_dataloader, verbose=False)
    x = np.array(model.predictions['true'])
    y = np.array(model.predictions['pred'])
    val_r2 = r2_score(x, y)
    val_fig = plot_fit_confidence_bond(x, y, val_r2, annot=False)
    
    model.predictions = defaultdict(list)
    trainer.test(model, dataloaders=train_dataloader, verbose=False)
    x = np.array(model.predictions['true'])
    y = np.array(model.predictions['pred'])
    train_r2 = r2_score(x, y)
    train_fig = plot_fit_confidence_bond(x, y, train_r2, annot=False)
    wandb.log({'train_res': train_fig, 'val_res': val_fig})
    # print(val_r2)
    # model.log('val_r2', val_r2)
    wandb.log({'val_r2': val_r2, 'train_r2':train_r2})
    wandb.finish()


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--dim", type=int, default=512)
    parser.add_argument("--layer_hidden",type=int, default=30)
    parser.add_argument("--layer_output",type=int, default=8)
    parser.add_argument("--batch_size",type=int, default=128)
    parser.add_argument("--data_path",type=str, default=None)
    parser.add_argument("--checkpoint",type=str, default=None)
    dataset_path = '/home/huabei/Projects/molecularGNN_3Dstructure/dataset/dock_conformation_1/exhaus_96/data_train.txt'
    checkpoint = '20220618_211314.ckpt'
    # add model args
    parser = MolecularGNN.add_model_specific_args(parent_parser=parser)
    # add Trainer args
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args(['--data_path', dataset_path, '--learning_rate', '0.0001', '--gpus=1', '--max_epochs', '1000'])

    main(args)