Lab9

---

# Contextual Question Answering
## Abstractive QA

The aim of this exercise is building a neural model able to answer contextual questions in the legal domain.

Training and Validation dataset: PoQuAD

Testing dataset: Simple Legal Questions Dataset

Resources:
https://medium.com/@ajazturki10/simplifying-language-understanding-a-beginners-guide-to-question-answering-with-t5-and-pytorch-253e0d6aac54



In [1]:
!pip install transformers evaluate rouge wandb nltk sentencepiece

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting rouge
  Downloading rouge-1.0.1-py3-none-any.whl.metadata (4.1 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from evaluate)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)
Collecting dill (from evaluate)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2

In [2]:
import torch
import json
from tqdm import tqdm
import torch.nn as nn
from torch.optim import Adam
import spacy
import string
import evaluate  # Bleu
from torch.utils.data import Dataset, DataLoader, RandomSampler
import pandas as pd
import numpy as np
import transformers
from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, T5TokenizerFast

import warnings
warnings.filterwarnings("ignore")

In [12]:
TOKENIZER = T5TokenizerFast.from_pretrained("allegro/plt5-base")
MODEL = T5ForConditionalGeneration.from_pretrained("allegro/plt5-base", return_dict=True)
OPTIMIZER = Adam(MODEL.parameters(), lr=0.00001)
Q_LEN = 256   # Question Length
T_LEN = 32    # Target Length
BATCH_SIZE = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/1.12M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/658 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


pytorch_model.bin:   0%|          | 0.00/1.10G [00:00<?, ?B/s]

## Download PoQuAD Dataset (Training and Validation)

In [3]:
#for AQA
!wget https://huggingface.co/datasets/clarin-pl/poquad/resolve/main/poquad-train.json
!wget https://huggingface.co/datasets/clarin-pl/poquad/resolve/main/poquad-dev.json

--2025-01-04 22:51:57--  https://huggingface.co/datasets/clarin-pl/poquad/resolve/main/poquad-train.json
Resolving huggingface.co (huggingface.co)... 65.8.243.46, 65.8.243.16, 65.8.243.90, ...
Connecting to huggingface.co (huggingface.co)|65.8.243.46|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.hf.co/repos/18/de/18ded45e8046dd5f58b7365947f5a4298433a0e7710248308670e8cf26059c20/b1ac3acabb49fedb7bb7db0de0690ddb22585d6419321589cc1bb0a8068a4ff9?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27poquad-train.json%3B+filename%3D%22poquad-train.json%22%3B&response-content-type=application%2Fjson&Expires=1736290317&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczNjI5MDMxN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8xOC9kZS8xOGRlZDQ1ZTgwNDZkZDVmNThiNzM2NTk0N2Y1YTQyOTg0MzNhMGU3NzEwMjQ4MzA4NjcwZThjZjI2MDU5YzIwL2IxYWMzYWNhYmI0OWZlZGI3YmI3ZGIwZGUwNjkwZGRiMjI1ODVkNjQxOTMyMTU4OWNjMWJiMG

In [4]:
with open('poquad-train.json') as f:
    train_data = json.load(f)
with open('poquad-dev.json') as f:
    validation_data = json.load(f)

In [5]:
print("CONTEXT\n", train_data['data'][0]['paragraphs'][0]['context'])
print("QESTION\n", train_data['data'][0]['paragraphs'][0]['qas'][0]['question'])
print("ANSWER\n", train_data['data'][0]['paragraphs'][0]['qas'][0]['answers'][0]['generative_answer'])

CONTEXT
 Projekty konfederacji zaczęły się załamywać 5 sierpnia 1942. Ponownie wróciła kwestia monachijska, co uaktywniło się wymianą listów Ripka – Stroński. Natomiast 17 sierpnia 1942 doszło do spotkania E. Beneša i J. Masaryka z jednej a Wł. Sikorskiego i E. Raczyńskiego z drugiej strony. Polscy dyplomaci zaproponowali podpisanie układu konfederacyjnego. W następnym miesiącu, tj. 24 września, strona polska przesłała na ręce J. Masaryka projekt deklaracji o przyszłej konfederacji obu państw. Strona czechosłowacka projekt przyjęła, lecz już w listopadzie 1942 E. Beneš podważył ideę konfederacji. W zamian zaproponowano zawarcie układu sojuszniczego z Polską na 20 lat (formalnie nastąpiło to 20 listopada 1942).
QESTION
 Co było powodem powrócenia konceptu porozumieniu monachijskiego?
ANSWER
 wymiana listów Ripka – Stroński


### Preprocessing

In [6]:
# Extracting context, question, and answers from the dataset

def prepare_data(data):
    articles = []

    for article in data["data"]:
        for paragraph in article["paragraphs"]:
            for qa in paragraph["qas"]:
                question = qa["question"]

                if not qa["is_impossible"]:
                  answer = qa["answers"][0]["text"]

                inputs = {"context": paragraph["context"], "question": question, "answer": answer}


                articles.append(inputs)

    return articles

train_data = prepare_data(train_data)
validation_data = prepare_data(validation_data)

# Create a Dataframe
train_data = pd.DataFrame(train_data)
validation_data = pd.DataFrame(validation_data)

