In [1]:
import re, json

from typing import Union, Any

import pandas as pd
import torch as tt

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from transformers import PreTrainedModel, PreTrainedTokenizer, EvalPrediction
from datasets import Dataset

In [2]:
tokenizer = AutoTokenizer.from_pretrained("ai-forever/rugpt3small_based_on_gpt2")
model = AutoModelForCausalLM.from_pretrained("ai-forever/rugpt3small_based_on_gpt2")
tokenizer.pad_token = tokenizer.eos_token

  return self.fget.__get__(instance, owner)()


In [3]:
tokenizer.pad_token

'</s>'

In [4]:
with open("title_dataset_pretty_filtered.json", 'r', encoding="utf8") as inp:
    title_dataset = json.load(inp)

In [5]:
title_dataset_train, title_dataset_val, title_dataset_test = title_dataset["train"], title_dataset["val"], title_dataset["test"]
title_dataset_train = Dataset.from_list(title_dataset_train)
title_dataset_val = Dataset.from_list(title_dataset_val)
title_dataset_test = Dataset.from_list(title_dataset_test)

In [6]:
len(title_dataset_train), len(title_dataset_val), len(title_dataset_test)

(4375, 219, 242)

In [7]:
set(item["answer"] for item in title_dataset_train)

{'A', 'B', 'C', 'D'}

In [8]:
set(item["answer"] for item in title_dataset_val)

{'A', 'B', 'C', 'D'}

In [9]:
set(item["answer"] for item in title_dataset_test)

{'A', 'B', 'C', 'D'}

In [10]:
option_id_dict = {
    'A': 0, 'B': 1, 'C': 2, 'D': 3
}

def to_new_format(example: dict[str, Union[str, list[str]]]) -> str:
  example["options_ru"] = [option for option in example["options_ru"] if option]
  right_answer = example['options_ru'][option_id_dict[example['answer']]]
  #print(right_answer)
  outp = example['article_ru'] + "\n" + "ВОПРОС: Какое название лучше всего подойдёт для этого текста? "
  outp += f"ПРАВИЛЬНЫЙ ОТВЕТ: {right_answer}"
  outp += "\nНЕПРАВИЛЬНЫЕ ВАРИАНТЫ ОТВЕТА:"

  inp_gen = outp
  distractors = ''

  for option in example["options_ru"]:
      if option != right_answer:
          #print(option)
          outp += f"\n  {option}"
          distractors += f"\n  {option}"
  #print(outp)
  #raise Exception
  return {"inp": outp, "distractors": distractors, "inp_gen": inp_gen}

In [11]:
title_dataset_train = title_dataset_train.map(to_new_format)
title_dataset_val = title_dataset_val.map(to_new_format)
title_dataset_test = title_dataset_test.map(to_new_format)

Map:   0%|          | 0/4375 [00:00<?, ? examples/s]

Map:   0%|          | 0/219 [00:00<?, ? examples/s]

Map:   0%|          | 0/242 [00:00<?, ? examples/s]

In [12]:
def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["inp"]
    )
    model_inputs["labels"] = model_inputs["input_ids"].copy()
    model_inputs["distractor_tokens"] = tokenizer(
        examples["distractors"]
    )["input_ids"]
    return model_inputs

In [13]:
title_dataset_train = title_dataset_train.map(preprocess_function, batched=True)
title_dataset_val = title_dataset_val.map(preprocess_function, batched=True)
title_dataset_test = title_dataset_test.map(preprocess_function, batched=True)

Map:   0%|          | 0/4375 [00:00<?, ? examples/s]

Map:   0%|          | 0/219 [00:00<?, ? examples/s]

Map:   0%|          | 0/242 [00:00<?, ? examples/s]

In [14]:
outp_lengths = [len(item) for item in title_dataset_train["distractor_tokens"]]
outp_lengths = pd.Series(outp_lengths)

In [15]:
MAX_OUTP_LEN = outp_lengths.quantile(0.99)

In [16]:
MAX_OUTP_LEN

45.0

In [17]:
def cut_last_break(input_: list[str]) -> list[str]:
    output = [s[:s.rfind('\n')] for s in input_]
    return output

