# Demo for MoleculeSTM Downstream: Molecule Editing

## Load Packages

In [1]:
import warnings
warnings.filterwarnings('ignore')

import argparse
import math
import numpy as np
import os

import torch
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm

from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem, Descriptors
from rdkit import DataStructs
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

from MoleculeSTM.utils import prepare_text_tokens
from MoleculeSTM.downstream_molecule_edit_utils import get_SMILES_list, get_description_list, load_language_molecule_and_edit_models, clip_loss_for_edit

import sys
sys.path.insert(0, "../scripts")
from downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization import get_lr, mean_pooling

[2023-08-30 12:31:55,780] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


## Setup Arguments

Notice that at this step, we are only using the textual branch (SciBERT) and a pretrained molecule generative model (MegaMolBART). The MoleculeSTM chemical branch (MegaMolBART or GraphMVP) is only used at the module alignment phase, and we can change it in the `MoleculeSTM_model_dir` argument.

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--verbose", type=int, default=1)

########## for editing ##########
parser.add_argument("--input_description", type=str, default=None)
parser.add_argument("--input_description_id", type=int, default=None)
parser.add_argument("--input_SMILES", type=str, default="OC1C2C1CC2")
parser.add_argument("--input_SMILES_file", type=str, default=None)
parser.add_argument("--output_model_dir", type=str, default=None)
parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"])
parser.add_argument("--use_noise_for_init", dest="use_noise_for_init", action="store_true")
parser.add_argument("--no_noise_for_init", dest="use_noise_for_init", action="store_false")
parser.set_defaults(use_noise_for_init=True)
parser.add_argument('--normalize', dest='normalize', action='store_true')
parser.add_argument('--no_normalize', dest='normalize', action='store_false')
parser.set_defaults(normalize=True)

parser.add_argument("--dataspace_path", type=str, default="../data")
parser.add_argument("--SSL_emb_dim", type=int, default=256)
parser.add_argument("--max_seq_len", type=int, default=512)

########## for foundation ##########
parser.add_argument("--MoleculeSTM_model_dir", type=str, default="demo_checkpoints_SMILES")
parser.add_argument("--MoleculeSTM_molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"])
parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt")

########## for generation ##########
parser.add_argument("--MegaMolBART_generation_model_dir", type=str, default="../data/pretrained_MegaMolBART/checkpoints")

########## for foundation and generation projection ##########
parser.add_argument("--language_edit_model_dir", type=str, default="demo_checkpoints_SMILES")   

########## for editing ##########
parser.add_argument("--lr_rampup", type=float, default=0.05)
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--epochs", type=int, default=100)
args, unknown = parser.parse_known_args()

print(args)

Namespace(MegaMolBART_generation_model_dir='../data/pretrained_MegaMolBART/checkpoints', MoleculeSTM_model_dir='demo_checkpoints_SMILES', MoleculeSTM_molecule_type='SMILES', SSL_emb_dim=256, dataspace_path='../data', device=0, epochs=100, input_SMILES='OC1C2C1CC2', input_SMILES_file=None, input_description=None, input_description_id=None, language_edit_model_dir='demo_checkpoints_SMILES', lr=0.1, lr_rampup=0.05, max_seq_len=512, mode='edit', normalize=True, output_model_dir=None, seed=42, use_noise_for_init=True, verbose=1, vocab_path='../MoleculeSTM/bart_vocab.txt')


## Load Models

In [3]:
text_model, text_tokenizer, text_dim, molecule_model, MegaMolBART_wrapper, molecule_dim, \
    text2latent, mol2latent, generation2foundation, foundation2generation = load_language_molecule_and_edit_models(args)
device = torch.device("cuda:" + str(args.device)) \
    if torch.cuda.is_available() else torch.device("cpu")
text_model = text_model.to(device)
molecule_model = molecule_model.to(device)
text2latent = text2latent.to(device)
mol2latent = mol2latent.to(device)
generation2foundation.to(device)
foundation2generation.to(device)
text_model.eval()
molecule_model.eval()
text2latent.eval()
mol2latent.eval()
generation2foundation.eval()
foundation2generation.eval()

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Loading from demo_checkpoints_SMILES/text_model.pth...
using world size: 1 and model-parallel size: 1 
using torch.float32 for parameters ...
-------------------- arguments --------------------
  adam_beta1 ...................... 0.9
  adam_beta2 ...................... 0.999
  adam_eps ........................ 1e-08
  adlr_autoresume ................. False
  adlr_autoresume_interval ........ 1000
  apply_query_key_layer_scaling ... False
  apply_residual_connection_post_layernorm  False
  attention_dropout ............... 0.1
  attention_softmax_in_fp32 ....... False
  batch_size ...................... None
  bert_load ....................... None
  bias_dropout_fusion ............. False
  bias_gelu_fusion ................ False
  block_data_path ................. None
  checkpoint_activations .......... False
  checkpoint_in_cpu ............... False
  checkpoint_num_layers ........... 1
  clip_grad ....................... 1.0
  contigious_checkpointing ........ False
  cpu_optimize

[W ProcessGroupNCCL.cpp:1569] Rank 0 using best-guess GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device.


  successfully loaded ../data/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_00/model_optim_rng.pt
Loading from pretrained MegaMolBART (../data/pretrained_MegaMolBART/checkpoints).
Loading from demo_checkpoints_SMILES/text2latent_model.pth...
Loading from demo_checkpoints_SMILES/mol2latent_model.pth...
Loading from demo_checkpoints_SMILES/generation2foundation_model.pth...
Loading from demo_checkpoints_SMILES/foundation2generation_model.pth...


