In [None]:
from argparse import ArgumentParser
from model.origgnn import MolecularGNN
from utils.package import plot_fit_confidence_bond
from pytorch_lightning import Trainer
import pytorch_lightning as pl
from collections import defaultdict
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torch
import numpy as np
import wandb
import time
from sklearn.metrics import r2_score
import matplotlib
import matplotlib.pyplot as plt
# matplotlib.use('Agg')

In [None]:
def main(hparams):
    
    dict_args = vars(hparams)
    model = MolecularGNN(**dict_args).load_from_checkpoint(checkpoint_path=hparams.checkpoint)
    trainer = Trainer()
    # trainer.tune(model)
    trainer.test(model, dataloaders=model.val_dataloader(), verbose=False)
    x = np.array(model.predictions['true'])
    y = np.array(model.predictions['pred'])
    val_r2 = r2_score(x, y)
    val_fig = plot_fit_confidence_bond(x, y, val_r2, annot=False)
    val_fig.savefig('val_fig.png')
    model.predictions = defaultdict(list)
    trainer.test(model, dataloaders=model.train_dataloader(), verbose=False)
    x = np.array(model.predictions['true'])
    y = np.array(model.predictions['pred'])
    train_r2 = r2_score(x, y)
    train_fig = plot_fit_confidence_bond(x, y, train_r2, annot=False)
    train_fig.savefig('train_fig.png')
def prepare_arg():
    parser = ArgumentParser()
    parser.add_argument("--dim", type=int, default=256)
    parser.add_argument("--layer_hidden",type=int, default=16)
    parser.add_argument("--layer_output",type=int, default=8)
    parser.add_argument("--batch_size",type=int, default=128)
    parser.add_argument("--data_path",type=str, default=None)
    parser.add_argument("--checkpoint",type=str, default=None)
    return parser

In [None]:
if __name__ == "__main__":
    # prepare args
    parser = prepare_arg()
    dataset_path = '/home/huabei/Projects/SMTarRNA/project/data/in_man_exhaustiveness_96_orig_conformation.txt'
    checkpoint = '/home/huabei/Projects/SMTarRNA/project/checkpoints/SMTarRNA-dim-256-hlayer-16-olayer-8-20220630_135942.ckpt'
    # add Trainer args
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args(['--data_path', dataset_path, '--checkpoint', checkpoint])

    main(args)