In [2]:
from datasets import DatasetDict, Dataset
from transformers import (
    RobertaTokenizer,
    RobertaForSequenceClassification,
    T5Tokenizer,
    T5ForConditionalGeneration,
    T5Config,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    GenerationConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    EarlyStoppingCallback,
    pipeline,
)
from sentence_transformers import SentenceTransformer
import torch
from torch import nn
import numpy as np
import time
import gc
import GPUtil
import evaluate
from numba import cuda
import wandb
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from functools import partial
import wandb
import os
import pickle
import optuna
from typing import Dict, Union, Optional, Tuple, List, Any
import pandas as pd

In [3]:
# Random seed for reproducibility
RANDOM_SEED = 42

# Default parameters for T5 model fine-tuning
PER_DEVICE_TRAIN_BATCH_SIZE = 64
PER_DEVICE_EVAL_BATCH_SIZE = 128
LEARNING_RATE = 3e-4
NUM_TRAIN_EPOCHS = 20
EARLY_STOPPING_PATIENCE = 2
NUM_BEAMS = 4

# Include BLEURT score in evaluation
INCLUDE_BLEURT = True

# Setting the DEVICE to cuda
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set path for profane word list
PROFANE_WORD_PATH = "../data/raw/en.txt"

# Set path for raw dataset dictionary
RAW_DATASET_PATH = "../data/processed/raw_dataset.pkl"
AUG_DATASET_ALL_FILTERS_PATH = "../data/processed/aug_datasets_all_filters"
AUG_DATASET_NO_TOXICITY_FILTER_PATH = "../data/processed/aug_datasets_no_toxicity_filter"
AUG_DATASET_NO_SIMILARITY_FILTER_PATH = "../data/processed/aug_datasets_no_similarity_filter"
AUG_DATASET_NO_ACCEPTABILITY_FILTER_PATH = "../data/processed/aug_datasets_no_acceptability_filter"

# Set path for txt file containing best model checkpoints
BEST_MODEL_CHECKPOINT_PATH = "../models/best_model_checkpoints.txt"

# Set path to save evaluation outputs to
EVAL_PREDS_PATH = "../data/final/model_preds.csv"
EVAL_METRICS_PATH = "../data/final/model_metrics.csv"

# Set maximum length for input and output
MAX_INPUT_LENGTH = 64
MAX_OUTPUT_LENGTH = 64

In [None]:
# Load tokenizers and models
tokenizer_t5_small = T5Tokenizer.from_pretrained("t5-small")
model_t5_small = T5ForConditionalGeneration.from_pretrained("t5-small").to(DEVICE)
tokenizer_toxicity = RobertaTokenizer.from_pretrained("SkolkovoInstitute/roberta_toxicity_classifier")
model_toxicity = RobertaForSequenceClassification.from_pretrained("SkolkovoInstitute/roberta_toxicity_classifier").to(DEVICE)
tokenizer_acceptability = AutoTokenizer.from_pretrained("iproskurina/tda-bert-en-cola")
model_acceptability = AutoModelForSequenceClassification.from_pretrained("iproskurina/tda-bert-en-cola").to(DEVICE)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Some weights of the model checkpoint at SkolkovoInstitute/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassifi

In [4]:
# Load datasets
raw_datasets = DatasetDict.load_from_disk(RAW_DATASET_PATH)
aug_datasets_all_filters = DatasetDict.load_from_disk(AUG_DATASET_ALL_FILTERS_PATH)
aug_datasets_no_acceptability_filter = DatasetDict.load_from_disk(AUG_DATASET_NO_ACCEPTABILITY_FILTER_PATH)
aug_datasets_no_similarity_filter = DatasetDict.load_from_disk(AUG_DATASET_NO_SIMILARITY_FILTER_PATH)
aug_datasets_no_toxicity_filter = DatasetDict.load_from_disk(AUG_DATASET_NO_TOXICITY_FILTER_PATH)

# Functions

## Debugging

In [27]:
def measure_time(func, *args, **kwargs):
    """
    Calculates the time it takes to run a function.
    """
    start_time = time.time()
    result = func(*args, **kwargs)
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Function {func.__name__} took {elapsed_time:.2f} seconds to run.")
    return result

def get_gpu_memory():
    """
    Gets the GPU memory information.
    """
    gpus = GPUtil.getGPUs()
    gpu = gpus[0]
    print(f"Total GPU memory: {gpu.memoryTotal}MB")
    print(f"Free GPU memory: {gpu.memoryFree}MB")
    print(f"Used GPU memory: {gpu.memoryUsed}MB")
    return gpu.memoryUsed

def force_clear_GPU_memory():
    """
    Force clears the GPU memory.
    """
    cuda.select_device(0)
    cuda.close()

def cleanup():
    """
    Cleans up the GPU memory.
    """
    gc.collect()
    torch.cuda.empty_cache()

## Baseline Models

In [6]:
# Baseline model functions
def baseline_detoxifier(text_list, profane_word_path=PROFANE_WORD_PATH):
    """
    Returns a detoxified version of the text by replacing toxic terms with blanks

    Args:
        text_list (list): list of strings to be detoxified
        toxic_list (list): list of toxic terms to be removed from text_list

    Returns:
        detoxified_text_list (list): list of detoxified strings
    """
    # Load list of profane words
    profane_words = []
    with open(profane_word_path, "r") as f:
        for line in f:
            profane_words.append(line.strip())

    # Detoxify text
    y_pred_delete = []
    for text in text_list:
        for term in profane_words:
            text = text.replace(term, "")
        y_pred_delete.append(text)

    return y_pred_delete

