## SAFE for Goal-directed optimization

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

#### Install the key dependencies

In [3]:
# %%capture
# ! pip install pytdc
# ! pip install wandb
# ! pip install trl

In [4]:
import os
import safe as sf
import datamol as dm
import torch
import numpy as np
from tqdm.auto import tqdm
from tdc import Evaluator
from safe.trainer.model import SAFEDoubleHeadsModel
from safe.tokenizer import SAFETokenizer
from safe.converter import encode, decode, SAFEConverter
from random import choices
from trl import AutoModelForCausalLMWithValueHead,PreTrainedModelWrapper, create_reference_model
from safe.sample import SAFEDesign
from safe.optim import REINVENTConfig, REINVENTTrainer, AutoModelForCausalLM


[2024-09-19 15:49:31,613] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to mps (auto detect)


W0919 15:49:32.087000 8193675072 torch/distributed/elastic/multiprocessing/redirects.py:28] NOTE: Redirects are currently not supported in Windows or MacOs.


### Reinvent training process for goal-directed generation

In [5]:
# This block show the optimization loop for the goal-directed generation with SAFE-GPT model

def REINVENT_train(config, generation_kwargs, model, tokenizer, reward_fn, prefix=None, n_episodes=100):
    """ Proximal Policy Optimization training for molecules generation
    Args:
        config: finetuning configs.
        generation_kwargs: Additional parameters for generation.
        model: Base model for optimization.
        tokenizer: SAFE tokenizer to tokenize molecule smiles strings.
        oracle: Reward function for training.
        prefix: String prefix for fragment constrained generation.
        n_episodes: Number of episodes to update the policy and value function of the agent.

    Returns:
        reinvent_trainer: trained REINVENT trainer
        model: Fine-tuned SAFE model with optmization.
    """
    # get the safe string encoder
    if not isinstance(model, PreTrainedModelWrapper):
        model = AutoModelForCausalLM(safe_model)
    safe_encoder = SAFEConverter()

    # define the referene model during fine-tuning
    prior = create_reference_model(model)
    reinvent_config = REINVENTConfig(**config)

    # define evaluation metrics for tracking
    diversity_evaluator = Evaluator(name = 'Diversity')
    uniqueness_evaluator = Evaluator(name = 'Uniqueness')

    reinvent_trainer = REINVENTTrainer(reinvent_config, model, prior, tokenizer)
    if isinstance(prefix, str):

        encoded_fragment = safe_encoder.encoder(
                        prefix,
                        canonical=False,
                        randomize=True,
                        constraints=None,
                        allow_empty=True,
                    )
        prefix = encoded_fragment.rstrip(".") + "."

    if prefix is None:
        prefix = ""

    if isinstance(prefix, str):
        prefix = [prefix]

    batch_size = config.get("batch_size", 32)
    if len(prefix) < batch_size:
        prefix = choices(prefix, k=batch_size)

    for _ in tqdm(range(n_episodes)):

        # a new complete sequence of actions for agent to learn
        game_data =  {}
        game_data["query"] = prefix
        batch = tokenizer([tokenizer.bos_token+x for x in prefix], return_tensors="pt", add_special_tokens=False).to(model.pretrained_model.device)
        query_tensor = batch["input_ids"]
        # generation
        response_tensor = reinvent_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs)
        decoded_safe_mols = tokenizer.batch_decode(response_tensor, skip_special_tokens=True)
        decoded_smiles = [
            decode(x,
                    as_mol=False,
                    fix=True,
                    remove_added_hs=True,
                    canonical=True,
                    ignore_errors=True,
                    remove_dummies=True,
                ) for x in decoded_safe_mols
            ]
        game_data["response"] = decoded_safe_mols

        # compute the reward scores
        rewards = np.zeros(len(decoded_smiles),  dtype=np.float32)
        valid_position = []
        valid_smiles = []
        valid_position, valid_smiles = zip(*[(i, x) for i, x in enumerate(decoded_smiles) if x is not None])
        valid_smiles = list(valid_smiles)

        # get reward function
        batch_reward = [reward_fn(smi) for smi in valid_smiles]
        rewards[np.asarray(valid_position)] = batch_reward
        rewards = torch.from_numpy(rewards).to(device=model.pretrained_model.device)
        rewards = list(rewards)

        # get the training stats
        stats = reinvent_trainer.step(list(query_tensor), list(response_tensor), rewards)
        stats["validity"] = (len(valid_position) / batch_size)

        # other statistics to track
        if len(valid_smiles) > 0:
            stats["uniqueness"]  = uniqueness_evaluator(list(valid_smiles))
            stats["diversity"]  = diversity_evaluator(list(valid_smiles))
        reinvent_trainer.log_stats(stats, game_data, rewards)
    return reinvent_trainer

