In [1]:
import nltk
nltk.download("punkt")

[nltk_data] Downloading package punkt to /home/user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
import evaluate
import json
import pandas as pd
import torch as tt

from datasets import load_dataset, Dataset
from nltk.tokenize import sent_tokenize
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import Any, Dict
from tqdm import tqdm_notebook

In [3]:
# models:
tokenizer = T5Tokenizer.from_pretrained("ai-forever/ruT5-base")
model = T5ForConditionalGeneration.from_pretrained("RuT5-MuSeRC-DG/checkpoint-14500")
model = model.to(tt.device("cuda:0"))

# metrics:
bleu4 = evaluate.load("bleu")
sbleu = evaluate.load("sacrebleu")
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[nltk_data] Downloading package wordnet to /home/user/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/user/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Load dataset:

In [4]:
def to_dg_format(dataset: list[dict[str, Any]]) -> list[dict[str, Any]]:
    dataset_processed = []
    item_id = 0

    for iidx, item in enumerate(dataset):
        for question in item["passage"]["questions"]:
            new_item = {
                "item_id": item_id,
                "passage_id": item["idx"],
                "passage": item["passage"]["text"],
                "question": question["question"],
                "distractors": ';'.join(
                    [
                        f'"{answer["text"]}"' for answer in question["answers"] if answer["label"] == 0
                    ]
                ),
                "right_answer": [
                    answer["text"] for answer in question["answers"] if answer["label"] == 1
                ][0]
            }
            dataset_processed.append(new_item)
            item_id += 1

    return dataset_processed


def to_dg_format_final(dataset: list[dict[str, Any]]) -> list[dict[str, Any]]:
    new_dataset = []

    for item in dataset:
        new_item = {
            "item_id": item["item_id"],
            "passage_id": item["passage_id"],
            "inp": f'{item["passage"]} ВОПРОС: {item["question"]} ПРАВИЛЬНЫЙ ОТВЕТ: {item["right_answer"]} НЕПРАВИЛЬНЫЕ ВАРИАНТЫ ОТВЕТА: ',
            "outp": item["distractors"],
            "outp_len": len(tokenizer(item["distractors"])["input_ids"])
        }
        new_dataset.append(new_item)

    return new_dataset


muserc_train = pd.read_json("MuSeRC/train.jsonl", lines=True).to_dict(orient="records")
muserc_val = pd.read_json("MuSeRC/val.jsonl", lines=True).to_dict(orient="records")
muserc_train_dg = Dataset.from_list(to_dg_format_final(to_dg_format(muserc_train)))
muserc_val_dg = Dataset.from_list(to_dg_format_final(to_dg_format(muserc_val)))

In [5]:
pd.Series(muserc_train_dg["outp_len"]).describe()

count    2897.000000
mean       23.775285
std        12.870540
min         1.000000
25%        15.000000
50%        21.000000
75%        29.000000
max       101.000000
dtype: float64

In [6]:
MAX_LEN = int(pd.Series(muserc_train_dg["outp_len"]).quantile(0.99))
MAX_LEN

69

In [7]:
def get_metric_inputs_seq2seq(
    input_batch: list[str],
    model: PreTrainedModel, tokenizer: PreTrainedTokenizer
) -> list[str]:
    input_batch_ = tokenizer(
        input_batch,
        return_tensors="pt",
        padding=True
    )["input_ids"].to(tt.device("cuda:0"))

    with tt.no_grad():
        output_batch = model.generate(input_batch_, max_length=MAX_LEN)

    output = [
        sent.replace("<pad>", " ").replace("</s>", " ").strip() for sent in tokenizer.batch_decode(
            output_batch)
    ]
    
    del input_batch_
    del output_batch
    tt.cuda.empty_cache()

    return output

def compute_metrics(output: list[str], label_batch: list[str]) -> dict:
    metric_dict = {
        "bleu": bleu4.compute(predictions=output, references=[[label] for label in label_batch]),
        "sbleu": sbleu.compute(predictions=output, references=[[label] for label in label_batch]),
        "rouge": rouge.compute(predictions=output, references=label_batch),
        "meteor": meteor.compute(predictions=output, references=label_batch)
    }
    return metric_dict

