In [36]:
from datasets import DatasetDict, Dataset
from transformers import (
    RobertaTokenizer,
    RobertaForSequenceClassification,
    T5Tokenizer,
    T5ForConditionalGeneration,
    T5Config,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    GenerationConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    EarlyStoppingCallback,
    pipeline,
    MarianMTModel,
    MarianTokenizer,
)
import torch
import numpy as np
import evaluate
import pandas as pd
import matplotlib.pyplot as plt
from functools import partial
from typing import Dict, Union, Optional, Tuple, List, Any
import pandas as pd
from tqdm import tqdm

In [2]:
# Setting the DEVICE to cuda
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set path for raw dataset dictionary
RAW_DATASET_PATH = "../data/processed/raw_dataset.pkl"

In [51]:
# Load tokenizers and models
tokenizer_toxicity = RobertaTokenizer.from_pretrained("SkolkovoInstitute/roberta_toxicity_classifier")
model_toxicity = RobertaForSequenceClassification.from_pretrained("SkolkovoInstitute/roberta_toxicity_classifier").to(DEVICE)
tokenizer_acceptability = AutoTokenizer.from_pretrained("iproskurina/tda-bert-en-cola")
model_acceptability = AutoModelForSequenceClassification.from_pretrained("iproskurina/tda-bert-en-cola").to(DEVICE)

Some weights of the model checkpoint at SkolkovoInstitute/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
# Load dataset
raw_datasets = DatasetDict.load_from_disk(RAW_DATASET_PATH)

# Evaluation Functions

In [53]:
# Initialize model variables
model_bleurt = None
model_bertscore = None
model_sacrebleu = None

def calc_sacrebleu(refs, preds):
    """
    Calculates the SacreBLEU score.

    Args:
        refs (list): List of reference sentences
        preds (list): List of predicted sentences
    
    Returns:
        results (float): SacreBLEU score
    """
    global model_sacrebleu

    if model_sacrebleu is None:
        model_sacrebleu = evaluate.load("sacrebleu")

    results = model_sacrebleu.compute(predictions=preds, references=refs)["score"]
    results = results/100

    return results

def calc_bert_score(
    refs, preds, model_type="microsoft/deberta-large-mnli", output_mean=True
    ):
    """
    Calculates BERT score per line. Note: https://docs.google.com/spreadsheets/d/1RKOVpselB98Nnh_EOC4A2BYn8_201tmPODpNWu4w7xI/edit#gid=0 lists the best performing models
    Args:
        refs (list): List of reference sentences.
        y_pred (list): List of predicted sentences.
        model_type (str): Type of BERT model to use.
        output_mean (bool): Whether to output the mean of the scores.

    Returns:
        list of precision, recall, f1 scores.

    """
    global model_bertscore

    if model_bertscore is None:
        model_bertscore = evaluate.load("bertscore")
        
    results = model_bertscore.compute(predictions=preds, references=refs, model_type=model_type)
    precision = np.array(results["precision"])
    recall = np.array(results["recall"])
    f1 = np.array(results["f1"])
    
    if output_mean:
        precision = precision.mean()
        recall = recall.mean()
        f1 = f1.mean()

    return precision, recall, f1

def calc_bleurt(refs, preds, checkpoint="BLEURT-20_D12", output_mean = True):
    """
    Calculates BLEURT score per line.

    Args:
        refs (list): List of reference sentences.
        preds (list): List of predicted sentences.
        output_type (str): Type of output to return. Either 'numpy' or 'list'.

    Returns:
        list/array of BLEURT scores.
    """
    global model_bleurt

    if model_bleurt is None:
        model_bleurt = evaluate.load("bleurt", module_type="metric", checkpoint=checkpoint)

    results = np.array(model_bleurt.compute(predictions=preds, references=refs)["scores"])

    if output_mean:
        results = results.mean()

    return results

def calc_tox_acceptability(
    data,
    tokenizer,
    model,
    output_score=True,
    output_mean=True):
    """
    Calculates toxicity and acceptability scores for a given dataset.

    Args:
        data = list of strings to be evaluated
        tokenizer = tokenizer for the model
        model = model to be used for evaluation
        output_score = whether to output the score or the label
        output_mean = whether to output the mean of the scores or the scores for each sentence
    
    Returns:
        array of toxicity and acceptability scores.
    """  
    inputs = tokenizer(data, return_tensors="pt", padding=True).to(DEVICE)
    with torch.no_grad():
        logits = model(**inputs)["logits"]
        if output_score:
            result = torch.nn.functional.softmax(logits, dim=1)[:, 1]
        else:
            result = logits.argmax(1).data
        result = result.cpu().numpy()

    if output_mean:
        result = result.mean()
        
    return result