def bart_detoxifier(text_list):
    """
    Returns a detoxified version of the text using BART

    Args:
        text_list (list): list of strings to be detoxified

    Returns:
        detoxified_text_list (list): list of detoxified strings
    """
    pipe_bart = pipeline("text2text-generation", model="s-nlp/bart-base-detox", device=DEVICE)
    y_pred_bart = pipe_bart(text_list, max_length=MAX_OUTPUT_LENGTH, truncation=True)
    y_pred_bart = [x["generated_text"] for x in y_pred_bart]
    
    return y_pred_bart

# Helper function to add metrics to the dataframe
def add_metrics_to_df(df, model_name, metrics, save_path="../data/processed/model_metrics.csv"):
    """
    Add model metrics to a pandas dataframe
    
    Args:
    - df: pandas dataframe to add metrics to
    - model_name: name of the model
    - metrics: dictionary of evaluation metrics
    
    Returns:
    - updated pandas dataframe
    """

    # Create a df if the input df is empty
    if df is None:
        df = pd.DataFrame(columns=["Model", "BLEURT", "BLEU", "STA", "FLU", "SEM", "Overall"])

    # Check if the model name already exists in the dataframe
    if model_name in df["Model"].values:
        print(f"Model {model_name} already exists in the dataframe.")
        return df
    
    # Add the new row to the dataframe
    model_metrics_df = pd.DataFrame({
        "Model": [model_name],
        "BLEURT": [metrics["BLEURT"]],
        "BLEU": [metrics["BLEU"]],
        "STA": [metrics["STA"]],
        "FLU": [metrics["FLU"]],
        "SEM": [metrics["SEM"]],
        "Overall": [metrics["Overall"]]
    })

    # Save the dataframe to a csv file
    df = pd.concat([df, model_metrics_df], ignore_index=True)
    df.to_csv(save_path, index=False)
    
    return df


## Evaluation

In [7]:
# Initialize model variables
model_bleurt = None
model_bertscore = None
model_sacrebleu = None

def calc_sacrebleu(refs, preds):
    """
    Calculates the SacreBLEU score.

    Args:
        refs (list): List of reference sentences
        preds (list): List of predicted sentences
    
    Returns:
        results (float): SacreBLEU score
    """
    global model_sacrebleu

    if model_sacrebleu is None:
        model_sacrebleu = evaluate.load("sacrebleu")

    results = model_sacrebleu.compute(predictions=preds, references=refs)["score"]
    results = results/100

    return results

def calc_bert_score(
    refs, preds, model_type="microsoft/deberta-large-mnli", output_mean=True
    ):
    """
    Calculates BERT score per line. Note: https://docs.google.com/spreadsheets/d/1RKOVpselB98Nnh_EOC4A2BYn8_201tmPODpNWu4w7xI/edit#gid=0 lists the best performing models
    Args:
        refs (list): List of reference sentences.
        y_pred (list): List of predicted sentences.
        model_type (str): Type of BERT model to use.
        output_mean (bool): Whether to output the mean of the scores.

    Returns:
        list of precision, recall, f1 scores.

    """
    global model_bertscore

    if model_bertscore is None:
        model_bertscore = evaluate.load("bertscore")
        
    results = model_bertscore.compute(predictions=preds, references=refs, model_type=model_type)
    precision = np.array(results["precision"])
    recall = np.array(results["recall"])
    f1 = np.array(results["f1"])
    
    if output_mean:
        precision = precision.mean()
        recall = recall.mean()
        f1 = f1.mean()

    return precision, recall, f1

def calc_bleurt(refs, preds, checkpoint="BLEURT-20_D12", output_mean = True):
    """
    Calculates BLEURT score per line.

    Args:
        refs (list): List of reference sentences.
        preds (list): List of predicted sentences.
        output_type (str): Type of output to return. Either 'numpy' or 'list'.

    Returns:
        list/array of BLEURT scores.
    """
    global model_bleurt

    if model_bleurt is None:
        model_bleurt = evaluate.load("bleurt", module_type="metric", checkpoint=checkpoint)

    results = np.array(model_bleurt.compute(predictions=preds, references=refs)["scores"])

    if output_mean:
        results = results.mean()

    return results

def calc_tox_acceptability(
    data,
    tokenizer,
    model,
    output_score=True,
    output_mean=True):
    """
    Calculates toxicity and acceptability scores for a given dataset.

    Args:
        data = list of strings to be evaluated
        tokenizer = tokenizer for the model
        model = model to be used for evaluation
        output_score = whether to output the score or the label
        output_mean = whether to output the mean of the scores or the scores for each sentence
    
    Returns:
        array of toxicity and acceptability scores.
    """  
    inputs = tokenizer(data, return_tensors="pt", padding=True).to(DEVICE)
    with torch.no_grad():
        logits = model(**inputs)["logits"]
        if output_score:
            result = torch.nn.functional.softmax(logits, dim=1)[:, 1]
        else:
            result = logits.argmax(1).data
        result = result.cpu().numpy()

    if output_mean:
        result = result.mean()
        
    return result

