In [1]:
!pip3 install -r requirements.txt


  Attempting uninstall: huggingface-hub
    Found existing installation: huggingface-hub 0.8.1
    Uninstalling huggingface-hub-0.8.1:
      Successfully uninstalled huggingface-hub-0.8.1
  Attempting uninstall: transformers
    Found existing installation: transformers 4.20.1
    Uninstalling transformers-4.20.1:
      Successfully uninstalled transformers-4.20.1
  Attempting uninstall: datasets
    Found existing installation: datasets 2.3.2
    Uninstalling datasets-2.3.2:
      Successfully uninstalled datasets-2.3.2
Successfully installed accelerate-0.16.0 appdirs-1.4.4 datasets-2.9.0 deepspeed-0.8.0 distlib-0.3.6 docker-pycreds-0.4.0 hjson-3.1.0 huggingface-hub-0.12.1 ninja-1.11.1 numpy-1.24.2 pathtools-0.1.2 py-cpuinfo-9.0.0 sentry-sdk-1.15.0 setproctitle-1.3.2 transformers-4.24.0 urllib3-1.26.14 wandb-0.13.10
[0m

In [None]:
# !git clone https://github.com/justinphan3110/SciFive.git
# !cp -r SciFive/biot5x/data .
# !rm -r SciFive

In [2]:
import torch
from tqdm import tqdm

from transformers import pipeline, AutoTokenizer, DataCollatorForSeq2Seq
from datasets import load_dataset, Dataset

from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead, set_seed
from trl.core import LengthSampler

In [3]:
config = PPOConfig(
    model_name="justinphan3110/biot5_chemprot",
    # model_name="t5-base",
    learning_rate=1.41e-5,
    batch_size=1024,
    forward_batch_size=1024,
    eval_batch_size=512,
    input_length = 256,
    target_length = 5,
    metric = 'PRF1',
    ppo_epochs=1,
    init_kl_coef=0.0,
    log_with="tensorboard",
    remove_unused_columns=False,
    accelerator_kwargs={"logging_dir": "log"}
)
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": config.forward_batch_size}

In [4]:
# set seed before initializing value head for deterministic eval
set_seed(config.seed)

# Now let's build the model, the reference model, and the tokenizer.
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

In [5]:
def preprocess_function(examples, input_length=config.input_length, target_length=config.target_length):
    model_inputs = tokenizer(
        examples["inputs"], max_length=input_length, truncation=True, padding=True
    )
    
    with tokenizer.as_target_tokenizer():
      labels = tokenizer(
          examples["labels"], max_length=target_length, truncation=True, padding=True
      )

    
    model_inputs['labels'] = labels['input_ids']
    model_inputs['input_ids'] = model_inputs['input_ids']
    return model_inputs

In [6]:
input_column = "inputs"
target_column = "labels"
raw_datasets = {}
task = "chemprot"
id2label = {}
for line in open(f'data/{task}/label2id.tsv'):
    line = line.strip().split('\t')
    id_ = line[1]
    label = line[0]
    id2label[id_]=label

for file_ in ['train','test','dev']:
    with open(f"data/{task}/{file_}_blurb_text.tsv", "w") as out_file:
        with open(f"data/{task}/{file_}_blurb.tsv", "r") as file:
            for line in file:
                line = line.strip().split('\t')
                input_ = line[0]
                target = id2label[line[1]]
                out_file.write(f"{input_}\t{target}\n")
                


for file_ in ['train', 'dev', 'test']:
    inputs = []
    targets = []

    with open(f'data/{task}/{file_}_blurb_text.tsv', 'r', encoding='utf-8') as file:
        for i, line in enumerate(file):
            line = line.strip().split('\t')
            inputs.append(f'{line[0].strip()}')
            targets.append(f'{line[1].strip()}')
    
    
    dataset = Dataset.from_dict({input_column: inputs, target_column: targets})
    tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=['inputs'], num_proc=1)
    tokenized_datasets.set_format(type="torch")
    raw_datasets[file_] = tokenized_datasets

  0%|          | 0/19 [00:00<?, ?ba/s]



  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/16 [00:00<?, ?ba/s]

