In [None]:
!git config --global --add safe.directory . && pip install -e .

In [None]:
# !git clone https://github.com/justinphan3110/SciFive.git
# !cp -r SciFive/biot5x/data .
# !rm -r SciFive

In [1]:
import os
import pathlib
from typing import List

import torch
import yaml
from datasets import Dataset
from transformers import pipeline

import trlx
from trlx.data.configs import TRLConfig
from trlx import metric as trlx_metric

In [2]:
config_path = pathlib.Path("configs/biot5_ppo_config.yml")
with config_path.open() as f:
    default_config = yaml.safe_load(f)

In [3]:
task = 'chemprot'

raw_datasets = {}
for file_ in ['train', 'dev', 'test']:
    inputs = []
    targets = []

    with open(f'data/{task}/{file_}_blurb.tsv', 'r', encoding='utf-8') as file:
        for i, line in enumerate(file):
            line = line.strip().split('\t')
            inputs.append(f'{line[0].strip()}')
            targets.append(f'{line[1].strip()}')
    
    
    dataset = Dataset.from_dict({"prompts" : inputs, "labels": targets})
    raw_datasets[file_] = dataset


In [4]:
config = TRLConfig.update(default_config, {})
from sklearn.metrics import classification_report

if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = -1

train_prompts = raw_datasets['train']['prompts']
eval_prompts = raw_datasets['test']['prompts']
eval_labels = raw_datasets['test']['labels']

def reward_fn(samples: List[str], **kwargs) -> List[float]:
    return [1] * len(samples)


trlx_metric.map_name_to_metric_function("PRF1")
def metric_fn(outputs: List[str], **kwargs):
    metric = trlx_metric.map_name_to_metric_function("PRF1")
    assert len(eval_labels) == len(outputs)
    
    with open('log.txt', 'w') as file:
        for o in outputs:
            file.write(f'{o}\n')
    result = metric(targets=eval_labels, predictions=outputs)
    print(classification_report(eval_labels,outputs))
    
    print(result)
    assert False
    return result
    
trlx.train(
    reward_fn=reward_fn,
    metric_fn=metric_fn,
    prompts=train_prompts,
    eval_prompts=eval_prompts,
    config=config,
)

[RANK 0] Initializing model: justinphan3110/biot5_chemprot
[RANK 0] Collecting rollouts
You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  scores = torch.tensor(all_scores[0])
[RANK 0] Starting training
[RANK 0] Evaluating model


[generation sweep 0/1 | eval batch 0/31]:   0%|          | 0/31 [00:00<?, ?it/s]

[RANK 0] Computing rewards
[RANK 0] Computing metrics


              precision    recall  f1-score   support

           0       0.92      0.96      0.94     12315
           1       0.76      0.60      0.67       663
           2       0.83      0.75      0.79      1655
           3       0.80      0.70      0.75       178
           4       0.87      0.73      0.80       292
           5       0.67      0.51      0.58       642

    accuracy                           0.90     15745
   macro avg       0.81      0.71      0.75     15745
weighted avg       0.89      0.90      0.89     15745

{'precision': 0.7942287873582962, 'recall': 0.6740524781341107, 'F1': 0.7292225201072385}


In [None]:
           0       0.94      0.95      0.94     12315
           1       0.75      0.75      0.75       663
           2       0.82      0.82      0.82      1655
           3       0.79      0.76      0.77       178
           4       0.88      0.79      0.83       292
           5       0.69      0.60      0.64       642

    accuracy                           0.91     15745
   macro avg       0.81      0.78      0.79     15745
weighted avg       0.91      0.91      0.91     15745

In [None]:
              precision    recall  f1-score   support

           0       0.92      0.96      0.94     12315
           1       0.76      0.60      0.67       663
           2       0.83      0.75      0.79      1655
           3       0.80      0.70      0.75       178
           4       0.87      0.73      0.80       292
           5       0.67      0.51      0.58       642

    accuracy                           0.90     15745
   macro avg       0.81      0.71      0.75     15745
weighted avg       0.89      0.90      0.89     15745