# Fine-Tune FLAN T5 with Reinforcement Learning (PPO) and PEFT to Generate Less-Toxic Summaries

Fine-tune a FLAN-T5 model to generate less toxic content by Facebook's hate speech reward model. The reward model is a binary classifier that predicts either "not hate" or "hate" for the given text. Proximal Policy Optimization will be used to fine-tune and reduce the model's toxicity.

In [1]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType

# trl: Transformer Reinforcement Learning library
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate

import numpy as np
import pandas as pd

from tqdm import tqdm
tqdm.pandas()

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
model_name = "google/flan-t5-base"
hf_dataset_name = "knkarthick/dialogsum" 

dataset_original = load_dataset(hf_dataset_name)
print(dataset_original)

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 12460
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 500
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'topic'],
        num_rows: 1500
    })
})


In [None]:
def build_dataset(model_name, dataset_name, input_min_text_length, input_max_text_length):
    
    # load dataset (only train part)
    dataset = load_dataset(dataset_name, split="train")

    # Filter the dialogues of length betwee input_min_text_length and input_max_text_length characters.
    # dataset = dataset.filter(lambda x: len(x['dialogue']) > input_min_text_length and len(x['dialogue']) <= input_max_text_length, batched=False)
    dataset = dataset.filter(lambda examples: [len(example) > input_min_text_length and len(example) <= input_max_text_length for example in examples['dialogue']], batched=True)

    return dataset

dataset = build_dataset(model_name, hf_dataset_name, 200, 1000)

print(dataset)

Filter: 100%|██████████| 12460/12460 [00:00<00:00, 154732.41 examples/s]

Dataset({
    features: ['id', 'dialogue', 'summary', 'topic'],
    num_rows: 10022
})