def evaluate_metrics(
    refs,
    preds,
    tokenizer_toxicity=tokenizer_toxicity,
    model_toxicity=model_toxicity,
    tokenizer_acceptability=tokenizer_acceptability,
    model_acceptability=model_acceptability,
    to_neutral=True,
    weights={
        "BLEU": 0.2,
        "STA": 0.4,
        "Acceptability": 0.2,
        "BERT_Score": 0.2
    },
    include_bleurt=INCLUDE_BLEURT
):
    """
    Calculates and returns a dictionary of evaluation metrics

    Args:
        refs (list): list of strings (reference)
        preds (list): list of strings (predictions)
        tokenizer_toxicity (tokenizer): tokenizer for toxicity model
        model_toxicity (model): toxicity model
        tokenizer_acceptability (tokenizer): tokenizer for acceptability model
        model_acceptability (model): acceptability model
        to_neutral (bool): whether the goal is to transfer to neutral (True) or to toxic (False)
        weights (dict): dictionary of weights for each metric
        include_bleurt (bool): whether to include BLEURT score in the output

    Returns:
        results (dict): dictionary of evaluation metrics
    """
    # Calculate BLEU score
    bleu = calc_sacrebleu(refs, preds)

    # Calculate toxicity classification
    tox_pred = calc_tox_acceptability(preds, tokenizer_toxicity, model_toxicity, output_score=False, output_mean=False)

    # Calculate style transfer accuracy as proportion of sentences that were correctly classified (as non-toxic / toxic)
    if to_neutral:
        sta_correct_label = 0
    else:
        sta_correct_label = 1

    sta_pred = (tox_pred == sta_correct_label).sum() / len(tox_pred)

    # Calculate acceptability scores
    acc_pred = calc_tox_acceptability(preds, tokenizer_acceptability, model_acceptability)

    # Calculate similarity score
    bert_score_f1 = calc_bert_score(refs, preds, model_type="distilbert-base-uncased")[2]

    # Calculate BLEURT score if include_bleurt is True
    bleurt = None
    if include_bleurt:
        bleurt = calc_bleurt(refs, preds)

    # Calculate composite score
    composite_score = weights["BLEU"] * bleu + weights["STA"] * sta_pred + weights["Acceptability"] * acc_pred + weights["BERT_Score"] * bert_score_f1

    # Return a dictionary of metrics
    results = {
        "BLEU": bleu,
        "STA": sta_pred,
        "FLU": acc_pred,
        "SEM": bert_score_f1,
        "Overall": composite_score,
    }
    if include_bleurt:
        results["BLEURT"] = bleurt
        
    return results

In [8]:
def add_preds_to_df(model_name, preds, raw_datasets=raw_datasets, save_path=EVAL_PREDS_PATH, load_csv=True):
    """
    Add model predictions to a pandas dataframe

    Args:
    - model_name: name of the model
    - preds: list of predictions
    - save_path: csv file to save the dataframe to
    - load_csv: whether to load the existing csv file. If False, a new dataframe will be created.
    """

    if load_csv:
        df = pd.read_csv(save_path)
    else:
        df = pd.DataFrame({
                "source": raw_datasets["test"]["source"],
                "target": raw_datasets["test"]["target"],
            })
        
    df[f"{model_name}_preds"] = preds
    df.to_csv(save_path, index=False)

    return df

# Helper function to add metrics to the dataframe
def add_metrics_to_df(model_name, metrics, save_path=EVAL_METRICS_PATH, load_csv=True):
    """
    Add model metrics to a pandas dataframe
    
    Args:
    - df: pandas dataframe to add metrics to
    - model_name: name of the model
    - metrics: dictionary of evaluation metrics
    - save_path: csv file to save the dataframe to
    - load_csv: whether to load the existing csv file. If False, a new dataframe will be created.
    
    Returns:
    - updated pandas dataframe
    """

    # Load the existing dataframe if it exists
    if load_csv:
        df = pd.read_csv(save_path)
    else:
        df = pd.DataFrame(columns=["Model", "BLEURT", "BLEU", "STA", "FLU", "SEM", "Overall"])

    # Check if the model name already exists in the dataframe
    if model_name in df["Model"].values:
        print(f"Model {model_name} already exists in the dataframe.")
        return df
    
    # Add the new row to the dataframe
    model_metrics_df = pd.DataFrame({
        "Model": [model_name],
        "BLEURT": [metrics["BLEURT"]],
        "BLEU": [metrics["BLEU"]],
        "STA": [metrics["STA"]],
        "FLU": [metrics["FLU"]],
        "SEM": [metrics["SEM"]],
        "Overall": [metrics["Overall"]]
    })

    # Save the dataframe to a csv file
    df = pd.concat([df, model_metrics_df], ignore_index=True)
    df.to_csv(save_path, index=False)
    
    return df

## Training

In [29]:
def add_prefix(datasetdict, prefix="to_neutral: "):
    """Adds a prefix to the source sequence in the dataset."""
    datasetdict_copy = datasetdict.copy()
    datasetdict_copy["train"] = datasetdict_copy["train"].map(lambda x: {"source": prefix + x["source"]})
    datasetdict_copy["validation"] = datasetdict_copy["validation"].map(lambda x: {"source": prefix + x["source"]})
    datasetdict_copy["test"] = datasetdict_copy["test"].map(lambda x: {"source": prefix + x["source"]})
    datasetdict_copy = DatasetDict(datasetdict_copy)
    return datasetdict_copy

def create_bidirectional_dataset(datasets, shuffle=True):
    """
    Creates a bi-directional dataset from the original dataset.

    Args:
        datasets (DatasetDict): DatasetDict object containing the original dataset.
        shuffle (bool): Whether to shuffle the dataset or not.
    
    Returns:
        extended_datasets (DatasetDict): DatasetDict object containing the bi-directional dataset.
    """

    def bidirectional_extension(dataset):
        new_data = {
            "source": [],
            "target": []
        }
        for src, tgt in zip(dataset['source'], dataset['target']):
            new_data['source'].extend([f'to_neutral: {src}', f'to_toxic: {tgt}'])
            new_data['target'].extend([tgt, src])
        return new_data

    extended_train_data = bidirectional_extension(datasets["train"])
    extended_validation_data = bidirectional_extension(datasets["validation"])
    extended_test_data = bidirectional_extension(datasets["test"])

    extended_datasets = DatasetDict({
        "train": Dataset.from_dict(extended_train_data),
        "validation": Dataset.from_dict(extended_validation_data),
        "test": Dataset.from_dict(extended_test_data)
    })

    if shuffle:
        extended_datasets["train"] = extended_datasets["train"].shuffle(seed=RANDOM_SEED)
        
    return extended_datasets