In [None]:
class QA_Dataset(Dataset):
    def __init__(self, tokenizer, dataframe, q_len, t_len):
        self.tokenizer = tokenizer
        self.q_len = q_len
        self.t_len = t_len
        self.data = dataframe
        self.questions = self.data["question"]
        self.context = self.data["context"]
        self.answer = self.data['answer']

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        question = self.questions[idx]
        context = self.context[idx]
        answer = self.answer[idx]

        question_tokenized = self.tokenizer(question, context, max_length=self.q_len, padding="max_length",
                                                    truncation=True, pad_to_max_length=True, add_special_tokens=True)
        answer_tokenized = self.tokenizer(answer, max_length=self.t_len, padding="max_length",
                                          truncation=True, pad_to_max_length=True, add_special_tokens=True)

        labels = torch.tensor(answer_tokenized["input_ids"], dtype=torch.long)
        labels[labels == 0] = TOKENIZER.pad_token_id

        return {
            "input_ids": torch.tensor(question_tokenized["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(question_tokenized["attention_mask"], dtype=torch.long),
            "labels": labels,
            "decoder_attention_mask": torch.tensor(answer_tokenized["attention_mask"], dtype=torch.long)
        }

In [None]:
# Dataloader

train_sampler = RandomSampler(train_data.index)
val_sampler = RandomSampler(validation_data.index)

data =  pd.concat([train_data, validation_data], axis=0, ignore_index=True)
qa_dataset = QA_Dataset(TOKENIZER, data, Q_LEN, T_LEN)

train_loader = DataLoader(qa_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
val_loader = DataLoader(qa_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)

In [None]:
print("Training batches: ", len(train_loader))
print("Validation batches: ", len(val_loader))

Training batches:  7078
Validation batches:  883


# Training

In [None]:
import wandb
wandb.init(project="PJN9", name="athena_plt5-base_poquad")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112338169995281, max=1.0…

In [None]:
train_loss = 0.0
val_loss = 0.0
train_bleu_score = 0.0
val_bleu_score = 0.0
train_batch_count = 0
val_batch_count = 0
save_step = 300
log_step = 100
bleu_step = 100

bleu_metric = evaluate.load("google_bleu")
MODEL.to(DEVICE)

for epoch in range(3):
    MODEL.train()
    for batch in tqdm(train_loader, desc="Training batches"):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)

        outputs = MODEL(
                          input_ids=input_ids,
                          attention_mask=attention_mask,
                          labels=labels,
                          decoder_attention_mask=decoder_attention_mask
                        )

        OPTIMIZER.zero_grad()
        outputs.loss.backward()
        OPTIMIZER.step()
        train_loss += outputs.loss.item()
        train_batch_count += 1

        if train_batch_count  % bleu_step == 0:
          with torch.no_grad():
            outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask)

            # Convert outputs and labels to strings
            predictions = [TOKENIZER.decode(pred, skip_special_tokens=True) for pred in outputs]
            references = [TOKENIZER.decode(label, skip_special_tokens=True) for label in labels]

            # Compute BLEU scores
            for pred, ref in zip(predictions, references):
                bleu_score = bleu_metric.compute(predictions=[pred], references=[[ref]])
                train_bleu_score += bleu_score["google_bleu"]

        if train_batch_count % save_step == 0:
          MODEL.save_pretrained("qa_model")
          TOKENIZER.save_pretrained("qa_tokenizer")

        if train_batch_count  % log_step == 0 :
          print(f"{epoch+1}/{2} -> Train loss: {train_loss }")
          wandb.log({"train_loss": train_loss / train_batch_count, "train_bleu": train_bleu_score / train_batch_count})

    #Evaluation
    MODEL.eval()
    for batch in tqdm(val_loader, desc="Validation batches"):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)

        outputs = MODEL(
                          input_ids=input_ids,
                          attention_mask=attention_mask,
                          labels=labels,
                          decoder_attention_mask=decoder_attention_mask
                        )

        OPTIMIZER.zero_grad()
        outputs.loss.backward()
        OPTIMIZER.step()
        val_loss += outputs.loss.item()
        val_batch_count += 1

        if val_batch_count  % bleu_step == 0:
          with torch.no_grad():
            outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask)

            # Convert outputs and labels to strings
            predictions = [TOKENIZER.decode(pred, skip_special_tokens=True) for pred in outputs]
            references = [TOKENIZER.decode(label, skip_special_tokens=True) for label in labels]

            # Compute BLEU scores
            for pred, ref in zip(predictions, references):
                bleu_score = bleu_metric.compute(predictions=[pred], references=[[ref]])
                val_bleu_score += bleu_score["google_bleu"]

        if val_batch_count  % log_step == 0 :
          print(f"{epoch+1}/{2} -> Validation loss: {val_loss / val_batch_count }")
          wandb.log({"val_loss": val_loss / val_batch_count, "val_bleu": val_bleu_score / val_batch_count})


Training batches:   0%|          | 0/7078 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Training batches:   1%|▏         | 100/7078 [00:27<4:37:20,  2.38s/it]

1/2 -> Train loss: 12360.009284973145


Training batches:   3%|▎         | 201/7078 [00:43<30:46,  3.73it/s]  

1/2 -> Train loss: 22501.7628326416


Training batches:   4%|▍         | 301/7078 [01:01<1:12:49,  1.55it/s]

1/2 -> Train loss: 30858.819995880127


Training batches:   6%|▌         | 401/7078 [01:18<30:38,  3.63it/s]  

1/2 -> Train loss: 36379.84896850586


Training batches:   7%|▋         | 501/7078 [01:34<28:49,  3.80it/s]

1/2 -> Train loss: 39176.98579788208


Training batches:   8%|▊         | 601/7078 [01:52<59:20,  1.82it/s]  

1/2 -> Train loss: 40994.02756404877


Training batches:  10%|▉         | 701/7078 [02:08<29:10,  3.64it/s]

1/2 -> Train loss: 42384.30526924133


Training batches:  11%|█▏        | 801/7078 [02:24<28:52,  3.62it/s]

1/2 -> Train loss: 43556.978652477264


Training batches:  13%|█▎        | 901/7078 [02:42<1:01:44,  1.67it/s]

1/2 -> Train loss: 44598.7671456337


Training batches:  14%|█▍        | 1001/7078 [02:58<29:19,  3.45it/s] 

1/2 -> Train loss: 45507.84151315689


Training batches:  16%|█▌        | 1101/7078 [03:15<25:30,  3.91it/s]

1/2 -> Train loss: 46318.82102918625


Training batches:  17%|█▋        | 1200/7078 [03:32<1:12:10,  1.36it/s]

1/2 -> Train loss: 47092.749212265015


Training batches:  18%|█▊        | 1301/7078 [03:49<26:31,  3.63it/s]  

1/2 -> Train loss: 47792.74526309967


Training batches:  20%|█▉        | 1401/7078 [04:05<24:06,  3.92it/s]

1/2 -> Train loss: 48438.66938519478


Training batches:  21%|██        | 1500/7078 [04:22<1:09:42,  1.33it/s]

1/2 -> Train loss: 49054.55450749397


Training batches:  23%|██▎       | 1601/7078 [04:39<25:09,  3.63it/s]  

1/2 -> Train loss: 49611.786410331726


Training batches:  24%|██▍       | 1701/7078 [04:56<25:51,  3.47it/s]

1/2 -> Train loss: 50127.705444574356


Training batches:  25%|██▌       | 1801/7078 [05:13<45:40,  1.93it/s]

1/2 -> Train loss: 50648.53631043434


Training batches:  27%|██▋       | 1901/7078 [05:30<20:48,  4.15it/s]

1/2 -> Train loss: 51129.20186448097


Training batches:  28%|██▊       | 2001/7078 [05:46<19:23,  4.37it/s]

1/2 -> Train loss: 51619.69003653526


Training batches:  30%|██▉       | 2101/7078 [06:04<45:26,  1.83it/s]

1/2 -> Train loss: 52092.32055425644


Training batches:  31%|███       | 2201/7078 [06:20<18:17,  4.44it/s]

1/2 -> Train loss: 52531.72650337219


Training batches:  33%|███▎      | 2301/7078 [06:36<17:34,  4.53it/s]

1/2 -> Train loss: 52987.62722468376


Training batches:  34%|███▍      | 2401/7078 [06:54<44:11,  1.76it/s]

1/2 -> Train loss: 53419.18631219864


Training batches:  35%|███▌      | 2501/7078 [07:10<17:19,  4.40it/s]

1/2 -> Train loss: 53849.79552030563


Training batches:  37%|███▋      | 2601/7078 [07:26<17:36,  4.24it/s]

1/2 -> Train loss: 54269.05127596855


Training batches:  38%|███▊      | 2701/7078 [07:43<37:43,  1.93it/s]

1/2 -> Train loss: 54676.21487545967


Training batches:  40%|███▉      | 2801/7078 [07:59<16:08,  4.42it/s]

1/2 -> Train loss: 55108.47615027428


