# Demo for MoleculeSTM Downstream: Property Prediction

## Load Packages

In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score, mean_absolute_error, mean_squared_error

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader as torch_DataLoader
from torch_geometric.loader import DataLoader as pyg_DataLoader

from MoleculeSTM.datasets import MoleculeNetSMILESDataset, MoleculeNetGraphDataset
from MoleculeSTM.splitters import scaffold_split
from MoleculeSTM.utils import get_num_task_and_type, get_molecule_repr_MoleculeSTM
from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART
from MoleculeSTM.models import GNN, GNN_graphpred

[2023-08-30 12:23:17,997] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


## Setup Arguments

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--training_mode", type=str, default="fine_tuning", choices=["fine_tuning", "linear_probing"])
parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"])

########## for dataset and split ##########
parser.add_argument("--dataspace_path", type=str, default="../data")
parser.add_argument("--dataset", type=str, default="bace")
parser.add_argument("--split", type=str, default="scaffold")

########## for optimization ##########
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lr_scale", type=float, default=1)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--schedule", type=str, default="cycle")
parser.add_argument("--warm_up_steps", type=int, default=10)

########## for MegaMolBART ##########
parser.add_argument("--megamolbart_input_dir", type=str, default="../data/pretrained_MegaMolBART/checkpoints", help="This is only for MegaMolBART.")
parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt")

########## for saver ##########
parser.add_argument("--eval_train", type=int, default=0)
parser.add_argument("--verbose", type=int, default=1)

parser.add_argument("--input_model_path", type=str, default="demo_checkpoints_SMILES/molecule_model.pth")
parser.add_argument("--output_model_dir", type=str, default=None)

args = parser.parse_args("")
print("arguments\t", args)

arguments	 Namespace(batch_size=32, dataset='bace', dataspace_path='../data', device=0, epochs=5, eval_train=0, input_model_path='demo_checkpoints_SMILES/molecule_model.pth', lr=0.0001, lr_scale=1, megamolbart_input_dir='../data/pretrained_MegaMolBART/checkpoints', molecule_type='SMILES', num_workers=1, output_model_dir=None, schedule='cycle', seed=42, split='scaffold', training_mode='fine_tuning', verbose=1, vocab_path='../MoleculeSTM/bart_vocab.txt', warm_up_steps=10, weight_decay=0)


## Setup Seed

In [3]:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda:" + str(args.device)) \
    if torch.cuda.is_available() else torch.device("cpu")

## Setup Dataset and Dataloader

In [4]:
num_tasks, task_mode = get_num_task_and_type(args.dataset)
dataset_folder = os.path.join(args.dataspace_path, "MoleculeNet_data", args.dataset)


dataset = MoleculeNetSMILESDataset(dataset_folder)
dataloader_class = torch_DataLoader
use_pyg_dataset = False

smiles_list = pd.read_csv(
    dataset_folder + "/processed/smiles.csv", header=None)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(
    dataset, smiles_list, null_value=0, frac_train=0.8,
    frac_valid=0.1, frac_test=0.1, pyg_dataset=use_pyg_dataset)