def preprocess_dataset(dataset, tokenizer):
    """Preprocesses a dataset using a tokenizer."""
    def preprocess_function(examples, tokenizer):
        """Preprocess function for T5."""
        model_inputs = tokenizer(
            examples["source"],
            text_target=examples["target"],
            max_length=MAX_INPUT_LENGTH,
            truncation=True,
        )
        return model_inputs

    return dataset.map(
        preprocess_function,
        fn_kwargs={'tokenizer': tokenizer},
        batched=True,
        remove_columns=["source", "target"],
    )

def post_process(preds, refs, tokenizer):
    """
    Post-process function for T5.

    Args:
        preds (list): list of predicted sequences
        refs (list): list of reference sequences
        tokenizer (PreTrainedTokenizer): tokenizer to use for decoding

    Returns:
        decoded_preds (list): list of decoded predicted sequences
        decoded_refs (list): list of decoded reference sequences
    """
    # In case the model returns more than the prediction logits
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100s in the labels as we can't decode them
    refs = np.where(refs != -100, refs, tokenizer.pad_token_id)
    decoded_refs = tokenizer.batch_decode(refs, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_refs = [ref.strip() for ref in decoded_refs]

    return decoded_preds, decoded_refs

def post_process_preds(preds, tokenizer):
    """
    Post-process function for T5 (only for predictions)

    Args:
        preds (list): list of predicted sequences
        tokenizer (PreTrainedTokenizer): tokenizer to use for decoding

    Returns:
        decoded_preds (list): list of decoded predicted sequences
    """
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_preds = [pred.strip() for pred in decoded_preds]

    return decoded_preds

def compute_metrics(eval_preds, tokenizer):
    """
    Function to calculate the metrics for trainer.evaluate().

    Args:
        tokenizer (PreTrainedTokenizer): tokenizer to use for decoding the predictions
        eval_preds (tuple): Tuple containing the predictions and references

    Returns:
        dict: Dictionary containing the metrics
    """
    preds, refs = eval_preds

    # Post-process the predictions and references
    decoded_preds, decoded_refs = post_process(preds, refs, tokenizer)
    
    # Evaluate metrics
    return evaluate_metrics(
        decoded_refs,
        decoded_preds,
        tokenizer_toxicity=tokenizer_toxicity,
        model_toxicity=model_toxicity,
        tokenizer_acceptability=tokenizer_acceptability,
        model_acceptability=model_acceptability,
        include_bleurt=INCLUDE_BLEURT
    )

def compute_metrics_bd(eval_preds, tokenizer, bd_dataset, shuffled_data=False):
    """
    Function to calculate the metrics for trainer.evaluate().
    This function is for the bi-directional model.
    
    Args:
        eval_preds (tuple): Tuple containing the predictions and references
        tokenizer (PreTrainedTokenizer): tokenizer to use for decoding the predictions
        shuffled_data (bool): Whether the data is shuffled or not
        bd_dataset (DatasetDict): Bidirectional dataset to use for testing created using create_bidirectional_datasets
                                  For example, raw_datasets_bd["validation"] or raw_datasets_bd["test"]

    Returns:
        dict: Dictionary containing the metrics
    """
    preds, refs = eval_preds

    # Post-process the predictions and references
    decoded_preds, decoded_refs = post_process(preds, refs, tokenizer)
    
    # If shuffled data is false, have to_neutral_preds and to_neutral_refs just be predictions and refs with even indices
    if not shuffled_data:
        to_neutral_preds = decoded_preds[::2]
        to_neutral_refs = decoded_refs[::2]
    # Otherwise, get the indices to use when splitting predictions and refs to to_neutral and to_toxic
    else:
        # Get the indices to use when splitting predictions and refs to to_neutral and to_toxic
        to_neutral_idx = [i for i, input_sentence in enumerate(bd_dataset['source']) if input_sentence.startswith("to_neutral")]

        # Retrieve based on the indices
        to_neutral_preds = [decoded_preds[i] for i in to_neutral_idx]
        to_neutral_refs = [decoded_refs[i] for i in to_neutral_idx]
    
    # Evaluate metrics for to_neutral
    to_neutral_metrics = evaluate_metrics(
        to_neutral_refs,
        to_neutral_preds,
        include_bleurt=INCLUDE_BLEURT
    )

    # Return dictionary of to_neutral metrics
    return to_neutral_metrics

def setup_trainer(output_dir_name,
                train_dataset,
                eval_dataset,
                compute_metrics,
                model_checkpoint="t5-small",
                per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
                per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
                learning_rate=LEARNING_RATE,
                num_train_epochs=NUM_TRAIN_EPOCHS,
                max_length=MAX_OUTPUT_LENGTH,
                num_beams=NUM_BEAMS,
                early_stopping_patience=EARLY_STOPPING_PATIENCE,
                report_to="wandb",
                ):
    """
    Set up a Seq2SeqTrainer object for training a T5 model.

    Default parameters based on this: https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/models/hf_model.py#L55

    Args:
        output_dir_name (str): What to name the model in the output directory.
        train_dataset (Dataset): Training dataset.
        eval_dataset (Dataset): Evaluation dataset.
        compute_metrics (function): Function to compute metrics. Change this to compute_metrics_bd if using a bi-directional model.
        model_checkpoint (str): Model checkpoint to use.
        per_device_train_batch_size (int): Batch size for training.
        per_device_eval_batch_size (int): Batch size for evaluation.
        learning_rate (float): Learning rate.
        num_train_epochs (int): Number of training epochs.
        max_length (int): Maximum length of the output sequence.
        num_beams (int): Number of beams for beam search.
        early_stopping_patience (int): Number of epochs to wait before early stopping.
        report_to (str): Where to report results to. Either "wandb" or "none".

    Returns:
        Seq2SeqTrainer: Trainer object for training the T5 model.
    """
    
    # Instantiate model and tokenizer
    model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
    tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)

    # Define the data collator
    data_collator = DataCollatorForSeq2Seq(tokenizer, model, return_tensors="pt", padding=True)

    # Define generation config
    generation_config = GenerationConfig(
        max_length=max_length,
        num_beams=num_beams,
        early_stopping=True,
        eos_token_id=model.config.eos_token_id,
        bos_token_id=model.config.bos_token_id,
        pad_token_id=model.config.pad_token_id,
        decoder_start_token_id=model.config.pad_token_id
        )

    # Save the generation config
    gen_config_path = f"../models/{output_dir_name}/generation_config"
    generation_config.save_pretrained(gen_config_path)

    # Define the training arguments
    args = Seq2SeqTrainingArguments(
        output_dir=f'../models/{output_dir_name}',
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="epoch",
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        learning_rate=learning_rate, 
        predict_with_generate=True,
        generation_config=gen_config_path,
        fp16=True,
        report_to=report_to,
        logging_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="Overall",
        greater_is_better=True,
        generation_max_length=max_length,
    )
   
    # Instantiate the trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=partial(compute_metrics, tokenizer=tokenizer),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)]

    )

    return trainer