Training batches:  41%|████      | 2901/7078 [08:15<15:07,  4.60it/s]

1/2 -> Train loss: 55527.54719042778


Training batches:  42%|████▏     | 3001/7078 [08:33<35:46,  1.90it/s]

1/2 -> Train loss: 55939.8009288311


Training batches:  44%|████▍     | 3101/7078 [08:49<15:55,  4.16it/s]

1/2 -> Train loss: 56335.158019542694


Training batches:  45%|████▌     | 3201/7078 [09:05<14:55,  4.33it/s]

1/2 -> Train loss: 56718.80309057236


Training batches:  47%|████▋     | 3301/7078 [09:23<34:06,  1.85it/s]

1/2 -> Train loss: 57112.55438184738


Training batches:  48%|████▊     | 3401/7078 [09:39<14:37,  4.19it/s]

1/2 -> Train loss: 57497.07667803764


Training batches:  49%|████▉     | 3501/7078 [09:55<14:06,  4.23it/s]

1/2 -> Train loss: 57872.2218644619


Training batches:  51%|█████     | 3601/7078 [10:12<31:34,  1.84it/s]

1/2 -> Train loss: 58242.42957901955


Training batches:  52%|█████▏    | 3701/7078 [10:28<12:00,  4.69it/s]

1/2 -> Train loss: 58628.35886144638


Training batches:  54%|█████▎    | 3801/7078 [10:44<11:54,  4.58it/s]

1/2 -> Train loss: 58998.90818977356


Training batches:  55%|█████▌    | 3901/7078 [11:02<31:15,  1.69it/s]

1/2 -> Train loss: 59385.09986281395


Training batches:  57%|█████▋    | 4001/7078 [11:18<11:06,  4.62it/s]

1/2 -> Train loss: 59750.95061349869


Training batches:  58%|█████▊    | 4101/7078 [11:34<10:58,  4.52it/s]

1/2 -> Train loss: 60117.61939287186


Training batches:  59%|█████▉    | 4201/7078 [11:52<25:34,  1.87it/s]

1/2 -> Train loss: 60471.57260274887


Training batches:  61%|██████    | 4301/7078 [12:08<10:11,  4.54it/s]

1/2 -> Train loss: 60830.04697394371


Training batches:  62%|██████▏   | 4401/7078 [12:24<09:36,  4.64it/s]

1/2 -> Train loss: 61194.779997348785


Training batches:  64%|██████▎   | 4501/7078 [12:42<23:07,  1.86it/s]

1/2 -> Train loss: 61547.52303647995


Training batches:  65%|██████▌   | 4601/7078 [12:58<11:17,  3.66it/s]

1/2 -> Train loss: 61878.03069615364


Training batches:  66%|██████▋   | 4701/7078 [13:14<09:15,  4.28it/s]

1/2 -> Train loss: 62222.68729805946


Training batches:  68%|██████▊   | 4801/7078 [13:32<22:35,  1.68it/s]

1/2 -> Train loss: 62561.17247104645


Training batches:  69%|██████▉   | 4901/7078 [13:48<10:12,  3.56it/s]

1/2 -> Train loss: 62895.273767232895


Training batches:  71%|███████   | 5001/7078 [14:04<08:40,  3.99it/s]

1/2 -> Train loss: 63242.32701730728


Training batches:  72%|███████▏  | 5101/7078 [14:22<18:15,  1.80it/s]

1/2 -> Train loss: 63563.302941441536


Training batches:  73%|███████▎  | 5201/7078 [14:38<06:48,  4.60it/s]

1/2 -> Train loss: 63908.90073955059


Training batches:  75%|███████▍  | 5301/7078 [14:54<06:28,  4.58it/s]

1/2 -> Train loss: 64239.017127633095


Training batches:  76%|███████▋  | 5401/7078 [15:12<16:58,  1.65it/s]

1/2 -> Train loss: 64578.66784989834


Training batches:  78%|███████▊  | 5501/7078 [15:28<07:25,  3.54it/s]

1/2 -> Train loss: 64907.089725613594


Training batches:  79%|███████▉  | 5601/7078 [15:44<05:32,  4.45it/s]

1/2 -> Train loss: 65228.54045855999


Training batches:  81%|████████  | 5701/7078 [16:02<12:43,  1.80it/s]

1/2 -> Train loss: 65567.7711662054


Training batches:  82%|████████▏ | 5801/7078 [16:18<04:41,  4.53it/s]

1/2 -> Train loss: 65885.74466705322


Training batches:  83%|████████▎ | 5901/7078 [16:34<04:16,  4.58it/s]

1/2 -> Train loss: 66203.06707894802


Training batches:  85%|████████▍ | 6001/7078 [16:51<09:44,  1.84it/s]

1/2 -> Train loss: 66506.49312055111


Training batches:  86%|████████▌ | 6101/7078 [17:08<04:25,  3.69it/s]

1/2 -> Train loss: 66811.35331249237


Training batches:  88%|████████▊ | 6201/7078 [17:24<03:18,  4.42it/s]

1/2 -> Train loss: 67125.4141882658


Training batches:  89%|████████▉ | 6301/7078 [17:42<07:48,  1.66it/s]

1/2 -> Train loss: 67448.36952114105


Training batches:  90%|█████████ | 6401/7078 [17:58<02:23,  4.72it/s]

1/2 -> Train loss: 67769.67134714127


Training batches:  92%|█████████▏| 6501/7078 [18:14<02:33,  3.76it/s]

1/2 -> Train loss: 68082.98286867142


Training batches:  93%|█████████▎| 6601/7078 [18:31<04:15,  1.87it/s]

1/2 -> Train loss: 68398.34423136711


Training batches:  95%|█████████▍| 6701/7078 [18:48<01:42,  3.67it/s]

1/2 -> Train loss: 68713.2045674324


Training batches:  96%|█████████▌| 6801/7078 [19:04<00:58,  4.74it/s]

1/2 -> Train loss: 69017.13620448112


Training batches:  97%|█████████▋| 6901/7078 [19:21<01:41,  1.74it/s]

1/2 -> Train loss: 69318.46833920479


Training batches:  99%|█████████▉| 7001/7078 [19:38<00:21,  3.54it/s]

1/2 -> Train loss: 69616.67630600929


Training batches: 100%|██████████| 7078/7078 [19:50<00:00,  5.95it/s]
Validation batches:  11%|█▏        | 101/883 [00:15<02:45,  4.72it/s]

1/2 -> Validation loss: 2.516189423799515


Validation batches:  23%|██▎       | 201/883 [00:31<02:21,  4.83it/s]

1/2 -> Validation loss: 2.3767869064211844


Validation batches:  34%|███▍      | 301/883 [00:47<02:02,  4.75it/s]

1/2 -> Validation loss: 2.3932476498683295


Validation batches:  45%|████▌     | 401/883 [01:02<01:36,  5.00it/s]

1/2 -> Validation loss: 2.424078317731619


Validation batches:  57%|█████▋    | 501/883 [01:18<01:18,  4.85it/s]

1/2 -> Validation loss: 2.4005789204835892