In [7]:
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")


## IMPORTANT: Need to be a multiple of batch size
train_datasets =  Dataset.from_dict(raw_datasets['train'][:512*20])

train_datasets.set_format(type="torch")
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=train_datasets, data_collator=data_collator)

device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
    
    
pipe = pipeline("text-classification")
classification_pipe = pipeline("text-classification", "justinphan3110/biolinkbert_chemprot", device=device)

No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at justinphan3110/biolinkbert_chemprot and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
ppo_trainer.evaluate(raw_datasets['test'])

  0%|          | 0/31 [00:00<?, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 31/31 [00:54<00:00,  1.77s/it]

{'precision': 78.2702, 'recall': 74.1399, 'F1': 76.1491}





In [None]:
# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.
output_min_length = 2
output_max_length = config.target_length
output_length_sampler = LengthSampler(output_min_length, output_max_length)
generation_kwargs = { "max_length": config.target_length}
dataloader = torch.utils.data.DataLoader(train_datasets, collate_fn=data_collator, batch_size=config.forward_batch_size)


for epoch in range(3):
    out_dir = f"out/test_trl_biot5_{task}/checkpoint_{epoch}"
    for batch in tqdm(ppo_trainer.dataloader):
        query_tensors = batch["input_ids"]
        label_tensors = batch["labels"]
        
        outputs = ppo_trainer.generate(query_tensors, **generation_kwargs)
        response_tensors = list(outputs)
        texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        labels = tokenizer.batch_decode(label_tensors, skip_special_tokens=True)

        sent_kwargs['function_to_apply'] = 'sigmoid'
        sent_kwargs['return_all_scores'] = True
        pipe_outputs = classification_pipe(texts, **sent_kwargs)
        
        rewards = []
        for t,output, label in zip(texts, pipe_outputs, labels):
            if label == t:
                if label == '0':
                    reward = 0.0
                else:
                    reward = 1.0
            else:
                if label == '0':
                    reward = 0.0
                else: 
                    reward = 0.0
            rewards.append(torch.tensor(reward).to(device))
        
        assert len(rewards) == len(labels) == len(texts)
        
        # print(rewards)
        #### Run PPO step
        stats = ppo_trainer.step(list(query_tensors), response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards)
        ppo_trainer.evaluate(raw_datasets['test'])
    # print(f"saving pretrained epoch {epoch} to {out_dir}")
    # ppo_trainer._save_pretrained(out_dir)
   


  0%|          | 0/31 [00:00<?, ?it/s][A
  3%|▎         | 1/31 [00:01<00:43,  1.46s/it][A
  6%|▋         | 2/31 [00:02<00:42,  1.46s/it][A
 10%|▉         | 3/31 [00:04<00:40,  1.44s/it][A
 13%|█▎        | 4/31 [00:06<00:45,  1.68s/it][A
 16%|█▌        | 5/31 [00:08<00:47,  1.82s/it][A
 19%|█▉        | 6/31 [00:10<00:47,  1.90s/it][A
 23%|██▎       | 7/31 [00:11<00:39,  1.63s/it][A
 26%|██▌       | 8/31 [00:13<00:40,  1.77s/it][A
 29%|██▉       | 9/31 [00:15<00:40,  1.86s/it][A
 32%|███▏      | 10/31 [00:17<00:40,  1.93s/it][A
 35%|███▌      | 11/31 [00:19<00:39,  1.98s/it][A
 39%|███▊      | 12/31 [00:21<00:38,  2.01s/it][A
 42%|████▏     | 13/31 [00:23<00:34,  1.93s/it][A
 45%|████▌     | 14/31 [00:25<00:31,  1.88s/it][A
 48%|████▊     | 15/31 [00:26<00:26,  1.69s/it][A
 52%|█████▏    | 16/31 [00:28<00:25,  1.73s/it][A
 55%|█████▍    | 17/31 [00:30<00:24,  1.76s/it][A
 58%|█████▊    | 18/31 [00:32<00:23,  1.78s/it][A
 61%|██████▏   | 19/31 [00:33<00:20,  1.68s/it]

{'precision': 78.2622, 'recall': 74.3149, 'F1': 76.2375}



  0%|          | 0/31 [00:00<?, ?it/s][A
  3%|▎         | 1/31 [00:01<00:43,  1.46s/it][A
  6%|▋         | 2/31 [00:02<00:42,  1.47s/it][A
 10%|▉         | 3/31 [00:04<00:40,  1.44s/it][A
 13%|█▎        | 4/31 [00:06<00:45,  1.69s/it][A
 16%|█▌        | 5/31 [00:08<00:47,  1.83s/it][A
 19%|█▉        | 6/31 [00:10<00:49,  1.98s/it][A
 23%|██▎       | 7/31 [00:11<00:40,  1.68s/it][A
 26%|██▌       | 8/31 [00:13<00:41,  1.80s/it][A
 29%|██▉       | 9/31 [00:15<00:41,  1.89s/it][A
 32%|███▏      | 10/31 [00:18<00:40,  1.95s/it][A
 35%|███▌      | 11/31 [00:20<00:39,  2.00s/it][A
 39%|███▊      | 12/31 [00:22<00:38,  2.02s/it][A
 42%|████▏     | 13/31 [00:23<00:34,  1.94s/it][A
 45%|████▌     | 14/31 [00:25<00:32,  1.88s/it][A
 48%|████▊     | 15/31 [00:26<00:27,  1.69s/it][A
 52%|█████▏    | 16/31 [00:28<00:25,  1.73s/it][A
 55%|█████▍    | 17/31 [00:30<00:24,  1.76s/it][A
 58%|█████▊    | 18/31 [00:32<00:23,  1.78s/it][A
 61%|██████▏   | 19/31 [00:33<00:20,  1.68s/it]

{'precision': 77.9945, 'recall': 74.6064, 'F1': 76.2629}



  0%|          | 0/31 [00:00<?, ?it/s][A
  3%|▎         | 1/31 [00:01<00:43,  1.46s/it][A
  6%|▋         | 2/31 [00:02<00:42,  1.45s/it][A
 10%|▉         | 3/31 [00:04<00:40,  1.44s/it][A
 13%|█▎        | 4/31 [00:06<00:45,  1.69s/it][A
 16%|█▌        | 5/31 [00:08<00:47,  1.83s/it][A
 19%|█▉        | 6/31 [00:10<00:47,  1.92s/it][A
 23%|██▎       | 7/31 [00:11<00:39,  1.65s/it][A
 26%|██▌       | 8/31 [00:13<00:41,  1.79s/it][A
 29%|██▉       | 9/31 [00:15<00:41,  1.88s/it][A
 32%|███▏      | 10/31 [00:17<00:41,  1.95s/it][A
 35%|███▌      | 11/31 [00:20<00:40,  2.00s/it][A
 39%|███▊      | 12/31 [00:22<00:38,  2.03s/it][A
 42%|████▏     | 13/31 [00:23<00:35,  1.95s/it][A
 45%|████▌     | 14/31 [00:25<00:32,  1.90s/it][A
 48%|████▊     | 15/31 [00:26<00:27,  1.71s/it][A
 52%|█████▏    | 16/31 [00:28<00:26,  1.75s/it][A
 55%|█████▍    | 17/31 [00:30<00:24,  1.78s/it][A
 58%|█████▊    | 18/31 [00:32<00:23,  1.80s/it][A
 61%|██████▏   | 19/31 [00:33<00:20,  1.70s/it]

{'precision': 77.6834, 'recall': 74.6939, 'F1': 76.1593}