train_loader = dataloader_class(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
val_loader = dataloader_class(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
test_loader = dataloader_class(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

1513 	 (1513, 1)


## Initialize and Load Model

In [5]:
if args.megamolbart_input_dir is not None:
    # This is loading from the pretarined_MegaMolBART
    # --megamolbart_input_dir=../../Datasets/pretrained_MegaMolBART/checkpoints
    # TODO: or maybe --input_model_path=../../Datasets/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_00/model_optim_rng.pt
    MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.megamolbart_input_dir, output_dir=None)
    print("Start from pretrained MegaMolBART using MLM.")
else:
    # This is starting from scratch
    MegaMolBART_wrapper = MegaMolBART(input_dir=None, output_dir=None)
    print("Start from randomly initialized MegaMolBART.")

model = MegaMolBART_wrapper.model
print("Update MegaMolBART with pretrained MoleculeSTM. Loading from {}...".format(args.input_model_path))
state_dict = torch.load(args.input_model_path, map_location='cpu')
model.load_state_dict(state_dict)
molecule_dim = 256


model = model.to(device)
linear_model = nn.Linear(molecule_dim, num_tasks).to(device)

# Rewrite the seed by MegaMolBART
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

using world size: 1 and model-parallel size: 1 
using torch.float32 for parameters ...
-------------------- arguments --------------------
  adam_beta1 ...................... 0.9
  adam_beta2 ...................... 0.999
  adam_eps ........................ 1e-08
  adlr_autoresume ................. False
  adlr_autoresume_interval ........ 1000
  apply_query_key_layer_scaling ... False
  apply_residual_connection_post_layernorm  False
  attention_dropout ............... 0.1
  attention_softmax_in_fp32 ....... False
  batch_size ...................... None
  bert_load ....................... None
  bias_dropout_fusion ............. False
  bias_gelu_fusion ................ False
  block_data_path ................. None
  checkpoint_activations .......... False
  checkpoint_in_cpu ............... False
  checkpoint_num_layers ........... 1
  clip_grad ....................... 1.0
  contigious_checkpointing ........ False
  cpu_optimizer ................... False
  cpu_torch_adam ..........

[W ProcessGroupNCCL.cpp:1569] Rank 0 using best-guess GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device.


  successfully loaded ../data/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_00/model_optim_rng.pt
Start from pretrained MegaMolBART using MLM.
Update MegaMolBART with pretrained MoleculeSTM. Loading from demo_checkpoints_SMILES/molecule_model.pth...


## Setup Optimizer

In [6]:
if args.training_mode == "fine_tuning":
    model_param_group = [
        {"params": model.parameters()},
        {"params": linear_model.parameters(), 'lr': args.lr * args.lr_scale}
    ]
else:
    model_param_group = [
        {"params": linear_model.parameters(), 'lr': args.lr * args.lr_scale}
    ]
optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.weight_decay)

## Define Support Functions

In [7]:
def train_classification(model, device, loader, optimizer):
    if args.training_mode == "fine_tuning":
        model.train()
    else:
        model.eval()
    linear_model.train()
    total_loss = 0

    if args.verbose:
        L = tqdm(loader)
    else:
        L = loader
    for step, batch in enumerate(L):
        SMILES_list, y = batch
        SMILES_list = list(SMILES_list)
        molecule_repr = get_molecule_repr_MoleculeSTM(
            SMILES_list, mol2latent=None,
            molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper)
        pred = linear_model(molecule_repr)
        pred = pred.float()
        y = y.to(device).float()
            
        is_valid = y ** 2 > 0
        loss_mat = criterion(pred, (y + 1) / 2)
        loss_mat = torch.where(
            is_valid, loss_mat,
            torch.zeros(loss_mat.shape).to(device).to(loss_mat.dtype))

        optimizer.zero_grad()
        loss = torch.sum(loss_mat) / torch.sum(is_valid)
        loss.backward()
        optimizer.step()
        total_loss += loss.detach().item()

    return total_loss / len(loader)


@torch.no_grad()
def eval_classification(model, device, loader):
    model.eval()
    linear_model.eval()
    y_true, y_scores = [], []

    if args.verbose:
        L = tqdm(loader)
    else:
        L = loader
    for step, batch in enumerate(L):
        SMILES_list, y = batch
        SMILES_list = list(SMILES_list)
        molecule_repr = get_molecule_repr_MoleculeSTM(
            SMILES_list, mol2latent=None,
            molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper)
        pred = linear_model(molecule_repr)
        pred = pred.float()
        y = y.to(device).float()
        
        y_true.append(y)
        y_scores.append(pred)

    y_true = torch.cat(y_true, dim=0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim=0).cpu().numpy()

    roc_list = []
    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
            is_valid = y_true[:, i] ** 2 > 0
            roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))
        else:
            print("{} is invalid".format(i))

    if len(roc_list) < y_true.shape[1]:
        print(len(roc_list))
        print("Some target is missing!")
        print("Missing ratio: %f" %(1 - float(len(roc_list)) / y_true.shape[1]))

    return sum(roc_list) / len(roc_list), 0, y_true, y_scores

## Start Training

In [8]:
train_func = train_classification
eval_func = eval_classification

train_roc_list, val_roc_list, test_roc_list = [], [], []
train_acc_list, val_acc_list, test_acc_list = [], [], []
best_val_roc, best_val_idx = -1, 0
criterion = nn.BCEWithLogitsLoss(reduction="none")

for epoch in range(1, args.epochs + 1):
    loss_acc = train_func(model, device, train_loader, optimizer)
    print("Epoch: {}\nLoss: {}".format(epoch, loss_acc))

    if args.eval_train:
        train_roc, train_acc, train_target, train_pred = eval_func(model, device, train_loader)
    else:
        train_roc = train_acc = 0
    val_roc, val_acc, val_target, val_pred = eval_func(model, device, val_loader)
    test_roc, test_acc, test_target, test_pred = eval_func(model, device, test_loader)

    train_roc_list.append(train_roc)
    train_acc_list.append(train_acc)
    val_roc_list.append(val_roc)
    val_acc_list.append(val_acc)
    test_roc_list.append(test_roc)
    test_acc_list.append(test_acc)
    print("train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc, val_roc, test_roc))
    print()

print("best train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc_list[best_val_idx], val_roc_list[best_val_idx], test_roc_list[best_val_idx]))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 20.98it/s]


Epoch: 1
Loss: 0.6168293129456671


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.69it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.62it/s]


train: 0.000000	val: 0.716484	test: 0.721788



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 25.27it/s]


Epoch: 2
Loss: 0.4680136606881493


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.49it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.08it/s]


train: 0.000000	val: 0.759707	test: 0.791167



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.34it/s]


Epoch: 3
Loss: 0.4001527561953193


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.11it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.06it/s]


train: 0.000000	val: 0.763736	test: 0.785950



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.44it/s]


Epoch: 4
Loss: 0.35615202117907374


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.01it/s]


train: 0.000000	val: 0.759341	test: 0.796035



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.84it/s]


Epoch: 5
Loss: 0.31917470811229004


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  5.08it/s]

train: 0.000000	val: 0.766300	test: 0.786646

best train: 0.000000	val: 0.716484	test: 0.721788