Validation batches:  68%|██████▊   | 601/883 [01:34<01:03,  4.47it/s]

1/2 -> Validation loss: 2.387656969924768


Validation batches:  79%|███████▉  | 701/883 [01:50<00:39,  4.60it/s]

1/2 -> Validation loss: 2.349882140159607


Validation batches:  91%|█████████ | 801/883 [02:05<00:15,  5.18it/s]

1/2 -> Validation loss: 2.3418093636631965


Validation batches: 100%|██████████| 883/883 [02:18<00:00,  6.39it/s]
Training batches:   0%|          | 23/7078 [00:04<31:54,  3.69it/s]

2/2 -> Train loss: 69923.64677083492


Training batches:   2%|▏         | 123/7078 [00:21<1:02:19,  1.86it/s]

2/2 -> Train loss: 70202.355250597


Training batches:   3%|▎         | 223/7078 [00:37<23:27,  4.87it/s]  

2/2 -> Train loss: 70483.08365058899


Training batches:   5%|▍         | 323/7078 [00:53<23:30,  4.79it/s]

2/2 -> Train loss: 70763.16078531742


Training batches:   6%|▌         | 423/7078 [01:11<1:03:37,  1.74it/s]

2/2 -> Train loss: 71046.32493805885


Training batches:   7%|▋         | 523/7078 [01:27<23:36,  4.63it/s]  

2/2 -> Train loss: 71312.99698960781


Training batches:   9%|▉         | 623/7078 [01:43<22:23,  4.80it/s]

2/2 -> Train loss: 71572.02531635761


Training batches:  10%|█         | 723/7078 [02:00<56:01,  1.89it/s]  

2/2 -> Train loss: 71823.60099160671


Training batches:  12%|█▏        | 823/7078 [02:16<23:15,  4.48it/s]

2/2 -> Train loss: 72086.97299718857


Training batches:  13%|█▎        | 923/7078 [02:32<22:26,  4.57it/s]

2/2 -> Train loss: 72354.5524879694


Training batches:  14%|█▍        | 1023/7078 [02:50<53:02,  1.90it/s]  

2/2 -> Train loss: 72616.33982014656


Training batches:  16%|█▌        | 1123/7078 [03:06<22:19,  4.45it/s]

2/2 -> Train loss: 72884.52959918976


Training batches:  17%|█▋        | 1223/7078 [03:22<23:57,  4.07it/s]

2/2 -> Train loss: 73147.71038103104


Training batches:  19%|█▊        | 1323/7078 [03:39<53:41,  1.79it/s]  

2/2 -> Train loss: 73407.55228424072


Training batches:  20%|██        | 1423/7078 [03:55<19:28,  4.84it/s]

2/2 -> Train loss: 73677.0383630991


Training batches:  22%|██▏       | 1523/7078 [04:12<27:34,  3.36it/s]

2/2 -> Train loss: 73932.41837358475


Training batches:  23%|██▎       | 1623/7078 [04:29<51:11,  1.78it/s]  

2/2 -> Train loss: 74187.44520938396


Training batches:  24%|██▍       | 1723/7078 [04:45<19:35,  4.56it/s]

2/2 -> Train loss: 74436.0386582017


Training batches:  26%|██▌       | 1823/7078 [05:01<19:02,  4.60it/s]

2/2 -> Train loss: 74679.29565763474


Training batches:  27%|██▋       | 1923/7078 [05:19<47:21,  1.81it/s]  

2/2 -> Train loss: 74917.74951648712


Training batches:  29%|██▊       | 2023/7078 [05:35<18:58,  4.44it/s]

2/2 -> Train loss: 75146.71625220776


Training batches:  30%|██▉       | 2123/7078 [05:51<21:36,  3.82it/s]

2/2 -> Train loss: 75380.35171318054


Training batches:  31%|███▏      | 2223/7078 [06:08<43:48,  1.85it/s]

2/2 -> Train loss: 75600.42535352707


Training batches:  33%|███▎      | 2323/7078 [06:24<17:19,  4.57it/s]

2/2 -> Train loss: 75808.33392053843


Training batches:  34%|███▍      | 2423/7078 [06:40<17:31,  4.42it/s]

2/2 -> Train loss: 76010.29871690273


Training batches:  36%|███▌      | 2523/7078 [06:58<43:53,  1.73it/s]

2/2 -> Train loss: 76191.7491748929


Training batches:  37%|███▋      | 2623/7078 [07:14<18:50,  3.94it/s]

2/2 -> Train loss: 76372.49791532755


Training batches:  38%|███▊      | 2723/7078 [07:31<19:32,  3.71it/s]

2/2 -> Train loss: 76540.55536651611


Training batches:  40%|███▉      | 2822/7078 [07:48<49:53,  1.42it/s]

2/2 -> Train loss: 76706.82022154331


Training batches:  41%|████▏     | 2923/7078 [08:04<15:19,  4.52it/s]

2/2 -> Train loss: 76863.11125218868


Training batches:  43%|████▎     | 3023/7078 [08:20<14:01,  4.82it/s]

2/2 -> Train loss: 77005.15353757143


Training batches:  44%|████▍     | 3123/7078 [08:38<33:53,  1.94it/s]

2/2 -> Train loss: 77147.40700930357


Training batches:  46%|████▌     | 3223/7078 [08:54<14:02,  4.58it/s]

2/2 -> Train loss: 77289.58970218897


Training batches:  47%|████▋     | 3323/7078 [09:10<13:32,  4.62it/s]

2/2 -> Train loss: 77421.70805644989


Training batches:  48%|████▊     | 3423/7078 [09:27<34:16,  1.78it/s]

2/2 -> Train loss: 77554.02546286583


Training batches:  50%|████▉     | 3523/7078 [09:43<12:45,  4.65it/s]

2/2 -> Train loss: 77671.60230481625


Training batches:  51%|█████     | 3623/7078 [09:59<12:27,  4.62it/s]

2/2 -> Train loss: 77789.03451931477


Training batches:  53%|█████▎    | 3723/7078 [10:17<32:35,  1.72it/s]

2/2 -> Train loss: 77908.91206598282


Training batches:  54%|█████▍    | 3823/7078 [10:33<13:23,  4.05it/s]

2/2 -> Train loss: 78022.8169580698


Training batches:  55%|█████▌    | 3923/7078 [10:49<12:01,  4.38it/s]

2/2 -> Train loss: 78137.73376142979


Training batches:  57%|█████▋    | 4023/7078 [11:07<30:20,  1.68it/s]

2/2 -> Train loss: 78244.16376525164


Training batches:  58%|█████▊    | 4123/7078 [11:25<10:08,  4.85it/s]

2/2 -> Train loss: 78353.95898449421


Training batches:  60%|█████▉    | 4223/7078 [11:41<12:41,  3.75it/s]

2/2 -> Train loss: 78454.4160298109


Training batches:  61%|██████    | 4323/7078 [11:59<27:01,  1.70it/s]

2/2 -> Train loss: 78556.18433302641


Training batches:  62%|██████▏   | 4423/7078 [12:15<10:50,  4.08it/s]

2/2 -> Train loss: 78652.84200799465


