In [1]:
from argparse import ArgumentParser
from model.origgnn import MolecularGNN
from utils.package import plot_fit_confidence_bond
from pytorch_lightning import Trainer
import pytorch_lightning as pl
from collections import defaultdict
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torch
import numpy as np
import wandb
import time
from sklearn.metrics import r2_score
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

In [2]:

wandb.login(key='local-8fe6e6b5840c4c05aaaf6aac5ca8c1fb58abbd1f', host='http://localhost:8080')


[34m[1mwandb[0m: Currently logged in as: [33mhuabei[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for localhost to your netrc file: /home/huabei/.netrc


True

In [3]:
# %%wandb
def main(hparams):
    
    model_name = f'SMTarRNA-3a6p-dim-{hparams.dim}-hlayer-{hparams.layer_hidden}-olayer-{hparams.layer_output}-' + time.strftime("%Y%m%d_%H%M%S", time.localtime())
    wandb.log({'checkpoint': model_name})
    dict_args = vars(hparams)
    model = MolecularGNN(**dict_args)
    print(model.elements_dict)
    # raise ValueError
    # logger
    wandb_logger = pl.loggers.WandbLogger(save_dir='log/3a6p')
    # callbacks
    # early stopping
    early_stopping = pl.callbacks.early_stopping.EarlyStopping(monitor='val_loss', patience=20, mode='min')
    # checkpoint
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min', save_last=True,
                                                         dirpath='checkpoints', filename=model_name)
    if hparams.checkpoint == None:
        trainer = Trainer.from_argparse_args(hparams, logger=wandb_logger, auto_lr_find=True, callbacks=[early_stopping, checkpoint_callback])
    else:
        trainer = Trainer(resume_from_checkpoint=hparams.checkpoint, callbacks=[early_stopping])
    # trainer.tune(model)
 
    # Train
    trainer.fit(model)
    # trainer.save_checkpoint(time.strftime("%Y%m%d_%H%M%S", time.localtime()) + ".ckpt")
    trainer.test(model, dataloaders=model.val_dataloader(), verbose=False)
    x = np.array(model.predictions['true'])
    y = np.array(model.predictions['pred'])
    val_r2 = r2_score(x, y)
    val_fig = plot_fit_confidence_bond(x, y, val_r2, annot=False)
    
    model.predictions = defaultdict(list)
    trainer.test(model, dataloaders=model.train_dataloader(), verbose=False)
    x = np.array(model.predictions['true'])
    y = np.array(model.predictions['pred'])
    train_r2 = r2_score(x, y)
    train_fig = plot_fit_confidence_bond(x, y, train_r2, annot=False)
    if True:
        wandb.log({'train_res': train_fig, 'val_res': val_fig})
        wandb.log({'val_r2': val_r2, 'train_r2':train_r2})
        wandb.finish()

In [4]:

def prepare_arg():
    parser = ArgumentParser()
    parser.add_argument("--dim", type=int, default=256)
    parser.add_argument("--layer_hidden",type=int, default=24)
    parser.add_argument("--layer_output",type=int, default=10)
    parser.add_argument("--batch_size",type=int, default=128)
    parser.add_argument("--data_path",type=str, default=None)
    parser.add_argument("--checkpoint",type=str, default=None)
    return parser

In [None]:
# %%wandb

if __name__ == "__main__":
    project = 'SMTarRNA-3a6p-project'
    wandb.init(project=project, dir='log/3a6p', notes='new dataset')
    # prepare args
    parser = prepare_arg()
    dataset_path = 'data/3a6p/3a6p_exhaus_96_100K.txt'
    # checkpoint = '20220618_211314.ckpt'
    # add model args
    parser = MolecularGNN.add_model_specific_args(parent_parser=parser)
    # add Trainer args
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args(['--data_path', dataset_path, '--learning_rate', '0.0005', '--gpus=1', '--max_epochs', '500'])
    main(args)

In [None]:
wandb.finish()

# wandb sweep

In [5]:
sweep_config = {
  "name" : "sweep",
  "method" : "random",
  "parameters": {
    "max_epochs": {
      "value": 500
    },
    "learning_rate": {
      "distribution": "log_uniform_values",
      "min": 0.00001,
      "max": 0.1
    },
    "lr_decay": {
      "min": 0.95,
      "max": 0.999
    },
    "dim" : {
      "distribution": "int_uniform",
      "min": 128,
      "max": 512
    },
    "layer_hidden": {
      "distribution": "int_uniform",
      "min": 8,
      "max": 32
    },
    "layer_output": {
      "distribution": "int_uniform",
      "min": 8,
      "max": 20
    }
  }
}

In [10]:

sweep_id = wandb.sweep(sweep_config, project='SMTarRNA-3a6p-project')

Create sweep with ID: u5md4p3y
Sweep URL: http://localhost:8080/huabei/SMTarRNA-3a6p-project/sweeps/u5md4p3y


In [6]:
def arg_for_sweep(config: dict):
    # prepare args
    parser = prepare_arg()
    dataset_path = 'data/3a6p/3a6p_exhaus_96_100K.txt'
    checkpoint = '20220618_211314.ckpt'
    # add model args
    parser = MolecularGNN.add_model_specific_args(parent_parser=parser)
    # add Trainer args
    parser = Trainer.add_argparse_args(parser)
    hyperparameter_list = ['--data_path', dataset_path, '--gpus=1']
    for key, value in config.items():
        # print(key, value)
        hyperparameter_list.extend(['--' + key, str(value)])
    # print(hyperparameter_list)
    args = parser.parse_args(hyperparameter_list)
    # print('here is right')
    return args
    # args = parser.parse_args(['--data_path', dataset_path, '--learning_rate', '0.0001', '--gpus=1', '--max_epochs', '1000'])

def train():
    with wandb.init(dir='log/3a6p') as run:
        config = wandb.config
        args = arg_for_sweep(config=config)
        # print(vars(args))
        # print(config)
        main(args)

In [None]:
count = 4 # number of runs to execute
sweep_id = 'huabei/SMTarRNA-3a6p-project/u5md4p3y'
wandb.agent(sweep_id, function=train, count=count)