def evaluate_metrics(
    refs,
    preds,
    tokenizer_toxicity=tokenizer_toxicity,
    model_toxicity=model_toxicity,
    tokenizer_acceptability=tokenizer_acceptability,
    model_acceptability=model_acceptability,
    to_neutral=True,
    weights={
        "BLEU": 0.2,
        "STA": 0.4,
        "Acceptability": 0.2,
        "BERT_Score": 0.2
    },
    include_bleurt=False
):
    """
    Calculates and returns a dictionary of evaluation metrics

    Args:
        refs (list): list of strings (reference)
        preds (list): list of strings (predictions)
        tokenizer_toxicity (tokenizer): tokenizer for toxicity model
        model_toxicity (model): toxicity model
        tokenizer_acceptability (tokenizer): tokenizer for acceptability model
        model_acceptability (model): acceptability model
        to_neutral (bool): whether the goal is to transfer to neutral (True) or to toxic (False)
        weights (dict): dictionary of weights for each metric
        include_bleurt (bool): whether to include BLEURT score in the output

    Returns:
        results (dict): dictionary of evaluation metrics
    """

    # Calculate BLEU score
    bleu = calc_sacrebleu(refs, preds)

    # Calculate toxicity classification
    # tox_ref = calc_tox_acceptability(refs, tokenizer_toxicity, model_toxicity, output_score=False, output_mean=False)
    tox_pred = calc_tox_acceptability(preds, tokenizer_toxicity, model_toxicity, output_score=False, output_mean=False)

    # Calculate style transfer accuracy as proportion of sentences that were correctly classified (as non-toxic / toxic)
    if to_neutral:
        sta_correct_label = 0
    else:
        sta_correct_label = 1

    # sta_ref = (tox_ref == sta_correct_label).sum() / len(tox_ref)
    sta_pred = (tox_pred == sta_correct_label).sum() / len(tox_pred)
    # sta_pct = sta_pred / sta_ref

    # Calculate acceptability scores
    # acc_ref = calc_tox_acceptability(refs, tokenizer_acceptability, model_acceptability)
    acc_pred = calc_tox_acceptability(preds, tokenizer_acceptability, model_acceptability)
    # acc_pct = acc_pred / acc_ref

    # Calculate similarity score
    bert_score_f1 = calc_bert_score(refs, preds, model_type="distilbert-base-uncased")[2]

    # Calculate BLEURT score if include_bleurt is True
    bleurt = None
    if include_bleurt:
        bleurt = calc_bleurt(refs, preds)

    # Calculate composite score
    composite_score = weights["BLEU"] * bleu + weights["STA"] * sta_pred + weights["Acceptability"] * acc_pred + weights["BERT_Score"] * bert_score_f1

    # Return a dictionary of metrics
    results = {
        "BLEU": bleu,
        "STA": sta_pred,
        # "STA_pct": sta_pct,
        "FLU": acc_pred,
        # "Acceptability_pct": acc_pct,
        "SEM": bert_score_f1,
        "Overall": composite_score,
    }
    if include_bleurt:
        results["BLEURT"] = bleurt
        
    return results


# Back-Translation Functions

In [None]:
# Helper function to download data for a language
def download(model_name):
  tokenizer = MarianTokenizer.from_pretrained(model_name)
  model = MarianMTModel.from_pretrained(model_name)
  return tokenizer, model

# download model for English -> Romance
tmp_lang_tokenizer, tmp_lang_model = download('Helsinki-NLP/opus-mt-en-ROMANCE')

# download model for Romance -> English
src_lang_tokenizer, src_lang_model = download('Helsinki-NLP/opus-mt-ROMANCE-en')

In [37]:
def translate(texts, model, tokenizer, language):
  """Translate texts into a target language"""
  # Format the text as expected by the model
  formatter_fn = lambda txt: f"{txt}" if language == "en" else f">>{language}<< {txt}"
  formatted_texts = [formatter_fn(txt) for txt in texts]

  # Tokenize (text to tokens)
  tokens = tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True)

  # Translate
  translated = model.generate(**tokens)

  # Decode (tokens to text)
  translated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True)

  return translated_texts