Training batches:  64%|██████▍   | 4523/7078 [12:31<12:28,  3.42it/s]

2/2 -> Train loss: 78747.34113729


Training batches:  65%|██████▌   | 4623/7078 [12:49<24:07,  1.70it/s]

2/2 -> Train loss: 78845.8622764349


Training batches:  67%|██████▋   | 4723/7078 [13:05<11:32,  3.40it/s]

2/2 -> Train loss: 78940.19147872925


Training batches:  68%|██████▊   | 4823/7078 [13:21<08:35,  4.37it/s]

2/2 -> Train loss: 79026.55370128155


Training batches:  70%|██████▉   | 4923/7078 [13:39<21:46,  1.65it/s]

2/2 -> Train loss: 79120.3425540328


Training batches:  71%|███████   | 5023/7078 [13:55<08:54,  3.85it/s]

2/2 -> Train loss: 79207.00697916746


Training batches:  72%|███████▏  | 5123/7078 [14:11<07:17,  4.47it/s]

2/2 -> Train loss: 79299.5887926817


Training batches:  74%|███████▍  | 5223/7078 [14:29<17:44,  1.74it/s]

2/2 -> Train loss: 79386.44330826402


Training batches:  75%|███████▌  | 5323/7078 [14:45<07:41,  3.80it/s]

2/2 -> Train loss: 79470.96654039621


Training batches:  77%|███████▋  | 5423/7078 [15:01<06:27,  4.27it/s]

2/2 -> Train loss: 79561.03385573626


Training batches:  78%|███████▊  | 5523/7078 [15:19<14:21,  1.81it/s]

2/2 -> Train loss: 79640.45470386744


Training batches:  79%|███████▉  | 5623/7078 [15:35<06:24,  3.79it/s]

2/2 -> Train loss: 79722.68061083555


Training batches:  81%|████████  | 5723/7078 [15:51<06:03,  3.73it/s]

2/2 -> Train loss: 79802.37908872962


Training batches:  82%|████████▏ | 5823/7078 [16:09<12:16,  1.71it/s]

2/2 -> Train loss: 79880.55460980535


Training batches:  84%|████████▎ | 5923/7078 [16:25<04:50,  3.98it/s]

2/2 -> Train loss: 79959.0523237288


Training batches:  85%|████████▌ | 6023/7078 [16:41<04:14,  4.15it/s]

2/2 -> Train loss: 80040.91224822402


Training batches:  87%|████████▋ | 6123/7078 [16:58<09:46,  1.63it/s]

2/2 -> Train loss: 80114.89337599277


Training batches:  88%|████████▊ | 6223/7078 [17:15<03:41,  3.85it/s]

2/2 -> Train loss: 80191.08510255814


Training batches:  89%|████████▉ | 6323/7078 [17:31<03:42,  3.40it/s]

2/2 -> Train loss: 80267.66514313221


Training batches:  91%|█████████ | 6423/7078 [17:50<10:20,  1.06it/s]

2/2 -> Train loss: 80336.97483560443


Training batches:  92%|█████████▏| 6523/7078 [18:07<02:29,  3.72it/s]

2/2 -> Train loss: 80412.47339352965


Training batches:  94%|█████████▎| 6623/7078 [18:23<01:53,  4.01it/s]

2/2 -> Train loss: 80487.5395680368


Training batches:  95%|█████████▍| 6723/7078 [18:41<03:24,  1.74it/s]

2/2 -> Train loss: 80558.13764417171


Training batches:  96%|█████████▋| 6823/7078 [18:57<00:57,  4.43it/s]

2/2 -> Train loss: 80636.48424261808


Training batches:  98%|█████████▊| 6923/7078 [19:13<00:42,  3.69it/s]

2/2 -> Train loss: 80708.10680276155


Training batches:  99%|█████████▉| 7023/7078 [19:31<00:33,  1.65it/s]

2/2 -> Train loss: 80779.1449765861


Training batches: 100%|██████████| 7078/7078 [19:39<00:00,  6.00it/s]
Validation batches:   2%|▏         | 18/883 [00:03<03:38,  3.96it/s]

2/2 -> Validation loss: 2.2972139370772573


Validation batches:  13%|█▎        | 118/883 [00:18<02:57,  4.32it/s]

2/2 -> Validation loss: 2.115534742027521


Validation batches:  25%|██▍       | 218/883 [00:34<02:56,  3.77it/s]

2/2 -> Validation loss: 1.9653044247085398


Validation batches:  36%|███▌      | 318/883 [00:50<02:04,  4.56it/s]

2/2 -> Validation loss: 1.8379453493158022


Validation batches:  47%|████▋     | 418/883 [01:06<02:01,  3.82it/s]

2/2 -> Validation loss: 1.7269989161766492


Validation batches:  59%|█████▊    | 518/883 [01:22<01:31,  3.97it/s]

2/2 -> Validation loss: 1.6337031136346716


Validation batches:  70%|██████▉   | 618/883 [01:38<01:08,  3.85it/s]

2/2 -> Validation loss: 1.5537616607844829


Validation batches:  81%|████████▏ | 718/883 [01:53<00:37,  4.35it/s]

2/2 -> Validation loss: 1.4812875142041593


Validation batches:  93%|█████████▎| 818/883 [02:09<00:15,  4.32it/s]

2/2 -> Validation loss: 1.4171891065380153


Validation batches: 100%|██████████| 883/883 [02:19<00:00,  6.33it/s]
Training batches:   1%|          | 45/7078 [00:07<32:19,  3.63it/s]

3/2 -> Train loss: 80846.65499278903


Training batches:   2%|▏         | 145/7078 [00:23<31:27,  3.67it/s]

3/2 -> Train loss: 80920.45653566718


Training batches:   3%|▎         | 245/7078 [00:41<1:09:47,  1.63it/s]

3/2 -> Train loss: 80989.03704714775


Training batches:   5%|▍         | 345/7078 [00:57<27:27,  4.09it/s]  

3/2 -> Train loss: 81054.28622248769


Training batches:   6%|▋         | 445/7078 [01:13<27:16,  4.05it/s]

3/2 -> Train loss: 81117.96612977982


Training batches:   8%|▊         | 545/7078 [01:31<1:14:26,  1.46it/s]

3/2 -> Train loss: 81182.74523377419


Training batches:   9%|▉         | 645/7078 [01:48<26:59,  3.97it/s]  

3/2 -> Train loss: 81248.02528071404


Training batches:  11%|█         | 745/7078 [02:04<26:27,  3.99it/s]

3/2 -> Train loss: 81312.54811540246


Training batches:  12%|█▏        | 845/7078 [02:21<58:53,  1.76it/s]  

3/2 -> Train loss: 81375.56127434969


Training batches:  13%|█▎        | 945/7078 [02:37<26:33,  3.85it/s]

3/2 -> Train loss: 81439.4739895165


Training batches:  15%|█▍        | 1045/7078 [02:54<26:31,  3.79it/s]

3/2 -> Train loss: 81502.12039297819


Training batches:  16%|█▌        | 1145/7078 [03:11<1:01:48,  1.60it/s]