MLP(
  (layers): ModuleList(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
  )
)

# Reset seed

In [4]:
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda:" + str(args.device)) \
    if torch.cuda.is_available() else torch.device("cpu")

## Define Support Functions

In [5]:
def evaluate_SMILES_list(SMILES_list, description):
    print("SMILES_list:", SMILES_list)
    mol_list = []
    for SMILES in SMILES_list:
        mol = Chem.MolFromSmiles(SMILES)
        if mol is None:
            continue
        mol_list.append(mol)

    if len(mol_list) < 3:
        return [False]

    if "soluble" in description and "insoluble" not in description:
        props = ["MolLogP"]
        prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props]
        value_list = []
        for name, func in prop_pred:
            for idx, (SMILES, mol) in enumerate(zip(SMILES_list, mol_list)):
                if idx == 1:
                    continue
                value = func(mol)
                value_list.append(value)
                print("SMILES: {}\t\t\tlogP: {:.5f}".format(SMILES, value))
        if value_list[0] > value_list[-1]:
            answer = [True]
        else:
            answer = [False]

    return answer


def check_edit(SMILES, text, device):
    text_list = [text]
    text_tokens_ids, text_masks = prepare_text_tokens(
        device=device, description=text_list, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)
    text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks)
    text_repr = text_output["pooler_output"]
    text_repr = text2latent(text_repr)

    first_and_second_SMILES_list = []

    latent_code_init, pad_mask_init = MegaMolBART_wrapper.smileslist2embedding([SMILES])  # [pad, B, d], [pad, B]
    first_and_second_SMILES_list.append(SMILES)

    regenerated_mols = MegaMolBART_wrapper.inverse_transform([latent_code_init], pad_mask_init.bool().cuda(), k=1, sanitize=True)
    first_and_second_SMILES_list.append(regenerated_mols[0])

    l2_lambda_list = [1e0]
    result_SMILES_list_one_pair, result_eval_list_one_pair = [], []
    
    if args.use_noise_for_init:
        print("Use random noise for init")
        random_noise = torch.randn(latent_code_init.size()).to(device)
    
    for l2_lambda in l2_lambda_list:
        print("l2 lambda: {}".format(l2_lambda))
        current_SMILES_list = [first_and_second_SMILES_list[0]] + [first_and_second_SMILES_list[1]]
        if args.use_noise_for_init:
            print("Use random noise for init")
            latent = latent_code_init.detach().clone() + random_noise
        else:
            print("No random noise for init")
            latent = latent_code_init.detach().clone()
        pad_mask = pad_mask_init.detach().clone()
        latent.requires_grad = True
        optimizer = optim.Adam([latent], lr=args.lr)
        
        if args.verbose:
            L = tqdm(range(args.epochs))
        else:
            L = range(args.epochs)

        for i in L:
            t = i / args.epochs
            lr = get_lr(t, args.lr)
            optimizer.param_groups[0]["lr"] = lr

            molecule_repr_generation = mean_pooling(latent, pad_mask) # [B, d]
            if args.normalize:
                molecule_repr_generation = F.normalize(molecule_repr_generation, dim=-1)
            molecule_repr_foundation = generation2foundation(molecule_repr_generation)

            clip_loss_ = clip_loss_for_edit(molecule_repr_foundation, text_repr)
            l2_loss_ =  l2_lambda * ((latent_code_init - latent) ** 2).mean()

            loss = clip_loss_ + l2_loss_

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
        print("clip loss: {:.5f}\tL2 loss: {:.5f}".format(clip_loss_.item(), l2_loss_.item()))

        generated_mols = MegaMolBART_wrapper.inverse_transform([latent], pad_mask.bool().cuda(), k=1, sanitize=True)
        current_SMILES_list.append(generated_mols[0])
        result_SMILES_list_one_pair.append([text] + current_SMILES_list + ['{}'.format(l2_lambda)])

        current_result_list = evaluate_SMILES_list(current_SMILES_list, text)
        result_eval_list_one_pair.append(current_result_list)
        print()
    
    result_eval_list_one_pair = np.array(result_eval_list_one_pair)
    result_eval_list_one_pair = np.any(result_eval_list_one_pair, axis=0, keepdims=True)
    return result_SMILES_list_one_pair, result_eval_list_one_pair


## Start Molecule Editing

In [6]:
print("start editing\n\n\n")

source_SMILES_list = get_SMILES_list(args)

description = "This molecule is soluble in water."


print("===== for text prompt: {} =====".format(description))
result_SMILES_list, result_acc_list = [], []

for SMILES in source_SMILES_list:
    print("===== for SMILES {} =====".format(SMILES))
    result_SMILES_list_, result_acc_list_ = check_edit(SMILES, description, device)

start editing



===== for text prompt: This molecule is soluble in water. =====
===== for SMILES OC1C2C1CC2 =====




Use random noise for init
l2 lambda: 1.0
Use random noise for init


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 38.36it/s]


clip loss: -0.92124	L2 loss: 0.33059
SMILES_list: ['OC1C2C1CC2', 'OC12CC1C2', 'OC1CC2CC(O)(C1)C2']
SMILES: OC1C2C1CC2			logP: 0.38710
SMILES: OC1CC2CC(O)(C1)C2			logP: 0.28220

