# T5 tuning & inference using Reinforcement Learning




All the finetuning before was a supervised one. That is, we tried to tune the models to give a response close to the one we provided them. The problem is -- this paraphrase was an output from other paraphraser. This might shadow the true potential of the considered model.
Another problem is that the loss function does not have any idea about the task we are trying to solve -- it just punishes the model for creating an output different from the one expected. It would be nice to specify directly to the model, what we want to get.

Both these problems might be solved by an unusual method -- Reinforcement Learning. Thanks to the library [Transformer Reinforcement Learning](https://huggingface.co/docs/trl/index), I don't have to implement it from scratch. 

On top of the described features, if one decide to deploy this algorithm to the production, using RL would allow to implement Reinforcement Learning from [Human Feedback](https://huggingface.co/blog/rlhf).

To implement this approach, let's take T5-small model once again and play around

In [1]:
!pip install -q datasets transformers[sentencepiece] sacrebleu
!pip install -q evaluate
!pip install -q langchain
!pip install -q sentence-transformers
# Note that the code does not work properly with up-to-date version, see: https://github.com/huggingface/trl/issues/679
!pip install -q git+http://github.com/huggingface/trl.git@b6b593d61af124bfda111c1cb1370d9166236ea5

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# Necessary inputs
import warnings

import gc
from datasets import load_dataset, load_metric
import transformers
import datasets
import random
import pandas as pd
import numpy as np
from IPython.display import display, HTML
from tqdm import tqdm
from tqdm.auto import trange
tqdm.pandas()

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, \
                         DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, \
                         RobertaTokenizer, RobertaForSequenceClassification
import wandb
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, set_seed

import evaluate
warnings.filterwarnings('ignore')
transformers.logging.set_verbosity_warning()
transformers.logging.set_verbosity_error()



Very convinient dashboard:

In [6]:
wandb.init()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


# Part 1. Labelled tuning

## Loading the dataset

In [7]:
transformers.set_seed(42)
dataset = datasets.load_dataset("domrachev03/toxic_comments_subset")
metric = load_metric("sacrebleu")

Downloading and preparing dataset parquet/domrachev03--toxic_comments_subset to /root/.cache/huggingface/datasets/parquet/domrachev03--toxic_comments_subset-480257536bff0e3f/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/15.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.75M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/parquet/domrachev03--toxic_comments_subset-480257536bff0e3f/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading builder script:   0%|          | 0.00/2.85k [00:00<?, ?B/s]

In [8]:
n_train = 10000
n_val = 1000
ds_subset = dataset['train'].select(range(n_train + n_val))

dataset['train'] = ds_subset.select(range(n_train)) 
dataset['val'] = ds_subset.select(range(n_train, n_train+n_val))

In [9]:
max_input_length = 64
max_target_length = 64
batch_size = 128

In [10]:
# simple postprocessing for textgithub
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


In [40]:
def cleanup():
    '''Clean the RAM and VRAM from the trash'''

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def get_toxicity(preds, batch_size=1, device=None):
    '''Calculates toxicity of the corpus using RoBerta finetuned on toxicity classification.'''

    results = []

    model_name = 'SkolkovoInstitute/roberta_toxicity_classifier'

    tokenizer = RobertaTokenizer.from_pretrained(model_name)
    model = RobertaForSequenceClassification.from_pretrained(model_name)
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else device
    model.to(device)

    model.eval()
    for i in tqdm.tqdm(range(0, len(preds), batch_size)):
        batch = tokenizer(preds[i:i + batch_size], return_tensors='pt', max_length=-1, padding=True).to(device)

        with torch.no_grad():
            logits = model(**batch).logits
            out = torch.softmax(logits, -1)[:, 1].cpu().numpy()
            results.append(out)
    return 1 - np.concatenate(results)


def get_sacrebleu(inputs, preds):
    '''Calculates sacrebleu score for the inputs and predictions'''

    metric = evaluate.load("sacrebleu")

    result = metric.compute(predictions=preds, references=inputs)
    return result['score']


def get_fluency(preds, soft=False, batch_size=1, device=None):
    '''Calculates fluency of the corpus using RoBerta finetuned on CoLa dataset.'''

    model_name = 'cointegrated/roberta-large-cola-krishna2020'

    model = RobertaForSequenceClassification.from_pretrained(model_name)
    tokenizer = RobertaTokenizer.from_pretrained(model_name)
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else device
    device = device
    model.to(device)

    results = []
    for i in trange(0, len(preds), batch_size):
        batch = [t for t in preds[i: i + batch_size]]
        inputs = tokenizer(batch, max_length=-1, padding=True, return_tensors='pt').to(device)
        with torch.no_grad():
            out = torch.softmax(model(**inputs).logits, -1)[:, 0].cpu().numpy()
            results.append(out)
    return np.concatenate(results)


def compute_metrics(eval_preds, tokenizer=None, print_results=False, batch_size=1, device='cuda', model_name=""):
    ''' Computing metrics for the given data

    Parameters:
    eval_preds=(preds, labels) -- tuple with predictions and labels
    tokenzier -- the tokenizer for the sequence. Default: None, and sequence is treated as decoded one
    model_name -- optional model name for fancy output printing. Defaul: empty'''

    preds, labels = eval_preds

    if tokenizer is not None:
        detokenized_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        filtered_labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        detokenized_labels = tokenizer.batch_decode(filtered_labels, skip_special_tokens=True)
    else:
        detokenized_preds = preds
        detokenized_labels = labels

    results = {}
    results['toxic'] = get_toxicity(detokenized_preds, batch_size=batch_size, device=device)
    results['avg_toxic'] = sum(results['toxic']) / len(results['toxic'])
    cleanup()

    results['bleu'] = get_sacrebleu(detokenized_labels, detokenized_preds) / 100
    cleanup()

    results['fluency'] = get_fluency(detokenized_preds, batch_size=batch_size, device=device)
    results['avg_fluency'] = sum(results['fluency']) / len(results['fluency'])
    cleanup()

    # count metrics
    results['joint'] = sum(results['toxic'] * results['bleu'] * results['fluency']) / len(preds)
    if print_results:
        if model_name != "":
            print("--------------")
            print(model_name)
        print("--------------")
        print("Metric   | Value")
        print("--------------")
        print(f"toxic    | {results['avg_toxic']:.2f}")
        print(f"bleu (n) | {results['bleu']:.2f}")
        print(f"fluency  | {results['avg_fluency']:.2f}")
        print("===============")
        print(f"Total    | {results['joint']:.2f}")
        print("--------------")
    return results


100%|██████████| 3/3 [00:00<00:00, 84.16it/s]


  0%|          | 0/3 [00:00<?, ?it/s]

--------------
Metric   | Value
--------------
toxic    | 0.33
bleu (n) | 0.09
fluency  | 0.89
Total    | 0.03
--------------


# T5-small

## Selecting the model
The baseline model is [`t5-small`](https://huggingface.co/docs/transformers/model_doc/t5) transformer. This acritecture seems suitable for our task, since: "*T5 works well on a variety of tasks out-of-the-box by prepending a different prefix to the input corresponding to each task, e.g., for translation: translate English to German: ..., for summarization: summarize: ...*"

In [12]:
# selecting model checkpoint
model_type = "t5-small"

The model utilizes pretrained tokenizer

In [13]:
t5_tokenizer = AutoTokenizer.from_pretrained(model_type, )

def t5_wrapper(text):
    wrapped = f"Make the following sentence non-toxic: '{text}'"
    return wrapped

def t5_preprocess_function(examples):
    inputs = t5_wrapper(examples["reference"])
    targets = examples["translation"]
    
    model_inputs = t5_tokenizer(inputs, max_length=max_input_length, truncation=True)
    labels = t5_tokenizer(targets, max_length=max_target_length, truncation=True)

    return model_inputs


Downloading (…)okenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

In [14]:
t5_ds_train = dataset['train'].map(
    t5_preprocess_function, 
    batched=False,
    load_from_cache_file=False,
)
t5_ds_train.set_format(type="torch")

  0%|          | 0/10000 [00:00<?, ?ex/s]

## T5 reinforcement learning

The `huggingface` provides a software for performing a [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index). However, it seems like this approach is highly uninvestigated (I've managed to find a few systematic reviews on the area). Hence, I decided to give it a try on a given task

## The metric

My approach of choosing the reward function was based on different combinations of these parameters:
1. *Toxicity*. It was evaluated using the RoBerta, finetuned for toxicity classification ([link](https://huggingface.co/s-nlp/roberta_toxicity_classifier_v1))
2. *Similarity*. It was evaluated using either [sacre bleu](https://huggingface.co/spaces/evaluate-metric/sacrebleu) metric or using cosine similarity metric.
3. *Fluency*. The fluency was evaluated via RoBerta, fituned on evaluation of language fluency, provided by the authors of the paper ([link](https://huggingface.co/cointegrated/roberta-large-cola-krishna2020)).

I've tried many different metrics, this one impelents:
$$ R(p, l) =
  \left\{
  \begin{aligned}
    \text{toxic}(p) + \text{fluency}(p) - 20, \qquad& \text{bleu}(p, l) < 0.2, \\
    \text{toxic}(p) + \text{fluency}(p), \qquad& otherwise  .
  \end{aligned}
  \right.
$$
and would be considered in all further evaluations.

In [15]:
from sentence_transformers import util

def get_toxicity_reward(preds, batch_size=1):
    results = []

    model_name = 'SkolkovoInstitute/roberta_toxicity_classifier'

    tokenizer = RobertaTokenizer.from_pretrained(model_name)
    model = RobertaForSequenceClassification.from_pretrained(model_name)
    device = 'cuda'
    model.to(device)

    model.eval()
    for i in tqdm(range(0, len(preds), batch_size)):
        batch = tokenizer(preds[i:i + batch_size], return_tensors='pt', max_length=-1, padding=True).to(device)

        with torch.no_grad():
            logits = model(**batch).logits
            out = logits[:, 0]
            results.append(out)
    return torch.concatenate(results)

# def get_cosine_similarity_reward(inputs, preds):
#     return np.array([util.pytorch_cos_sim(in_i.type(torch.float64), pred_i.type(torch.float64)) for in_i, pred_i in zip(inputs, preds)])

def get_sacrebleu(inputs, preds):
    metric = evaluate.load("sacrebleu")

    results = []
    for i in range(len(inputs)):
        results.append(metric.compute(predictions=preds[i:i+1], references=inputs[i:i+1])['score'])
    return torch.tensor(results).to('cuda')


def get_fluency_reward(preds, batch_size=1):
    path = 'cointegrated/roberta-large-cola-krishna2020'

    model = RobertaForSequenceClassification.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path)
    device = 'cuda'
    model.to(device)

    results = []
    for i in trange(0, len(preds), batch_size):
        batch = [t for t in preds[i: i + batch_size]]
        inputs = tokenizer(batch, max_length=-1, padding=True, return_tensors='pt').to(device)
        with torch.no_grad():
            out = model(**inputs).logits[:, 0]
            results.append(out)
    return torch.concatenate(results)


def compute_reward(preds, labels, batch_size=1):
    toxicity = get_toxicity_reward(preds, batch_size=batch_size)
    cleanup()

    bleu = torch.where(get_sacrebleu(labels, preds) < 0.2, -20, 0)
    cleanup()

    fluency = get_fluency_reward(preds, batch_size=batch_size)
    cleanup()
    return toxicity + bleu + fluency


In [16]:
compute_reward(
    ['fuck you, bitch, I love you!', '<extra_id_0>. Retest:</s>'], 
    ['fuck you, bitch, I hate you!', 'I Love you so much I could not express it!'], 
    batch_size=2
)

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/501M [00:00<?, ?B/s]

100%|██████████| 1/1 [00:01<00:00,  1.96s/it]


Downloading builder script:   0%|          | 0.00/8.15k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/628 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/289 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

  0%|          | 0/1 [00:00<?, ?it/s]

tensor([ -2.0337, -14.2195], device='cuda:0')

## The train process

In [17]:
config = PPOConfig(
    model_name="t5-small",
    learning_rate=1.41e-5,
    ppo_epochs=4,
    batch_size=128,
    log_with='wandb',
)

set_seed(42)


In [18]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [19]:
t5_ds_train

Dataset({
    features: ['reference', 'translation', 'similarity', 'lenght_diff', 'ref_tox', 'trn_tox', 'input_ids', 'attention_mask'],
    num_rows: 10000
})

In [20]:
cols_to_remove = ['reference', 'translation', 'similarity', 'lenght_diff', 'ref_tox', 'trn_tox']

t5_clean_ds = t5_ds_train.remove_columns(cols_to_remove)

t5_clean_ds

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 10000
})

In [21]:
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)


Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/242M [00:00<?, ?B/s]

In [22]:
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=t5_clean_ds, data_collator=collator)

In [23]:
generation_kwargs = {
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "min_new_tokens": 4
}

stat_logs = []
batch_log = []
rewards_log = []
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader), total=len(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    #### Get response from t5
    response_tensors = []
    for query in query_tensors:
        response = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append(response.squeeze())
    batch["response"] = [tokenizer.decode(r[1:].squeeze()) for r in response_tensors]
    
    # response = ppo_trainer.generate(query_tensors, **generation_kwargs)  
    # batch["response"] = [tokenizer.decode(r[1:]) for r in response]
    batch['query'] = [tokenizer.decode(inp) for inp in batch['input_ids']]
    batch['query'] = [q[q.find('\'')+1: q.rfind('\'')] for q in batch['query']]
    rewards = compute_reward(batch['response'], batch['query'], batch_size=batch_size)
    rewards = [torch.tensor([reward]) for reward in rewards]

    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

  0%|          | 0/79 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 13.93it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

  1%|▏         | 1/79 [01:08<1:29:12, 68.62s/it]
100%|██████████| 1/1 [00:00<00:00, 14.10it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

  3%|▎         | 2/79 [02:17<1:28:16, 68.79s/it]
100%|██████████| 1/1 [00:00<00:00, 15.07it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

  4%|▍         | 3/79 [03:23<1:25:27, 67.47s/it]
100%|██████████| 1/1 [00:00<00:00, 13.88it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

  5%|▌         | 4/79 [04:31<1:24:38, 67.71s/it]
100%|██████████| 1/1 [00:00<00:00, 16.61it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

  6%|▋         | 5/79 [05:39<1:23:30, 67.72s/it]
100%|██████████| 1/1 [00:00<00:00, 14.93it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

  8%|▊         | 6/79 [06:45<1:21:56, 67.35s/it]
100%|██████████| 1/1 [00:00<00:00, 15.92it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

  9%|▉         | 7/79 [07:53<1:21:05, 67.58s/it]
100%|██████████| 1/1 [00:00<00:00, 13.64it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 10%|█         | 8/79 [09:01<1:20:05, 67.68s/it]
100%|██████████| 1/1 [00:00<00:00, 14.90it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 11%|█▏        | 9/79 [10:08<1:18:41, 67.45s/it]
100%|██████████| 1/1 [00:00<00:00, 14.34it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 13%|█▎        | 10/79 [11:16<1:17:41, 67.56s/it]
100%|██████████| 1/1 [00:00<00:00, 15.01it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 14%|█▍        | 11/79 [12:23<1:16:24, 67.41s/it]
100%|██████████| 1/1 [00:00<00:00, 16.29it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 15%|█▌        | 12/79 [13:28<1:14:22, 66.60s/it]
100%|██████████| 1/1 [00:00<00:00, 16.60it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 16%|█▋        | 13/79 [14:34<1:13:00, 66.37s/it]
100%|██████████| 1/1 [00:00<00:00, 16.76it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 18%|█▊        | 14/79 [15:36<1:10:28, 65.06s/it]
100%|██████████| 1/1 [00:00<00:00, 17.32it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 19%|█▉        | 15/79 [16:37<1:08:17, 64.02s/it]
100%|██████████| 1/1 [00:00<00:00, 16.97it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 20%|██        | 16/79 [17:40<1:06:40, 63.50s/it]
100%|██████████| 1/1 [00:00<00:00, 17.06it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 22%|██▏       | 17/79 [18:40<1:04:45, 62.66s/it]
100%|██████████| 1/1 [00:00<00:00, 16.91it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 23%|██▎       | 18/79 [19:42<1:03:29, 62.44s/it]
100%|██████████| 1/1 [00:00<00:00, 17.02it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 24%|██▍       | 19/79 [20:42<1:01:31, 61.53s/it]
100%|██████████| 1/1 [00:00<00:00, 16.76it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 25%|██▌       | 20/79 [21:45<1:00:56, 61.97s/it]
100%|██████████| 1/1 [00:00<00:00, 18.08it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 27%|██▋       | 21/79 [22:46<59:35, 61.65s/it]  
100%|██████████| 1/1 [00:00<00:00, 17.17it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 28%|██▊       | 22/79 [23:49<58:59, 62.09s/it]
100%|██████████| 1/1 [00:00<00:00, 16.51it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 29%|██▉       | 23/79 [24:51<57:58, 62.12s/it]
100%|██████████| 1/1 [00:00<00:00, 16.92it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 30%|███       | 24/79 [25:54<57:04, 62.27s/it]
100%|██████████| 1/1 [00:00<00:00, 16.90it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 32%|███▏      | 25/79 [26:56<56:09, 62.39s/it]
100%|██████████| 1/1 [00:00<00:00, 16.77it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 33%|███▎      | 26/79 [27:58<54:59, 62.25s/it]
100%|██████████| 1/1 [00:00<00:00, 17.81it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 34%|███▍      | 27/79 [28:58<53:18, 61.52s/it]
100%|██████████| 1/1 [00:00<00:00, 18.09it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 35%|███▌      | 28/79 [29:57<51:41, 60.81s/it]
100%|██████████| 1/1 [00:00<00:00, 17.26it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 37%|███▋      | 29/79 [30:57<50:29, 60.58s/it]
100%|██████████| 1/1 [00:00<00:00, 17.81it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 38%|███▊      | 30/79 [31:58<49:28, 60.58s/it]
100%|██████████| 1/1 [00:00<00:00, 16.53it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 39%|███▉      | 31/79 [32:58<48:24, 60.50s/it]
100%|██████████| 1/1 [00:00<00:00, 18.00it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 41%|████      | 32/79 [33:58<47:13, 60.29s/it]
100%|██████████| 1/1 [00:00<00:00, 18.79it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 42%|████▏     | 33/79 [34:55<45:34, 59.45s/it]
100%|██████████| 1/1 [00:00<00:00, 19.59it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 43%|████▎     | 34/79 [35:51<43:40, 58.24s/it]
100%|██████████| 1/1 [00:00<00:00, 19.01it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 44%|████▍     | 35/79 [36:48<42:32, 58.01s/it]
100%|██████████| 1/1 [00:00<00:00, 18.62it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 46%|████▌     | 36/79 [37:47<41:47, 58.31s/it]
100%|██████████| 1/1 [00:00<00:00, 18.31it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 47%|████▋     | 37/79 [38:44<40:28, 57.82s/it]
100%|██████████| 1/1 [00:00<00:00, 18.67it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 48%|████▊     | 38/79 [39:42<39:39, 58.03s/it]
100%|██████████| 1/1 [00:00<00:00, 18.59it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 49%|████▉     | 39/79 [40:41<38:46, 58.17s/it]
100%|██████████| 1/1 [00:00<00:00, 19.16it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 51%|█████     | 40/79 [41:37<37:26, 57.59s/it]
100%|██████████| 1/1 [00:00<00:00, 19.21it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 52%|█████▏    | 41/79 [42:33<36:12, 57.16s/it]
100%|██████████| 1/1 [00:00<00:00, 18.57it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 53%|█████▎    | 42/79 [43:29<34:56, 56.66s/it]
100%|██████████| 1/1 [00:00<00:00, 18.91it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 54%|█████▍    | 43/79 [44:26<34:09, 56.93s/it]
100%|██████████| 1/1 [00:00<00:00, 18.81it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 56%|█████▌    | 44/79 [45:22<33:01, 56.62s/it]
100%|██████████| 1/1 [00:00<00:00, 19.50it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 57%|█████▋    | 45/79 [46:18<31:58, 56.42s/it]
100%|██████████| 1/1 [00:00<00:00, 19.15it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 58%|█████▊    | 46/79 [47:13<30:45, 55.92s/it]
100%|██████████| 1/1 [00:00<00:00, 18.84it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 59%|█████▉    | 47/79 [48:09<29:50, 55.96s/it]
100%|██████████| 1/1 [00:00<00:00, 20.38it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 61%|██████    | 48/79 [49:04<28:48, 55.77s/it]
100%|██████████| 1/1 [00:00<00:00, 19.66it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 62%|██████▏   | 49/79 [50:00<27:51, 55.73s/it]
100%|██████████| 1/1 [00:00<00:00, 20.21it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 63%|██████▎   | 50/79 [50:57<27:04, 56.00s/it]
100%|██████████| 1/1 [00:00<00:00, 19.52it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 65%|██████▍   | 51/79 [51:52<26:00, 55.73s/it]
100%|██████████| 1/1 [00:00<00:00, 19.48it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 66%|██████▌   | 52/79 [52:48<25:08, 55.89s/it]
100%|██████████| 1/1 [00:00<00:00, 19.70it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 67%|██████▋   | 53/79 [53:43<24:05, 55.58s/it]
100%|██████████| 1/1 [00:00<00:00, 19.90it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 68%|██████▊   | 54/79 [54:38<23:03, 55.34s/it]
100%|██████████| 1/1 [00:00<00:00, 18.91it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 70%|██████▉   | 55/79 [55:34<22:17, 55.71s/it]
100%|██████████| 1/1 [00:00<00:00, 19.40it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 71%|███████   | 56/79 [56:29<21:15, 55.46s/it]
100%|██████████| 1/1 [00:00<00:00, 19.84it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 72%|███████▏  | 57/79 [57:25<20:24, 55.64s/it]
100%|██████████| 1/1 [00:00<00:00, 19.95it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 73%|███████▎  | 58/79 [58:20<19:23, 55.42s/it]
100%|██████████| 1/1 [00:00<00:00, 21.47it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 75%|███████▍  | 59/79 [59:14<18:18, 54.92s/it]
100%|██████████| 1/1 [00:00<00:00, 19.17it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 76%|███████▌  | 60/79 [1:00:10<17:31, 55.36s/it]
100%|██████████| 1/1 [00:00<00:00, 21.44it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 77%|███████▋  | 61/79 [1:01:07<16:42, 55.68s/it]
100%|██████████| 1/1 [00:00<00:00, 19.83it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 78%|███████▊  | 62/79 [1:02:01<15:38, 55.20s/it]
100%|██████████| 1/1 [00:00<00:00, 19.39it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 80%|███████▉  | 63/79 [1:02:57<14:48, 55.55s/it]
100%|██████████| 1/1 [00:00<00:00, 21.74it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 81%|████████  | 64/79 [1:03:52<13:49, 55.29s/it]
100%|██████████| 1/1 [00:00<00:00, 21.19it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 82%|████████▏ | 65/79 [1:04:47<12:54, 55.31s/it]
100%|██████████| 1/1 [00:00<00:00, 19.42it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 84%|████████▎ | 66/79 [1:05:43<12:01, 55.52s/it]
100%|██████████| 1/1 [00:00<00:00, 20.76it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 85%|████████▍ | 67/79 [1:06:37<10:59, 54.93s/it]
100%|██████████| 1/1 [00:00<00:00, 19.97it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 86%|████████▌ | 68/79 [1:07:33<10:08, 55.33s/it]
100%|██████████| 1/1 [00:00<00:00, 18.73it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 87%|████████▋ | 69/79 [1:08:30<09:17, 55.73s/it]
100%|██████████| 1/1 [00:00<00:00, 22.14it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 89%|████████▊ | 70/79 [1:09:23<08:14, 54.98s/it]
100%|██████████| 1/1 [00:00<00:00, 20.25it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 90%|████████▉ | 71/79 [1:10:18<07:20, 55.02s/it]
100%|██████████| 1/1 [00:00<00:00, 20.39it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 91%|█████████ | 72/79 [1:11:13<06:24, 54.96s/it]
100%|██████████| 1/1 [00:00<00:00, 19.81it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 92%|█████████▏| 73/79 [1:12:08<05:30, 55.11s/it]
100%|██████████| 1/1 [00:00<00:00, 20.49it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 94%|█████████▎| 74/79 [1:13:04<04:36, 55.38s/it]
100%|██████████| 1/1 [00:00<00:00, 21.50it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 95%|█████████▍| 75/79 [1:14:00<03:41, 55.37s/it]
100%|██████████| 1/1 [00:00<00:00, 22.10it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 96%|█████████▌| 76/79 [1:14:55<02:46, 55.33s/it]
100%|██████████| 1/1 [00:00<00:00, 21.22it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 97%|█████████▋| 77/79 [1:15:50<01:50, 55.40s/it]
100%|██████████| 1/1 [00:00<00:00, 20.57it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 99%|█████████▊| 78/79 [1:16:44<00:54, 54.87s/it]
100%|██████████| 1/1 [00:00<00:00, 50.56it/s]


  0%|          | 0/1 [00:00<?, ?it/s]

 99%|█████████▊| 78/79 [1:16:53<00:59, 59.15s/it]


ValueError: Batch size (128) does not match number of examples - but got 16 for: queries

In [34]:
model.save_pretrained("t5_rl_best", push_to_hub=False)

## Evaluation

In [38]:
model.from_pretrained('domrachev03/t5_rl_detox')
model.eval()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

results = []
msgs = [t5_wrapper(text) for text in dataset["test"]["reference"]]
for i in trange(0, len(msgs), batch_size):
    batch = msgs[i: i+batch_size]
    tokenized_batch = tokenizer(batch, max_length=128, padding=True, return_tensors='pt').to(device)
    with torch.no_grad():
        output = model.generate(tokenized_batch['input_ids'])
        result = [tokenizer.decode(out_i, skip_special_tokens=True,temperature=0) for out_i in output]
    results.extend(result)

  0%|          | 0/136 [00:00<?, ?it/s]

In [41]:
compute_metrics((results, dataset['test']['translation']), print_results=True, batch_size=batch_size)

100%|██████████| 136/136 [00:14<00:00,  9.27it/s]


  0%|          | 0/136 [00:00<?, ?it/s]

--------------
Metric   | Value
--------------
toxic    | 0.93
bleu (n) | 0.03
fluency  | 0.49
Total    | 0.01
--------------


{'toxic': array([0.99982816, 0.99994403, 0.9999064 , ..., 0.99994403, 0.99994403,
        0.99982816], dtype=float32),
 'avg_toxic': 0.9251211043690382,
 'bleu': 0.02733905992001097,
 'fluency': array([0.06736284, 0.9749654 , 0.06025813, ..., 0.9749654 , 0.9749654 ,
        0.06736284], dtype=float32),
 'avg_fluency': 0.48509223179666894,
 'joint': 0.011723955617528401}

## Saving

In [None]:
model.save_pretrained('domrachev03/t5_rl_detox')

In [42]:
import pandas as pd

preds_dict = pd.DataFrame([[orig_i, pred_i, label_i] for orig_i, pred_i, label_i in zip(dataset['test']['reference'], results, dataset['test']['translation'])], columns=['Query', 'Predictions', 'Labels'])

preds_dict.head(10)

Unnamed: 0,Query,Predictions,Labels
0,"If they say they don't want your help, you can...",the following sentence non-toxic.,"if they say they don't need help, you can bet ..."
1,"When you go home, they'il call you ""Mr. No-Leg...",non-toxic.,"when you get home, they'll call you Mr."
2,Mrs. Sanders ain't goingfor that late shit no ...,the sentence non-toxic.,Mrs. Sanders is not curious about being late.
3,I want every bloody one of them to hear.,the sentence non-toxic.,I want everyone to hear it.
4,That shit hurt.,.,and that really hurts.
5,How do I know if I tell you what you need to k...,non-toxic.,how do I know if I tell you what you want to k...
6,"This guy ain't going nowhere. Oh, shit!",the following sentence non-toxic!,this guy's not going anywhere.
7,Fucked if I can remember.,.,if only I could remember.
8,I fucked one stitch and he looks at me now as ...,'I made the following sentence non-toxic: I ma...,"I mess up a suture, and now he's looking at me..."
9,Earth! Meet my lovely assistant... Tiny tits.,the following sentence non-toxic.,"please welcome my beautiful assistant, Maloprs..."


In [43]:
preds_dict.to_csv('t5_rl_test.csv')