3/2 -> Train loss: 81564.64668104053


Training batches:  18%|█▊        | 1245/7078 [03:28<26:03,  3.73it/s]  

3/2 -> Train loss: 81629.9521804452


Training batches:  19%|█▉        | 1345/7078 [03:44<24:26,  3.91it/s]

3/2 -> Train loss: 81690.46054214239


Training batches:  20%|██        | 1445/7078 [04:02<58:48,  1.60it/s]  

3/2 -> Train loss: 81753.63490504026


Training batches:  22%|██▏       | 1545/7078 [04:18<24:51,  3.71it/s]

3/2 -> Train loss: 81816.12073603272


Training batches:  23%|██▎       | 1645/7078 [04:34<22:13,  4.07it/s]

3/2 -> Train loss: 81873.4615727961


Training batches:  25%|██▍       | 1745/7078 [04:52<54:33,  1.63it/s]  

3/2 -> Train loss: 81938.89228978753


Training batches:  26%|██▌       | 1845/7078 [05:08<19:39,  4.44it/s]

3/2 -> Train loss: 81999.65078613162


Training batches:  27%|██▋       | 1945/7078 [05:24<22:54,  3.73it/s]

3/2 -> Train loss: 82060.21382656693


Training batches:  29%|██▉       | 2045/7078 [05:42<51:31,  1.63it/s]  

3/2 -> Train loss: 82122.14215505123


Training batches:  30%|███       | 2145/7078 [06:01<1:03:54,  1.29it/s]

3/2 -> Train loss: 82180.87777584791


Training batches:  32%|███▏      | 2245/7078 [06:23<1:57:58,  1.46s/it]

3/2 -> Train loss: 82237.10917189717


Training batches:  33%|███▎      | 2345/7078 [06:41<48:44,  1.62it/s]  

3/2 -> Train loss: 82294.218524158


Training batches:  35%|███▍      | 2445/7078 [06:57<20:24,  3.78it/s]

3/2 -> Train loss: 82358.23810473084


Training batches:  36%|███▌      | 2545/7078 [07:13<20:18,  3.72it/s]

3/2 -> Train loss: 82419.92773115635


Training batches:  37%|███▋      | 2644/7078 [07:38<3:48:38,  3.09s/it]

3/2 -> Train loss: 82478.58764833212


Training batches:  39%|███▉      | 2745/7078 [07:55<18:36,  3.88it/s]  

3/2 -> Train loss: 82535.50272414088


Training batches:  40%|████      | 2845/7078 [08:11<18:05,  3.90it/s]

3/2 -> Train loss: 82589.679466784


Training batches:  42%|████▏     | 2945/7078 [08:29<42:17,  1.63it/s]

3/2 -> Train loss: 82643.3537362814


Training batches:  43%|████▎     | 3045/7078 [08:45<18:14,  3.68it/s]

3/2 -> Train loss: 82702.78142407537


Training batches:  44%|████▍     | 3145/7078 [09:01<16:39,  3.94it/s]

3/2 -> Train loss: 82758.40924584866


Training batches:  46%|████▌     | 3245/7078 [09:19<36:54,  1.73it/s]

3/2 -> Train loss: 82812.97612795234


Training batches:  47%|████▋     | 3345/7078 [09:35<16:59,  3.66it/s]

3/2 -> Train loss: 82868.82514980435


Training batches:  49%|████▊     | 3445/7078 [09:51<17:04,  3.55it/s]

3/2 -> Train loss: 82922.24834933877


Training batches:  50%|█████     | 3545/7078 [10:09<34:41,  1.70it/s]

3/2 -> Train loss: 82977.37394094467


Training batches:  51%|█████▏    | 3645/7078 [10:25<13:41,  4.18it/s]

3/2 -> Train loss: 83037.17138186097


Training batches:  53%|█████▎    | 3745/7078 [10:43<14:11,  3.91it/s]

3/2 -> Train loss: 83091.53247836232


Training batches:  54%|█████▍    | 3845/7078 [11:01<32:54,  1.64it/s]

3/2 -> Train loss: 83143.87741217017


Training batches:  56%|█████▌    | 3945/7078 [11:17<14:10,  3.68it/s]

3/2 -> Train loss: 83198.00682583451


Training batches:  57%|█████▋    | 4045/7078 [11:33<12:30,  4.04it/s]

3/2 -> Train loss: 83250.84979119897


Training batches:  59%|█████▊    | 4145/7078 [11:51<28:47,  1.70it/s]

3/2 -> Train loss: 83304.08578261733


Training batches:  60%|█████▉    | 4245/7078 [12:07<10:52,  4.34it/s]

3/2 -> Train loss: 83358.8458174169


Training batches:  61%|██████▏   | 4345/7078 [12:23<12:16,  3.71it/s]

3/2 -> Train loss: 83409.92642009258


Training batches:  63%|██████▎   | 4445/7078 [12:41<27:15,  1.61it/s]

3/2 -> Train loss: 83464.8101350367


Training batches:  64%|██████▍   | 4545/7078 [12:57<11:38,  3.62it/s]

3/2 -> Train loss: 83518.0173791945


Training batches:  66%|██████▌   | 4645/7078 [13:13<12:05,  3.35it/s]

3/2 -> Train loss: 83572.69579336047


Training batches:  67%|██████▋   | 4745/7078 [13:31<24:14,  1.60it/s]

3/2 -> Train loss: 83624.4655534327


Training batches:  68%|██████▊   | 4845/7078 [13:47<10:00,  3.72it/s]

3/2 -> Train loss: 83678.155556947


Training batches:  70%|██████▉   | 4945/7078 [14:04<09:38,  3.69it/s]

3/2 -> Train loss: 83732.41226640344


Training batches:  71%|███████▏  | 5045/7078 [14:21<19:26,  1.74it/s]

3/2 -> Train loss: 83784.78215107322


Training batches:  73%|███████▎  | 5145/7078 [14:37<09:05,  3.54it/s]

3/2 -> Train loss: 83837.27628317475


Training batches:  74%|███████▍  | 5245/7078 [14:54<08:47,  3.47it/s]

3/2 -> Train loss: 83892.44186770916


Training batches:  76%|███████▌  | 5345/7078 [15:11<17:33,  1.65it/s]

3/2 -> Train loss: 83946.04416623712


Training batches:  77%|███████▋  | 5445/7078 [15:28<07:22,  3.69it/s]

3/2 -> Train loss: 83996.46429869533


Training batches:  78%|███████▊  | 5545/7078 [15:44<06:13,  4.10it/s]

3/2 -> Train loss: 84047.11432942748


Training batches:  80%|███████▉  | 5645/7078 [16:02<14:15,  1.67it/s]

3/2 -> Train loss: 84099.79309102893


Training batches:  81%|████████  | 5745/7078 [16:18<06:16,  3.54it/s]

3/2 -> Train loss: 84154.09299057722


Training batches:  83%|████████▎ | 5845/7078 [16:34<05:23,  3.81it/s]

3/2 -> Train loss: 84203.59773889184


Training batches:  84%|████████▍ | 5945/7078 [16:52<11:45,  1.61it/s]