In [8]:
def compute_metrics_on_dataset_seq2seq(
    dataset: Dataset, model: PreTrainedModel=model,
    tokenizer: PreTrainedTokenizer=tokenizer
) -> pd.DataFrame:
    batch_size = 1

    n_steps = (len(dataset) // batch_size) + 1
    metrics = []

    for i in tqdm_notebook(range(n_steps), total=n_steps):
        slice = dataset[i*batch_size:(i+1)*batch_size]
        if slice["inp"]:
            output = get_metric_inputs_seq2seq(slice["inp"], model, tokenizer)
            distractors = [
                item.replace('\n', '').replace('  ',' ').replace('  ',' ').strip()
                for item in slice["outp"]
            ]
            if len(distractors[0]) > 0:
                metric = compute_metrics(output, distractors)
                metrics.append({
                    "item_id": slice["item_id"][0],
                    "passage_id": slice["passage_id"][0],
                    "inp": slice["inp"][0],
                    "distractors": distractors[0],
                    "output": output[0],
        
                    "bleu": metric["bleu"]["bleu"],
                    "sbleu": metric["sbleu"]["score"],
                    "rouge1": metric["rouge"]["rouge1"],
                    "rouge2": metric["rouge"]["rouge2"],
                    "rougeL": metric["rouge"]["rougeL"],
                    "rougeLsum": metric["rouge"]["rougeLsum"],
                    "meteor": metric["meteor"]["meteor"],
                })

    return pd.DataFrame(metrics)

In [9]:
METRIC_COLS = [
    "bleu", "sbleu", "rouge1", "rouge2",
    "rougeL", "rougeLsum", "meteor"
]

In [None]:
metrics_muserc_train = compute_metrics_on_dataset_seq2seq(muserc_train_dg)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i in tqdm_notebook(range(n_steps), total=n_steps):


  0%|          | 0/2898 [00:00<?, ?it/s]

In [None]:
metrics_muserc_train

In [16]:
metrics_muserc_train[METRIC_COLS].describe()

Unnamed: 0,bleu,sbleu,rouge1,rouge2,rougeL,rougeLsum,meteor
count,2896.0,2896.0,2896.0,2896.0,2896.0,2896.0,2896.0
mean,0.742225,74.349,0.11141,0.049719,0.110692,0.110692,0.839152
std,0.256326,25.30154,0.285275,0.197448,0.283919,0.283919,0.18025
min,0.0,0.696273,0.0,0.0,0.0,0.0,0.105167
25%,0.576674,57.667354,0.0,0.0,0.0,0.0,0.731129
50%,0.763857,76.385667,0.0,0.0,0.0,0.0,0.88697
75%,1.0,100.0,0.0,0.0,0.0,0.0,0.9995
max,1.0,100.0,1.0,1.0,1.0,1.0,0.999995


In [None]:
metrics_muserc_train.to_csv("metrics_muserc_train.csv", sep=';')

In [None]:
metrics_muserc_val = compute_metrics_on_dataset_seq2seq(muserc_val_dg)
metrics_muserc_val.to_csv("metrics_muserc_val.csv", sep=';')

In [17]:
metrics_muserc_val[METRIC_COLS].describe()

Unnamed: 0,bleu,sbleu,rouge1,rouge2,rougeL,rougeLsum,meteor
count,528.0,528.0,528.0,528.0,528.0,528.0,528.0
mean,0.193488,20.123265,0.017784,0.006875,0.017152,0.017152,0.457789
std,0.141881,13.300166,0.100311,0.06504,0.09722,0.09722,0.163195
min,0.0,0.112528,0.0,0.0,0.0,0.0,0.11094
25%,0.101119,10.55267,0.0,0.0,0.0,0.0,0.337701
50%,0.17695,17.694975,0.0,0.0,0.0,0.0,0.44449
75%,0.274049,27.404852,0.0,0.0,0.0,0.0,0.575006
max,0.750624,75.062385,1.0,1.0,1.0,1.0,0.903748