def back_translate(texts, language_src, language_dst):
    """Implements back translation"""
    all_back_translated_texts = []

    # Iterate over each text rather than each language
    for text in tqdm(texts):
        back_translated_texts_for_text = []

        # Check if language_dst is a list or a single language
        for lang in language_dst:
            
            # Translate from source to target language
            with torch.no_grad():
                translated = translate([text], tmp_lang_model, tmp_lang_tokenizer, lang)
                back_translated = translate(translated, src_lang_model, src_lang_tokenizer, language_src)
            
            # Collect the back-translated texts for this particular text
            back_translated_texts_for_text.extend(back_translated)

        # Collect all the back-translated versions for each text
        all_back_translated_texts.extend(back_translated_texts_for_text)

    return all_back_translated_texts

In [46]:
# Set languages
languages = ['fr', 'es', 'it', 'pt', 'ro']
num_languages = len(languages)

# Create a copy of the dataset, selecting only 10 samples from each
raw_datasets_train_copy = raw_datasets['train'].select(range(10))
source = raw_datasets_train_copy['source']
target = raw_datasets_train_copy['target']

# Replicate the sentences according to the number of languages
source_replicated = []
for sentence in source:
    for i in range(num_languages):
        source_replicated.extend([sentence])

target_replicated = []
for sentence in target:
    for i in range(num_languages):
        target_replicated.extend([sentence])

# Create a pandas dataframe
df = pd.DataFrame()
df['source'] = source_replicated
df['target'] = target_replicated

# Back translate the sentences and add to pandas dataframe
df['source_bt'] = back_translate(source, "en", languages)
df['target_bt'] = back_translate(target, "en", languages)

# Delete rows with duplicate 'source_bt' or 'target_bt'
print(f"Before deleting duplicates: {len(df)}")
df = df[df['source'] != df['source_bt']]
df = df[df['target'] != df['target_bt']]
df = df.drop_duplicates(subset=['source_bt'])
df = df.drop_duplicates(subset=['target_bt'])
print(f"After deleting duplicates: {len(df)}")

# Calculate toxicity scores for the candidate sentence pairs
df['source_bt_toxicity'] = calc_tox_acceptability(df['source_bt'].tolist(), tokenizer_toxicity, model_toxicity, output_score=False, output_mean=False)
df['target_bt_toxicity'] = calc_tox_acceptability(df['target_bt'].tolist(), tokenizer_toxicity, model_toxicity, output_score=False, output_mean=False)

# Calculate acceptability scores for the candidate sentence pairs
df['source_bt_acceptability'] = calc_tox_acceptability(df['source_bt'].tolist(), tokenizer_acceptability, model_acceptability, output_score=False, output_mean=False)
df['target_bt_acceptability'] = calc_tox_acceptability(df['target_bt'].tolist(), tokenizer_acceptability, model_acceptability, output_score=False, output_mean=False)

# Calculate similarity scores for the candidate sentence pairs - return the F1 score
df['bt_similarity'] = calc_bert_score(df['source_bt'], df['target_bt'], model_type="distilbert-base-uncased", output_mean=False)[2]

# Create filters for the candidate sentence pairs
## Filter 1: Toxicity
df['f_toxicity'] = (df['source_bt_toxicity'] == 1) & (df['target_bt_toxicity'] == 0)

## Filter 2: Acceptability
df['f_acceptability'] = (df['source_bt_acceptability'] == 1) & (df['target_bt_acceptability'] == 1)

## Filter 3: Similarity
df['f_similarity'] = (df['bt_similarity'] > 0.9)

# Save the dataframe as a pickle file
df.to_pickle("../data/processed/back_translated_dataset.pkl")

# Create different datasets based on the filters
df_all = df[df['f_toxicity'] & df['f_acceptability'] & df['f_similarity']]
df_nosta = df[df['f_acceptability'] & df['f_similarity']]
df_noflu = df[df['f_toxicity'] & df['f_similarity']]
df_nosem = df[df['f_toxicity'] & df['f_acceptability']]

Unnamed: 0,source,target
0,"Barnum, were right : theres a sucker born ever...","Barnum, were right : theres a baby born every ..."
1,"Barnum, were right : theres a sucker born ever...","Barnum, were right : theres a baby born every ..."
2,"Barnum, were right : theres a sucker born ever...","Barnum, were right : theres a baby born every ..."
3,"Barnum, were right : theres a sucker born ever...","Barnum, were right : theres a baby born every ..."
4,"Barnum, were right : theres a sucker born ever...","Barnum, were right : theres a baby born every ..."
5,So dont try to act like you know what the fuck...,So dont try to act like you know what is going...
6,So dont try to act like you know what the fuck...,So dont try to act like you know what is going...
7,So dont try to act like you know what the fuck...,So dont try to act like you know what is going...
8,So dont try to act like you know what the fuck...,So dont try to act like you know what is going...
9,So dont try to act like you know what the fuck...,So dont try to act like you know what is going...