### Define the SAFE model for fine-tuning

In [6]:
# get the base safe-gpt model for fine-tuning
designer = SAFEDesign.load_default()
safe_tokenizer = designer.tokenizer
safe_model = designer.model
tokenizer = safe_tokenizer.get_pretrained()

# wrap the model for training
model = safe_model
model.is_peft_model = False


### Define the reward function that the agent can learn from. 
It can be a single molecular property such as `clogP` or a surrogate function of multiple molecular properties such as `BBB score`. It can also be a scroing function based on a `predictive model` for potency etc. 

In [7]:
# In this tutorial, LogP is used for demonstration purpose
# The desired log P value is 4
def clogp_reward_fn(mol: str, **kwargs):
    """ Reward function for optimization
    Args:
        mol: Molecule in SMILES.
    """
    mol = dm.to_mol(mol)
    if mol is None:
        return -100
    return dm.descriptors.clogp(mol)

#### Start the REINVENT training and track the training on Wandb

In [8]:
import os
os.environ["WANDB_SILENT"] = "False"
os.environ["WANDB_LOG_MODEL"]="end"
os.environ["WANDB_WATCH"]="all"
os.environ["WANDB_ENTITY"]="valencelabs"

In [9]:
# define REINVENT config
n_episodes = 25
# scaffold = "[*:2]N1CCN(CC1)CCCCN[*:1]"
scaffold = None # a small number for testing purpose in thßßis tutorial

trainer_map = {}
for strategy in tqdm(["dap", "sdap", "mauli", "mascof"]):
    config = {
        "batch_size": 32,
        "mini_batch_size":32,
        "log_with":"wandb",
        "exp_name": "safe-gpt-reinvent-cLogP",
        "tracker_project_name": "safe-reinvent-tutorial",
        "reward_model": "cLogP",
        "sigma":100,
        "steps": n_episodes,
        "reinvent_epochs": 2,
        "max_buffer_size": 512,
        "strategy": strategy
        }

    config["exp_name"] += f"_{strategy}"
    # generation config
    # see more at https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig
    generation_kwargs = {
        "min_length": -1,
        "do_sample": True,
        "pad_token_id": tokenizer.pad_token_id,
        "bos_token_id": tokenizer.bos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "max_new_tokens": 100,
    }

    trainer_map[strategy] = REINVENT_train(config, generation_kwargs, model, tokenizer, clogp_reward_fn, prefix=scaffold, n_episodes=n_episodes)

  0%|          | 0/4 [00:00<?, ?it/s]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmaclandrol[0m ([33mvalencelabs[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/25 [00:00<?, ?it/s]

KeyboardInterrupt: 

Now you can used the fine-tuned model and use it for goal-directed generation
Below we use de novo generation as an example.

In [None]:
designer = SAFEDesign(model=trainer.model, tokenizer=safe_tokenizer )
generated = designer.de_novo_generation(n_samples_per_trial=10)
valid_position, valid_smiles = zip(*[(i, x) for i, x in enumerate(generated) if x is not None])
generated_mols = [dm.to_mol(mol) for mol in valid_smiles]
mol_prop = [dm.descriptors.clogp(mol) for mol in generated_mols ]
print(np.mean(mol_prop))
dm.to_image([dm.to_mol(mol) for mol in generated if mol is not None ],
            legends= [f"clogP: {x:.2f}" for x in mol_prop])