3/2 -> Train loss: 84252.26049607992


Training batches:  85%|████████▌ | 6045/7078 [17:09<09:20,  1.84it/s]

3/2 -> Train loss: 84306.23607617617


Training batches:  87%|████████▋ | 6145/7078 [17:25<04:17,  3.62it/s]

3/2 -> Train loss: 84356.71544349194


Training batches:  88%|████████▊ | 6245/7078 [17:43<08:31,  1.63it/s]

3/2 -> Train loss: 84412.8974519968


Training batches:  90%|████████▉ | 6345/7078 [18:00<03:22,  3.62it/s]

3/2 -> Train loss: 84461.79625481367


Training batches:  91%|█████████ | 6445/7078 [18:16<02:52,  3.66it/s]

3/2 -> Train loss: 84509.36339354515


Training batches:  92%|█████████▏| 6545/7078 [18:34<05:37,  1.58it/s]

3/2 -> Train loss: 84558.93475824594


Training batches:  94%|█████████▍| 6645/7078 [18:50<01:56,  3.71it/s]

3/2 -> Train loss: 84609.22217062116


Training batches:  95%|█████████▌| 6745/7078 [19:07<01:29,  3.72it/s]

3/2 -> Train loss: 84657.41616207361


Training batches:  97%|█████████▋| 6845/7078 [19:24<02:20,  1.65it/s]

3/2 -> Train loss: 84708.13770273328


Training batches:  98%|█████████▊| 6945/7078 [19:42<01:12,  1.85it/s]

3/2 -> Train loss: 84756.74005168676


Training batches: 100%|█████████▉| 7045/7078 [19:58<00:09,  3.62it/s]

3/2 -> Train loss: 84807.13763532043


Training batches: 100%|██████████| 7078/7078 [20:03<00:00,  5.88it/s]
Validation batches:   4%|▍         | 35/883 [00:05<03:19,  4.25it/s]

3/2 -> Validation loss: 1.3598390261001057


Validation batches:  15%|█▌        | 135/883 [00:22<03:43,  3.34it/s]

3/2 -> Validation loss: 1.3079193010220402


Validation batches:  27%|██▋       | 235/883 [00:38<02:53,  3.73it/s]

3/2 -> Validation loss: 1.2593272935152053


Validation batches:  38%|███▊      | 335/883 [00:54<02:34,  3.54it/s]

3/2 -> Validation loss: 1.216939185957114


Validation batches:  49%|████▉     | 435/883 [01:09<01:48,  4.13it/s]

3/2 -> Validation loss: 1.1762588285451585


Validation batches:  61%|██████    | 535/883 [01:25<01:33,  3.70it/s]

3/2 -> Validation loss: 1.140374582319156


Validation batches:  72%|███████▏  | 635/883 [01:41<01:06,  3.75it/s]

3/2 -> Validation loss: 1.1078671198524535


Validation batches:  83%|████████▎ | 735/883 [01:57<00:37,  3.93it/s]

3/2 -> Validation loss: 1.0774873459994794


Validation batches:  95%|█████████▍| 835/883 [02:13<00:12,  3.71it/s]

3/2 -> Validation loss: 1.048431235546103


Validation batches: 100%|██████████| 883/883 [02:20<00:00,  6.27it/s]


# Example predictions

