In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
%%capture
! pip install pytdc

In [4]:
%%capture
! pip install trl

In [5]:
import safe as sf
import datamol as dm
import torch
import numpy as np
from tqdm.auto import tqdm
from safe import SAFEDesign
from tdc import Oracle
from tdc import Evaluator
from safe.trainer.model import SAFEDoubleHeadsModel
from safe.tokenizer import SAFETokenizer
from random import choices
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model

In [6]:
def train(ppo_config, generation_kwargs, model, tokenizer, oracle, prefix=None, n_episodes=100):
    safe_encoder = sf.SAFEConverter()
    model_ref = create_reference_model(model)
    config = PPOConfig(**ppo_config)

    diversity_evaluator = Evaluator(name = 'Diversity')
    validity_evaluator = Evaluator(name = 'Validity')
    uniqueness_evaluator = Evaluator(name = 'Uniqueness')

    ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
    if prefix is None:
        prefix = tokenizer.bos_token
    
    if isinstance(prefix, str):
        prefix = [prefix]

    if len(prefix) < BATCH_SIZE:
        prefix = choices(prefix, k=BATCH_SIZE)
    
    for episode in tqdm(range(n_episodes)): 
        game_data =  dict()
        game_data["query"] = prefix
        batch = tokenizer(prefix, return_tensors="pt", add_special_tokens=False).to(model.pretrained_model.device)
        query_tensor = batch["input_ids"]
        response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
        decoded_safe_mols = tokenizer.batch_decode(response_tensor, skip_special_tokens=True)

        decoded_smiles = [
            sf.decode(x,
                    as_mol=False,
                    fix=True,
                    remove_added_hs=True,
                    canonical=True,
                    ignore_errors=True,
                    remove_dummies=True,
                ) for x in decoded_safe_mols
            ]
        
        game_data["response"] = decoded_safe_mols
        valid_position, valid_smiles = zip(*[(i, x) for i, x in enumerate(decoded_smiles) if x is not None])
        batch_reward = oracle(list(valid_smiles))
        rewards = np.zeros(len(decoded_smiles))
        rewards[np.asarray(valid_position)] = batch_reward
        rewards = torch.from_numpy(rewards).to(device=model.pretrained_model.device)
        rewards = [item for item in rewards]
        stats = ppo_trainer.step([_ for _ in query_tensor], [_ for _ in response_tensor], rewards)
        stats["validity"] = (len(valid_position) / BATCH_SIZE)
        stats["uniqueness"]  = uniqueness_evaluator(list(valid_smiles))
        stats["diversity"]  = diversity_evaluator(list(valid_smiles))
        ppo_trainer.log_stats(stats, game_data, rewards)
    return ppo_trainer, model

In [7]:
oracle = Oracle(name = 'LogP')

In [8]:
model_path = SAFEDesign._DEFAULT_MODEL_PATH
safe_tokenizer = SAFETokenizer.load(os.path.join(model_path, 'tokenizer.json'))
safe_model = SAFEDoubleHeadsModel.from_pretrained(model_path)

In [9]:
tokenizer = safe_tokenizer.get_pretrained()
model = AutoModelForCausalLMWithValueHead(safe_model)
model.is_peft_model = False

In [10]:
%%capture
! pip install wandb

In [11]:
BATCH_SIZE = 32

ppo_config = {"batch_size": BATCH_SIZE, "log_with":"wandb", "model_name": "safe-test-notebook"}
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "bos_token_id": tokenizer.bos_token_id,
    "eos_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 100,
}


In [12]:
%env WANDB_SILENT=True
%env WANDB_LOG_MODEL=end
%env WANDB_WATCH=all
%env WANDB_PROJECT=safe-gpt2-plogp
%env WANDB_NAME=notebook-test

env: WANDB_SILENT=True
env: WANDB_LOG_MODEL=end
env: WANDB_WATCH=all
env: WANDB_PROJECT=safe-gpt2-plogp
env: WANDB_NAME=notebook-test


In [13]:
import wandb
wandb.finish()

In [14]:
%%wandb

trainer, trained_model = train(ppo_config, generation_kwargs, model, tokenizer, oracle, prefix=None, n_episodes=100)

fatal: No names found, cannot describe anything.


  0%|          | 0/100 [00:00<?, ?it/s]

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Found local copy...


### Logs

![image.png](attachment:image.png)