def parse_options(input_: list[str]) -> list[str]:
    output = [s.strip() for s in input_]
    output = [set(option.strip() for option in s.split('\n')) for s in output]
    output = [sorted(list(s))[:3] for s in output]
    output = ['\n'.join(s) for s in output]
    return output

def get_metric_inputs(
    input_batch: list[str], label_batch: list[str],
    model: PreTrainedModel, tokenizer: PreTrainedTokenizer
) -> list[str]:
    FACTOR = 1.1

    input_batch_ = tokenizer(input_batch, return_tensors="pt", padding=True)["input_ids"].to(tt.device("cuda:0"))
    label_batch_ = tokenizer(label_batch, return_tensors="pt", padding=True)["input_ids"]

    input_length = input_batch_.shape[-1]
    output_length = label_batch_.shape[-1]
    
    with tt.no_grad():
        output_batch = model.generate(input_batch_, max_length=input_length + MAX_OUTP_LEN)
        output_batch = output_batch[:,input_length:]

    output = tokenizer.batch_decode(output_batch)
    del input_batch_
    del output_batch
    del label_batch_

    output = cut_last_break(output)
    output = parse_options(output)

    return output

def compute_metric_values(output: list[str], label_batch: list[str]) -> dict[str, Any]:
    metric_dict = {
        "bleu": bleu4.compute(predictions=output, references=[[label] for label in label_batch]),
        "sbleu": sbleu.compute(predictions=output, references=[[label] for label in label_batch]),
        "rouge": rouge.compute(predictions=output, references=label_batch),
        "meteor": meteor.compute(predictions=output, references=label_batch)
    }
    return metric_dict

In [18]:
def compute_metrics(eval_preds: EvalPrediction) -> dict[str, Any]:
    # print("Entered compute_metrics function")
    # metrics = []
    # outputs, distractorss = [], []

    # for i in tqdm_notebook(range(N_STEPS), total=N_STEPS):
    #     slice = title_dataset_val[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
    
    #     if slice["inp"]:
    #         distractors = slice["distractors"]
    #         output = get_metric_inputs(slice["inp"], distractors, model, tokenizer)
    #         distractors = parse_options(distractors)
    #         outputs += output
    #         distractorss += distractors

    # print(len(outputs), len(distractorss))

    # metrics = compute_metric_values(outputs, distractorss)

    # return {
    #     "bleu": metrics["bleu"]["bleu"],
    #     "sbleu": metrics["sbleu"]["score"],
    #     "rouge1": metrics["rouge"]["rouge1"],
    #     "rouge2": metrics["rouge"]["rouge2"],
    #     "rougeL": metrics["rouge"]["rougeL"],
    #     "rougeLsum": metrics["rouge"]["rougeLsum"],
    #     "meteor": metrics["meteor"]["meteor"]
    # }
    return {"heh":1,"hah":2}

In [19]:
NUM_TRAIN_EPOCHS=20
BATCH_SIZE=1
#STEPS=1000

training_args = TrainingArguments(
    output_dir="./RuGPT3-RuRACE-1",
    evaluation_strategy="epoch",
    weight_decay=0.01,
    learning_rate=5e-5,
    load_best_model_at_end=True,   
    save_strategy="epoch",
    num_train_epochs=NUM_TRAIN_EPOCHS,
    save_total_limit=3,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    prediction_loss_only=True,
    gradient_checkpointing=True,
    logging_dir="./rugpt3-rurace-1-title-logs",
    fp16=True, eval_accumulation_steps=1
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model,
    args=training_args,
    train_dataset=title_dataset_train,
    eval_dataset=title_dataset_val,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [21]:
trainer.evaluate()

{'eval_loss': 3.277505397796631,
 'eval_runtime': 4.9842,
 'eval_samples_per_second': 43.939,
 'eval_steps_per_second': 43.939}

In [22]:
trainer.compute_metrics

<function __main__.compute_metrics(eval_preds: transformers.trainer_utils.EvalPrediction) -> dict[str, typing.Any]>

In [23]:
compute_metrics

<function __main__.compute_metrics(eval_preds: transformers.trainer_utils.EvalPrediction) -> dict[str, typing.Any]>