In [18]:
TOKENIZER = T5TokenizerFast.from_pretrained("qa_tokenizer")
MODEL = T5ForConditionalGeneration.from_pretrained("qa_model")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [14]:
def predict_answer(context, question, ref_answer=None):
    inputs = TOKENIZER(question, context, max_length=Q_LEN, padding="max_length", truncation=True, add_special_tokens=True)

    input_ids = torch.tensor(inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
    attention_mask = torch.tensor(inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)

    outputs = MODEL.generate(input_ids=input_ids, attention_mask=attention_mask)

    predicted_answer = TOKENIZER.decode(outputs.flatten(), skip_special_tokens=True)

    if ref_answer and ref_answer != '':
        # Load the Bleu metric
        bleu = evaluate.load("google_bleu")
        score = bleu.compute(predictions=[predicted_answer],
                            references=[ref_answer])
        return {
            "ref_answer": ref_answer,
            "predicted_answer": predicted_answer,
            "bleu_score": score['google_bleu']
        }
    else:
       return {
            "ref_answer": None,
            "predicted_answer": predicted_answer,
            "bleu_score": None
        }
def print_prediction(context, question, pred):
  print("Context: \n", context)
  print("Question: \n", question)
  print("\nReference Answer: ", pred['ref_answer'])
  print("Predicted Answer: ", pred['predicted_answer'])
  print("BLEU Score: ", round(pred['bleu_score'], 3))

In [17]:
!cp /content/drive/MyDrive/poquad/model.safetensors qa_model/

#### PoQuAD Train Set

In [19]:
context = train_data.iloc[20]["context"]
question = train_data.iloc[20]["question"]
ref_answer = train_data.iloc[20]["answer"]
pred = predict_answer(context, question, ref_answer)
print_prediction(context, question, pred)
print("\n\n")

context = train_data.iloc[40]["context"]
question = train_data.iloc[40]["question"]
ref_answer = train_data.iloc[40]["answer"]
pred = predict_answer(context, question, ref_answer)
print_prediction(context, question, pred)
print("\n\n")

context = train_data.iloc[60]["context"]
question = train_data.iloc[60]["question"]
ref_answer = train_data.iloc[60]["answer"]
pred = predict_answer(context, question, ref_answer)
print_prediction(context, question, pred)
print("\n\n")

Context: 
 W sezonie 1959 nowym szkoleniowcem Górnika został węgierski trener János Steiner, z którym Kowal współpracował w Legii. Kowal wystąpił we wszystkich dwudziestu dwóch meczach ligowych i zdobył sześć bramek, spędzając na boisku 1935 minut. Jedynym spotkaniem, którego nie dograł w pełnym wymiarze czasowym był mecz przeciwko Ruchowi Chorzów (2:2, 25 października 1959 roku), kiedy to został zmieniony przez Manfreda Fojcika. Kowal był odpowiedzialny za wykonywanie rzutów karnych, które wykorzystał w spotkaniach z Legią (1:2, 10 maja 1959 roku) oraz Lechią Gdańsk (3:2, 1 listopada 1959 roku). Bramkarza Lechii pokonał nie biorąc rozbiegu w kierunku piłki. Przed strzałem wykonał zwód, po którym Henryk Gronowski rzucił się w jeden róg bramki, a piłka potoczyła się w drugi. Górnik zapewnił sobie tytuł mistrzowski na trzy kolejki przez zakończeniem rozgrywek, natomiast Kowal był uważany za piłkarza w szczytowej formie oraz inteligentnego konstruktora akcji.
Question: 
 Z jakiego kraju p

#### PoQuAD Validation Set

In [20]:
context = validation_data.iloc[20]["context"]
question = validation_data.iloc[20]["question"]
ref_answer = validation_data.iloc[20]["answer"]
pred = predict_answer(context, question, ref_answer)
print_prediction(context, question, pred)
print("\n\n")

context = validation_data.iloc[40]["context"]
question = validation_data.iloc[40]["question"]
ref_answer = validation_data.iloc[40]["answer"]
pred = predict_answer(context, question, ref_answer)
print_prediction(context, question, pred)
print("\n\n")

context = validation_data.iloc[60]["context"]
question = validation_data.iloc[60]["question"]
ref_answer = validation_data.iloc[60]["answer"]
pred = predict_answer(context, question, ref_answer)
print_prediction(context, question, pred)
print("\n\n")

Context: 
 Od 2014 roku w Superpucharze, z inicjatywy prezesa PZPN Zbigniewa Bońka, nastąpiła zmiana, gdyż po 8 latach przerwy związek podjął na swoje barki organizację rozgrywek o Superpuchar. O trofeum walczyć będą Mistrz Polski oraz zdobywca Pucharu Polski sezonu zakończonego w roku rozgrywania Superpucharu. Tym samym powrócono do nazwy Superpuchar Polski. Mecz rozgrywany będzie na stadionie Mistrza Polski, na około tydzień przed startem nowego sezonu Ekstraklasy. W sytuacji, gdy ten sam klub sięgnie po mistrzostwo i Puchar Polski, jego rywalem w meczu o trofeum będzie finalista ostatniej edycji Pucharu Polski.
Question: 
 Kto piastował stanowisko prezesa Polskiego Związku Piłki Nożnej poczynając od 2014 roku?

Reference Answer:  pączki, faworki i bliny
Predicted Answer:  na stadionie Mistrza Polski
BLEU Score:  0.0



Context: 
 Eminem używa różnych tożsamości w swoich piosenkach, by korzystać z różnych stylów rapowania i różnych podmiotów. Jego najbardziej znane i popularne alter 

## Evaluation on Simple Legal Questions Dataset

### Loading dataset from json files

In [None]:
import json

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

answers = load_jsonl('answers.jl')
questions = load_jsonl('questions.jl')
passages = load_jsonl('passages.jl')
relevant = load_jsonl('relevant.jl')

In [None]:
print(answers[0])

{'score': '1', 'question-id': '1', 'answer': 'Tak, podlega karze aresztu wojskowego albo pozbawienia wolności do lat 3.'}


In [None]:
test_set = []

for row in relevant:
  context_id = row["passage-id"]
  question_id = row["question-id"]

  for passage in passages:
    if passage["_id"] == context_id:
      context_text = passage["text"]
      break
  question_text = questions[int(question_id)-1]["text"]

  for answer in answers:
    if answer["question-id"] == row["question-id"]:
      answer_text = answer["answer"]
      break

  test_set.append({
      "context": context_text,
      "question": question_text,
      "answer": answer_text
  })


In [None]:
print("Simple Legal Ouestions Test Set Example\n")
print("context: ", test_set[0]['context'])
print("question: ",test_set[0]['question'])
print("answer: ", test_set[0]['answer'])

Simple Legal Ouestions Test Set Example

context:  Art. 345. § 1. Żołnierz, który dopuszcza się czynnej napaści na przełożonego, podlega karze aresztu wojskowego albo pozbawienia wolności do lat 3. § 2. Jeżeli sprawca dopuszcza się czynnej napaści w związku z pełnieniem przez przełożonego obowiązków służbowych albo wspólnie z innymi żołnierzami lub w obecności zebranych żołnierzy, podlega karze pozbawienia wolności od 6 miesięcy do lat 8. § 3. Jeżeli sprawca czynu określonego w § 1 lub 2 używa broni, noża lub innego podobnie niebezpiecznego przedmiotu, podlega karze pozbawienia wolności od roku do lat 10. § 4. Karze przewidzianej w § 3 podlega sprawca czynu określonego w § 1 lub 2, jeżeli jego następstwem jest skutek określony w art. 156 lub 157 § 1.
question:  Czy żołnierz, który dopuszcza się czynnej napaści na przełożonego podlega karze pozbawienia wolności?
answer:  Tak, podlega karze aresztu wojskowego albo pozbawienia wolności do lat 3.


### Example Predictions with BLEU score

#### Simple Legal Questions Dataset

In [None]:
context = test_set[20]["context"]
question = test_set[20]["question"]
ref_answer = test_set[20]["answer"]
pred = predict_answer(context, question, ref_answer)
print_prediction(context, question, pred)
print("\n\n")

context = test_set[40]["context"]
question = test_set[40]["question"]
ref_answer = test_set[40]["answer"]
pred = predict_answer(context, question, ref_answer)
print_prediction(context, question, pred)
print("\n\n")

context = test_set[60]["context"]
question = test_set[60]["question"]
ref_answer = test_set[60]["answer"]
pred = predict_answer(context, question, ref_answer)
print_prediction(context, question, pred)
print("\n\n")


Context: 
 Art. 209. § 1. Kto uporczywie uchyla się od wykonania ciążącego na nim z mocy ustawy lub orzeczenia sądowego obowiązku opieki przez niełożenie na utrzymanie osoby najbliższej lub innej osoby i przez to naraża ją na niemożność zaspokojenia podstawowych potrzeb życiowych, podlega grzywnie, karze ograniczenia wolności albo pozbawienia wolności do lat 2. § 2. Ściganie następuje na wniosek pokrzywdzonego, organu opieki społecznej lub właściwej instytucji. § 3. Jeżeli pokrzywdzonemu przyznano świadczenie z funduszu alimentacyjnego, ściganie odbywa się z urzędu.


Question: 
 Jak ściga się świadczenia w ramach funduszu alimentacyjnego?

Reference Answer:  Ściganie następuje na wniosek pokrzywdzonego, organu opieki społecznej lub właściwej instytucji.
Predicted Answer:  z urzędu
BLEU Score:  0.0



Context: 
 Art. 58. 1. Kontrolujący w toku kontroli może również dokonać przeszukania pomieszczeń lub rzeczy, za zgodą sądu antymonopolowego, udzieloną na wniosek Prezesa Urzędu. Przy pr

### Mean BLEU for whole testing dataset

In [None]:
scores = []
mean_bleu = 0.0

for i in range(len(test_set)):
  context = test_set[i]["context"]
  question = test_set[i]["question"]
  ref_answer = test_set[i]["answer"]
  pred = predict_answer(context, question, ref_answer)
  if pred['bleu_score'] is not None:
    scores.append(pred['bleu_score'])

mean_bleu = sum(scores) / len(scores)
print("Mean BLEU Score: ", round(mean_bleu, 3))


Mean BLEU Score:  0.056


# Summary

> Does the performance on the validation dataset reflects the performance on your test set?

They are both bad...

>What are the outcomes of the model on your test questions? Are they satisfying? If not, what might be the reason for that?

Most of the answers seem to be to abstracted from context. Either halucinations or inability to work with eg. numbers. Both questions ans the ground truth answers might not be clear or correct even. So some model predictions might be correct but not similar to the ground truth in BLEU measure.


>Why extractive question answering is not well suited for inflectional languages?

The answers to the questions may lack context and might not be correctly lemmatized. The syntax of extracted answer could contain errors.