def setup_trainer(output_dir_name,
                train_dataset,
                eval_dataset,
                compute_metrics,
                model_checkpoint="t5-small",
                per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
                per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
                learning_rate=LEARNING_RATE,
                num_train_epochs=NUM_TRAIN_EPOCHS,
                max_length=MAX_OUTPUT_LENGTH,
                num_beams=NUM_BEAMS,
                early_stopping_patience=EARLY_STOPPING_PATIENCE,
                report_to="wandb",
                ):
    """
    Set up a Seq2SeqTrainer object for training a T5 model.

    Default parameters based on this: https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/models/hf_model.py#L55

    Args:
        output_dir_name (str): What to name the model in the output directory.
        train_dataset (Dataset): Training dataset.
        eval_dataset (Dataset): Evaluation dataset.
        compute_metrics (function): Function to compute metrics. Change this to compute_metrics_bd if using a bi-directional model.
        model_checkpoint (str): Model checkpoint to use.
        per_device_train_batch_size (int): Batch size for training.
        per_device_eval_batch_size (int): Batch size for evaluation.
        learning_rate (float): Learning rate.
        num_train_epochs (int): Number of training epochs.
        max_length (int): Maximum length of the output sequence.
        num_beams (int): Number of beams for beam search.
        early_stopping_patience (int): Number of epochs to wait before early stopping.
        report_to (str): Where to report results to. Either "wandb" or "none".

    Returns:
        Seq2SeqTrainer: Trainer object for training the T5 model.
    """
    
    # Instantiate model and tokenizer
    model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
    tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)

    # Define the data collator
    data_collator = DataCollatorForSeq2Seq(tokenizer, model, return_tensors="pt", padding=True)

    # Define generation config
    generation_config = GenerationConfig(
        max_length=max_length,
        num_beams=num_beams,
        early_stopping=True,
        eos_token_id=model.config.eos_token_id,
        bos_token_id=model.config.bos_token_id,
        pad_token_id=model.config.pad_token_id,
        decoder_start_token_id=model.config.pad_token_id
        )

    # Save the generation config
    gen_config_path = f"../models/{output_dir_name}/generation_config"
    generation_config.save_pretrained(gen_config_path)

    # Define the training arguments
    args = Seq2SeqTrainingArguments(
        output_dir=f'../models/{output_dir_name}',
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="epoch",
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        learning_rate=learning_rate, 
        predict_with_generate=True,
        generation_config=gen_config_path,
        fp16=True,
        report_to=report_to,
        logging_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="Overall",
        greater_is_better=True,
        generation_max_length=max_length,
    )
   
    # Instantiate the trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=partial(compute_metrics, tokenizer=tokenizer),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)]

    )

    return trainer

