In [1]:
!pip3 install -r requirements.txt


Successfully installed accelerate-0.16.0 appdirs-1.4.4 docker-pycreds-0.4.0 pathtools-0.1.2 sentry-sdk-1.15.0 setproctitle-1.3.2 urllib3-1.26.14 wandb-0.13.10
[0m

In [None]:
!git clone https://github.com/justinphan3110/SciFive.git
!cp -r SciFive/biot5x/data .
!rm -rm SciFive

In [2]:
import torch
from tqdm import tqdm

from transformers import pipeline, AutoTokenizer, DataCollatorForSeq2Seq
from datasets import load_dataset, Dataset

from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead, set_seed
from trl.core import LengthSampler

In [3]:
config = PPOConfig(
    model_name="out/chemprot_hf/biot5_pytorch_text",
    # model_name="t5-base",
    learning_rate=1.41e-5,
    batch_size=1024,
    forward_batch_size=1024,
    input_length = 256,
    target_length = 5,
    metric = 'PRF1',
    ppo_epochs=1,
    init_kl_coef=0.0,
    log_with="tensorboard",
    remove_unused_columns=False,
    accelerator_kwargs={"logging_dir": "log"}
)
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": config.forward_batch_size}


In [None]:

# set seed before initializing value head for deterministic eval
set_seed(config.seed)

# Now let's build the model, the reference model, and the tokenizer.
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)


In [None]:
def preprocess_function(examples, input_length=config.input_length, target_length=config.target_length):
    model_inputs = tokenizer(
        examples["inputs"], max_length=input_length, truncation=True, padding=True
    )
    
    with tokenizer.as_target_tokenizer():
      labels = tokenizer(
          examples["labels"], max_length=target_length, truncation=True, padding=True
      )

    
    model_inputs['labels'] = labels['input_ids']
    model_inputs['input_ids'] = model_inputs['input_ids']
    return model_inputs

In [None]:
input_column = "inputs"
target_column = "labels"
raw_datasets = {}
for file_ in ['train', 'dev', 'test']:
    inputs = []
    targets = []

    with open(f'data/chemprot/{file_}_blurb_text.tsv', 'r', encoding='utf-8') as file:
        for i, line in enumerate(file):
            line = line.strip().split('\t')
            inputs.append(f'{line[0].strip()}')
            targets.append(f'{line[1].strip()}')
    
    
    dataset = Dataset.from_dict({input_column: inputs, target_column: targets})
    tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=['inputs'], num_proc=1)
    tokenized_datasets.set_format(type="torch")
    raw_datasets[file_] = tokenized_datasets

In [None]:
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")

train_datasets =  Dataset.from_dict(raw_datasets['test'][:512*20])
train_datasets.set_format(type="torch")

ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=train_datasets, data_collator=data_collator)

device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
    print("device", device)
    
    
pipe = pipeline("text-classification")
classification_pipe = pipeline("text-classification", "out/chemprot_hf/BioLinkBERT-base", device=device)

In [None]:
# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.
output_min_length = 2
output_max_length = 5
output_length_sampler = LengthSampler(output_min_length, output_max_length)
generation_kwargs = { "max_length": 5}
dataloader = torch.utils.data.DataLoader(train_datasets, collate_fn=data_collator, batch_size=config.forward_batch_size)


for epoch in range(5):
    out_dir = f"out/test_trl_biot5_chemprot/checkpoint_{epoch}"
    for batch in tqdm(ppo_trainer.dataloader):
        query_tensors = batch["input_ids"]
        label_tensors = batch["labels"]

        outputs = ppo_trainer.generate(query_tensors, **generation_kwargs)
        response_tensors = list(outputs)
        texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        labels = tokenizer.batch_decode(label_tensors, skip_special_tokens=True)

        sent_kwargs['function_to_apply'] = 'sigmoid'
        sent_kwargs['return_all_scores'] = True
        pipe_outputs = classification_pipe(texts, **sent_kwargs)
        
        rewards = []
        for t,output, label in zip(texts, pipe_outputs, labels):
            if label == t:
                if label == '0':
                    reward = 1.0
                else:
                    reward = 1.0
            else:
                if label == '0':
                    reward = 0
                else: 
                    reward = 0
            rewards.append(torch.tensor(reward).to(device))
        
        assert len(rewards) == len(labels) == len(texts)
        #### Run PPO step
        stats = ppo_trainer.step(list(query_tensors), response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards)
    print(f"saving pretrained epoch {epoch} to {out_dir}")
    ppo_trainer._save_pretrained(out_dir)