def training_pipeline(model_name, project_name="t5-detox", model_checkpoint="t5-small", use_validation=True, raw_datasets=raw_datasets, bidirectional=False, shuffle=False, do_train=True):
    """
    Pipeline for training a T5 model. Saves the best model checkpoint to a txt file. Can also be used for evaluating a model (use test set instead of validation set).

    Args:
        model_name (str): Name of the model to name the output directory and wandb run.
        project_name (str): Name of the wandb project.
        model_checkpoint (str): Model checkpoint to use.
        use_validation (bool): Whether to use the validation set or not.
        raw_datasets (DatasetDict): DatasetDict object containing the original dataset.
        bidirectional (bool): Whether to use a bi-directional model or not.
        shuffle (bool): Whether to shuffle the dataset or not.
        do_train (bool): Whether to train the model or not.

    Returns:
        trainer (Seq2SeqTrainer): Trainer object for training the T5 model.
    """
    
    # Preprocess dataset (add prefixes / make bidirectional)
    if bidirectional:
        raw_datasets = create_bidirectional_dataset(raw_datasets, shuffle=shuffle)
    else:
        raw_datasets = add_prefix(raw_datasets)

    # Tokenize dataset
    tokenized_datasets = preprocess_dataset(raw_datasets, tokenizer_t5_small)

    # Define compute_metrics function depending on bidirectional or not
    if bidirectional and use_validation:
        bd_dataset = raw_datasets["validation"]
    elif bidirectional and not use_validation:
        bd_dataset = raw_datasets["test"]
    else:
        bd_dataset = None

    compute_metrics_fn = partial(compute_metrics_bd, bd_dataset=bd_dataset, shuffled_data=shuffle) if bd_dataset else compute_metrics

    # Setup trainer
    trainer = setup_trainer(
        output_dir_name=model_name,
        model_checkpoint=model_checkpoint,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"] if use_validation else tokenized_datasets["test"],
        compute_metrics=compute_metrics_fn
    )

    if do_train:
        # Initialize wandb
        wandb.init(project=project_name, name=model_name)
        trainer.train()
        wandb.finish()

        # Get the best checkpoint path for the model
        checkpoint_path = trainer.state.best_model_checkpoint

        # Save the checkpoint path for the best model
        with open(BEST_MODEL_CHECKPOINT_PATH, "a") as file:
            file.write(f"{model_name}: {checkpoint_path}\n")

    return trainer, tokenized_datasets

# Evaluation

# Baseline: DELETE

In [37]:
delete_preds_test = baseline_detoxifier(raw_datasets["test"]["source"])
df_test = add_preds_to_df("DELETE", delete_preds_test, load_csv=False)

delete_preds_test_metrics = evaluate_metrics(raw_datasets["test"]["target"], delete_preds_test)
df_eval = add_metrics_to_df("DELETE", delete_preds_test_metrics, load_csv=False)

# Baseline: BART

In [14]:
bart_preds_test = bart_detoxifier(raw_datasets["test"]["source"])
df_test = add_preds_to_df("BART", bart_preds_test, load_csv=True)

bart_preds_test_metrics = evaluate_metrics(raw_datasets["test"]["target"], bart_preds_test)
df_eval = add_metrics_to_df(df_eval, "BART", bart_preds_test_metrics, load_csv=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# T5 Models

In [59]:
def get_model_checkpoints():
    # Get checkpoint values for the best models
    with open(BEST_MODEL_CHECKPOINT_PATH, "r") as f:
        best_model_checkpoints = f.readlines()

    # Convert to a dictionary
    best_model_checkpoints_dict = {}
    for line in best_model_checkpoints:
        model_name, checkpoint_path = line.split(": ")
        best_model_checkpoints_dict[model_name] = checkpoint_path.strip()

    return best_model_checkpoints_dict

model_checkpoints = get_model_checkpoints()

In [76]:
def get_t5_preds_metrics(model_name,
                         model_checkpoint_dict=model_checkpoints,
                         raw_datasets=raw_datasets,
                         bidirectional=False,
                         shuffle=False,
                         use_validation=False,
                         do_train=False,
                         tokenizer=tokenizer_t5_small
                         ):
    """
    Returns the predictions and metrics for a T5 model.
    """
    
    # Setup training pipeline
    trainer, trainer_tokenized_ds = training_pipeline(
        model_name="n/a",
        project_name="n/a",
        model_checkpoint=model_checkpoint_dict[model_name],
        use_validation=use_validation,
        raw_datasets=raw_datasets,
        bidirectional=bidirectional,
        shuffle=shuffle,
        do_train=do_train
    )

    # Get raw predictions
    trainer_preds_raw = trainer.predict(trainer_tokenized_ds["test"])

    # Get encoded predictions and metrics
    trainer_preds_encoded, trainer_metrics = trainer_preds_raw.predictions, trainer_preds_raw.metrics

    # Post-process predictions
    if isinstance(trainer_preds_encoded, tuple):
        trainer_preds_encoded = trainer_preds_encoded[0]

    trainer_preds_decoded = tokenizer.batch_decode(trainer_preds_encoded, skip_special_tokens=True)
    trainer_preds_decoded = [pred.strip() for pred in trainer_preds_decoded]

    #Return trainer metrics in the same format as evaluate_metrics
    trainer_metrics = {
        "BLEU": trainer_metrics["test_BLEU"],
        "BLEURT": trainer_metrics["test_BLEURT"],
        "STA": trainer_metrics["test_STA"],
        "FLU": trainer_metrics["test_FLU"],
        "SEM": trainer_metrics["test_SEM"],
        "Overall": trainer_metrics["test_Overall"]
    }
        
    # Return predictions and metrics
    return trainer_preds_decoded, trainer_metrics

In [83]:
# Loop through the model checkpoints and get predictions and metrics for each
for model_name, model_path in model_checkpoints.items():
    # Get predictions and metrics
    preds, metrics = get_t5_preds_metrics(model_name)

    # Add predictions to dataframe
    df_test = add_preds_to_df(model_name, preds)

    # Add metrics to dataframe
    df_eval = add_metrics_to_df(model_name, metrics)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# T5 Model With Negative Lexically Constrained Decoding

In [71]:
tokenizer_toxicity = RobertaTokenizer.from_pretrained("SkolkovoInstitute/roberta_toxicity_classifier")
model_toxicity = RobertaForSequenceClassification.from_pretrained("SkolkovoInstitute/roberta_toxicity_classifier")

Some weights of the model checkpoint at SkolkovoInstitute/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [94]:
for sentence in raw_datasets["test"]["source"][:5]:
    print()
    print(f"Sentence: {sentence}")
    
    # Tokenize sentence
    inputs = tokenizer_toxicity(sentence, return_tensors="pt")
    input_ids = inputs["input_ids"]
    print(f"Input IDs shape: {input_ids.shape}")
    print(f"Input IDs: {input_ids}")

    # Get attention scores
    attention = model_toxicity(input_ids, output_attentions=True)['attentions']
    
    # Get the last 3 layer attention scores
    attention = attention[-3:]

    # Get the mean attention scores for the last 3 layers
    attention = torch.stack(attention).mean(0)
    print(f"Attention shape after averaging across 3 layers: {attention.shape}")

    # Average across each head
    attention = attention.mean(1)
    print(f"Attention shape after averaging across each head: {attention.shape}")

    # Sum each row to get the attention score for each token
    attention = attention.mean(1)
    print(f"Attention shape after averaging across each row: {attention.shape}")
    print(f"Attention: {attention}")

    # Get the indices of the top 3 tokens with the highest attention scores
    top_indices = attention.topk(4).indices.squeeze().tolist()
    print(f"Top indices: {top_indices}")

    # Filter input_ids to only include the top indices
    filtered_input_ids = input_ids[:, top_indices].squeeze().tolist()
    print(f"Filtered input IDs: {filtered_input_ids}")

    # Decode the filtered input IDs, skipping special tokens and outputting as a list
    bad_words = []
    for input_id in filtered_input_ids:
        bad_words.append(tokenizer_toxicity.decode(input_id, skip_special_tokens=True))
    
    print(f"Bad words: {bad_words}")


Sentence: My thoughts are that our new guy might be an even bigger dirt bag.
Input IDs shape: torch.Size([1, 17])
Input IDs: tensor([[    0,  2387,  4312,    32,    14,    84,    92,  2173,   429,    28,
            41,   190,  2671, 10667,  3298,     4,     2]])
Attention shape after averaging across 3 layers: torch.Size([1, 12, 17, 17])
Attention shape after averaging across each head: torch.Size([1, 17, 17])
Attention shape after averaging across each row: torch.Size([1, 17])
Attention: tensor([[0.0245, 0.0356, 0.0252, 0.0242, 0.0346, 0.0566, 0.0571, 0.0656, 0.0579,
         0.0584, 0.0474, 0.0600, 0.0601, 0.0947, 0.0917, 0.1030, 0.1034]],
       grad_fn=<MeanBackward1>)
Top indices: [16, 15, 13, 14]
Filtered input IDs: [2, 4, 10667, 3298]
Bad words: ['</s>', '.', ' dirt', ' bag']

Sentence: Getting stronger every month yeah i fucking wish.
Input IDs shape: torch.Size([1, 11])
Input IDs: tensor([[    0, 28750,  3651,   358,   353, 11380,   939, 23523,  2813,     4,
             2]]

In [5]:
# Use the RoBERTa toxicity classifier on each sentence of the test set to identify attention weights for each word
tokenized_inputs = tokenizer_toxicity(raw_datasets["test"]["source"], return_tensors="pt", padding=True, truncation=True).to(DEVICE)
encoded_outputs = model_toxicity(**tokenized_inputs, output_attentions=True)

In [6]:
tokenized_inputs["input_ids"][:5]

tensor([[    0,  2387,  4312,    32,    14,    84,    92,  2173,   429,    28,
            41,   190,  2671, 10667,  3298,     4,     2,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1],
        [    0, 28750,  3651,   358,   353, 11380,   939, 23523,  2813,     4,
             2,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1],
        [    0,  1711,  2173,   341,     7,    28,    10,   588, 38594,     4,
             2,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1],
        [    0,  7608,     5, 26536,   473,   961,  4157, 13550,  6374,    98,
           203,   116,     4,     2,     1,     1,     1,     1,     1,     1,
             1,  

In [7]:
# Decode first 5 inputs
tokenizer_toxicity.batch_decode(tokenized_inputs["input_ids"][:5], skip_special_tokens=True)

['My thoughts are that our new guy might be an even bigger dirt bag.',
 'Getting stronger every month yeah i fucking wish.',
 'That guy used to be a real dick.',
 'Why the fuck does everyone hate syria so much?.',
 'I think that s noble as fuck.']

In [11]:
# Get the attention weights for each word in the sentence
attentions = encoded_outputs["attentions"]
attentions = torch.cat(attentions, dim=0).view(-1, attentions[0].size(-1))
attentions[:1]

tensor([[0.8803, 0.0160, 0.0112, 0.0088, 0.0060, 0.0055, 0.0050, 0.0052, 0.0085,
         0.0085, 0.0068, 0.0078, 0.0052, 0.0076, 0.0044, 0.0050, 0.0081, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [12]:
# Remove the first token from each attention weight
attentions_nostart = attentions[:, 1:]
attentions_nostart[:1]

tensor([[0.0160, 0.0112, 0.0088, 0.0060, 0.0055, 0.0050, 0.0052, 0.0085, 0.0085,
         0.0068, 0.0078, 0.0052, 0.0076, 0.0044, 0.0050, 0.0081, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [13]:
# For the first 5 sentences, print raw sentence, tokenized sentence, and attention weights
for i in range(3):
    print()
    print(f"RAW SENTENCE: {raw_datasets['test']['source'][i]}")
    print(f"TOKENIZED SENTENCE: {tokenized_inputs['input_ids'][i][1:]}")
    print(f"ATTENTION WEIGHTS: {attentions[i]}")


RAW SENTENCE: My thoughts are that our new guy might be an even bigger dirt bag.
TOKENIZED SENTENCE: tensor([ 2387,  4312,    32,    14,    84,    92,  2173,   429,    28,    41,
          190,  2671, 10667,  3298,     4,     2,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1], device='cuda:0')
ATTENTION WEIGHTS: tensor([0.8803, 0.0160, 0.0112, 0.0088, 0.0060, 0.0055, 0.0050, 0.0052, 0.0085,
        0.0085, 0.0068, 0.0078, 0.0052, 0.0076, 0.0044, 0.0050, 0.0081, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       device='cuda:0', grad_fn=<SelectBackward0>)

RAW SENTENCE: Getting stronger every month yeah i fucking wish.
TOKENIZED SENTENCE: tensor([28750,  3651,   358,   353, 11380,   939, 23523,  2813,     4,     2,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
 

## Approach A: Set top N tokens by attention

In [17]:
# Get the indices of the top 3 words with the highest attention weights, ignoring the first token
top_indices = torch.topk(attentions_nostart, 5, dim=1)
top_indices

torch.return_types.topk(
values=tensor([[0.0160, 0.0112, 0.0088, 0.0085, 0.0085],
        [0.0924, 0.0472, 0.0274, 0.0123, 0.0078],
        [0.4354, 0.1660, 0.1048, 0.0288, 0.0231],
        ...,
        [0.3704, 0.1812, 0.0534, 0.0470, 0.0442],
        [0.3704, 0.1812, 0.0534, 0.0470, 0.0442],
        [0.3704, 0.1812, 0.0534, 0.0470, 0.0442]], device='cuda:0',
       grad_fn=<TopkBackward0>),
indices=tensor([[ 0,  1,  2,  7,  8],
        [ 0,  1,  2,  7,  6],
        [ 0,  1,  2,  7,  9],
        ...,
        [14,  7,  2,  6,  9],
        [14,  7,  2,  6,  9],
        [14,  7,  2,  6,  9]], device='cuda:0'))

In [18]:
# Add 1 to each index to account for the first token
top_indices = top_indices.indices + 1
top_indices

tensor([[ 1,  2,  3,  8,  9],
        [ 1,  2,  3,  8,  7],
        [ 1,  2,  3,  8, 10],
        ...,
        [15,  8,  3,  7, 10],
        [15,  8,  3,  7, 10],
        [15,  8,  3,  7, 10]], device='cuda:0')

In [21]:
top_k_words = []

for ids, top_indices in zip(tokenized_inputs["input_ids"], top_indices):
    top_k_words.append([tokenizer_toxicity.decode(ids[idx]) for idx in top_indices])

top_k_words

[['My', ' thoughts', ' are', ' might', ' be'],
 ['Getting', ' stronger', ' every', ' wish', ' fucking'],
 ['That', ' guy', ' used', ' dick', '</s>'],
 ['Why', ' the', ' fuck', ' does', 'ria'],
 [' think', 'I', ' that', ' s', ' noble'],
 [' still', ' wife', ' gives', 'My', ' me'],
 [' to', ' refreshing', ' see', ' s', 'It'],
 [' state', ' thinks', 'His', ' of', ' secretary'],
 [' of', ' its', ' cheap', ' but', ' such'],
 [' about', ' mean', ' actually', 'is', ' the'],
 [' big', ' stick', ' a', ' to', ' smack'],
 ['?', ' suck', '</s>', '<pad>', '<pad>'],
 ['<pad>', '<pad>', '<pad>', '</s>', '<pad>'],
 ['ry', ' har', ' and', ' ll', 'oyd'],
 [' played', 'd', ' i', 'yr', ' with'],
 [' than', ',', ' open', ' fool', ' it'],
 ['</s>', 'ria', '<pad>', '.', ' sy'],
 [' they', 'Well', '<pad>', '.', '<pad>'],
 [' other', 'In', ',', ' actually', ' justice'],
 [' the', 'Get', '<pad>', ' do', '<pad>'],
 ['ont', 'D', '<pad>', ' to', '<pad>'],
 [' kind', 'What', '.', ' put', ' first'],
 [' comment', 'M

In [20]:
import torch

def get_top_k_words_per_sentence(input_ids, attentions, tokenizer, k=3):
    top_k_words = []

    # Iterate over each sentence and its corresponding attention scores
    for ids, attention in zip(input_ids, attentions):
        # Skip special tokens
        token_attention_pairs = [(tok, att) for tok, att in zip(ids, attention) if tok not in [tokenizer.pad_token_id, tokenizer.cls_token_id, tokenizer.sep_token_id]]

        # Sort by attention score in descending order and select top k
        top_tokens = sorted(token_attention_pairs, key=lambda x: x[1], reverse=True)[:k]

        # Decode tokens to words
        top_words = [tokenizer.decode([tok]) for tok, _ in top_tokens]
        top_k_words.append(top_words)

    return top_k_words

# Example usage
# Assuming you have your input_ids, attentions from the model, and tokenizer
top_k_words = get_top_k_words_per_sentence(input_ids, attentions, tokenizer, k=3)


tensor([1, 2, 3, 8, 9], device='cuda:0')
tensor([1, 2, 3, 8, 7], device='cuda:0')
tensor([ 1,  2,  3,  8, 10], device='cuda:0')
tensor([1, 2, 3, 4, 8], device='cuda:0')
tensor([2, 1, 3, 4, 5], device='cuda:0')


In [None]:
# For each line of tokenized_inputs["input_ids"], get the top 3 words with the highest attention weights
top_words = []
for i in range(len(tokenized_inputs["input_ids"])):
    top_words.append(tokenizer_toxicity.convert_ids_to_tokens(top_indices[i]))