# Install necessary packages

In [1]:
!pip install datasets
!pip install evaluate



# About this project
- This project was the final project of Natural Language Processing course (CSCI 5832) at University of Colorado Boulder
- The main purpose of the final project was to solve a subset of one of [SemEval-2024 Tasks](https://semeval.github.io/SemEval2024/tasks.html).
- This project is in particular focusing on the subtask 1 of [Task 3: The Competition of Multimodal Emotion Cause Analysis in Conversations](https://nustm.github.io/SemEval-2024_ECAC/).
- The project team members are:
  1. Jooseok Lee
  2. Seungwook Lee
- Simplified from the original project

# Introduction
- In this project, we aimed to solve the problem of textual Emotion-Cause Pair Extraction (ECPE), which is the first sub-task of SemEval-2024 Task 3,  using text classification and question answering framework.
- The main purpose of textual ECPE is to find all sets of emotion-cause pairs where each utterance (i.e., small subset of a conversation) is matched with a single or multiple textual cause span(s) along with its emotional category.


<div align="center">
    <img src="img/ECPE_overview.jpg" alt="Overview of ECPE" width="500">
</div>

# Approach
- While the original paper solved this problem using a single solution, in this project we utilized two separate natural language processing (NLP) frameworks to solve it; text classification and question answering.
- That is, we splitted the original problem into two separate sub problems and solved them independently.
- In our approach, the text classification model is responsible for determining the emotional category of a given utterance.
- A single utterance is given to a classification model and it predicts the six emotional categories of the utterance (i.e., Anger, Disgust, Fear, Joy, Sadness and Surprise).
- We fine-tuned the publicly available text classification large language model (LLM) (i.e., BERT classification model) to solve this sub-problem.
- Then, we utilized a question answering model to find textual cause span(s) of a given utterance.
- In particular, we changed our data to Stanford Question Answering Dataset (SQuAD) format to fine-tune publicly available question answering LLM (i.e., DistilBERT question answering model).

<div align="center">
    <img src="img/Approach.jpg" alt="Overview of ECPE" width="500">
</div>

# Data Load
Load the original json data

In [2]:
# Generalized code for handling file path
try:
    import google.colab
    IN_COLAB = True

    import sys
    sys.path.append('/content/drive/My Drive/PersonalPage/ECPE-with-BERT')
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    # Mount Google Drive (optional, if you need to access files there)
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    # Define the data path (e.g., in a specific folder in Google Drive)
    data_path = "/content/drive/My Drive/PersonalPage/ECPE-with-BERT/data"
else:
    # Define the local data path
    data_path = "/data"

Mounted at /content/drive


In [3]:
import json

file_path = data_path + '/Subtask_1_train.json'

with open(file_path, 'r') as file:
    data = json.load(file)

# Question Answering
- In this project, we viewed the sub-task of finding textual cause span(s) in ECPE as a question-answering problem.
- That is, we treated a given utterance as a question to be answered.
- The previous utterances including the given utterance itself can be viewed as a context where the answer (i.e., the textual cause span(s) of the given utterance) lies.
- We only consider that only previous utterances (including itself) can be regarded as possible emotional causes.

## Data Preprocessing
- In a single sample of original data, there are multiple utterances (i.e., utterance ID and the corresponding text), their emotions, and speakers.
- Also, the sample has ‘emotion-cause_pairs’ that the model should make a prediction on.
- The ‘emotion-cause_pairs’ is a set of emotion-cause pairs that matches the emotion of an utterance (e.g., ‘3_surprise’) and the corresponding causes (e.g., ‘1_I realize I am totally naked .’ and ‘3_Then I look down , and I realize there is a phone ... there .’).
- As mentioned before, we considered a given utterance with emotion as a question.
- Also, we treated previous utterances and the given utterance itself as a context where the corresponding answer (i.e., the textual cause span) could lie.
<div align="center">
    <img src="img/Data_transformation.jpg" alt="Overview of ECPE" width="500">
</div>

### Transform the original data into SQuAD
- sample_to_SQuAD(sample): Tramsform a single sample into SQuAD
- SQuAD_format_transformation(origin_data, random_state=42): Transform the whole data into SQuAD

In [4]:
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import numpy as np
from utils.preprocess import sample_to_SQuAD, SQuAD_format_transformation # Custom library to transform the original data into SQuAD

In [26]:
# Split the data into train/valid/test set
train, test = train_test_split(data, test_size=0.3)
valid , test = train_test_split(test, test_size=0.5)

print('Number of traning dataset: ' + str(len(train)))
print('Number of validation dataset: ' + str(len(valid)))
print('Number of test dataset: ' + str(len(test)))

Number of traning dataset: 961
Number of validation dataset: 206
Number of test dataset: 207


In [27]:
# Change the original data into SQuAD format
train_contexts, train_questions, train_answers = SQuAD_format_transformation(train)
val_contexts, val_questions, val_answers = SQuAD_format_transformation(valid)
test_contexts, test_questions, test_answers = SQuAD_format_transformation(test)

### Tokenization
- Tokenize the SQuAD-formatted data to prepare input for BERT
- If the context data exceeds BERT's maximum input length, create multiple training features from a single dataset sample by applying a sliding window between them, following this [method](https://huggingface.co/learn/nlp-course/chapter7/7).



In [28]:
# Change the SQuAD data into Dataset class for later process
from datasets import Dataset

train_ids = list(range(len(train_contexts)))
train_dic = {'id': train_ids, 'context': train_contexts, 'question': train_questions, 'answers': train_answers}
val_ids = list(range(len(val_contexts)))
val_dic = {'id': val_ids, 'context': val_contexts, 'question': val_questions, 'answers': val_answers}

train_dataset = Dataset.from_dict(train_dic)
val_dataset = Dataset.from_dict(val_dic)

In [29]:
# Initialize Tokenizer
from transformers import AutoTokenizer

model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [30]:
from utils.tokenize import tokenize_with_split

# Check the estimated maximum length of context
max_length = 384
stride = 128

train_dataset = train_dataset.map(
    lambda x: tokenize_with_split(x, tokenizer, max_length, stride),
    batched=True,
    remove_columns=train_dataset.column_names,
)

train_dataset = train_dataset.remove_columns(['offset_mapping', 'example_id'])

val_dataset = val_dataset.map(
    lambda x: tokenize_with_split(x, tokenizer, max_length, stride),
    batched=True,
    remove_columns=val_dataset.column_names,
)

Map:   0%|          | 0/6335 [00:00<?, ? examples/s]

Map:   0%|          | 0/1341 [00:00<?, ? examples/s]

## Training

In [10]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

train_dataset.set_format("torch")
validation_set = val_dataset.remove_columns(["example_id", "offset_mapping"])
validation_set.set_format("torch")

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=8,
)
eval_dataloader = DataLoader(
    validation_set, collate_fn=default_data_collator, batch_size=8
)

In [11]:
# Referenced https://huggingface.co/learn/nlp-course/chapter7/7, modified accordingly
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
import torch

from torch.optim import AdamW
from transformers import get_scheduler
from transformers import AutoModelForQuestionAnswering

from utils.metrics import compute_metrics

import evaluate
from tqdm.auto import tqdm
import collections

metric = evaluate.load("squad")

n_best = 20
max_answer_length = 200

num_train_epochs = 10
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

learning_rate_lst = []
no_epochs = []
val_results = []
val_exactmatch = []

lr_lst = [2e-3, 8e-4, 2e-4, 8e-5, 2e-5, 2e-7]
models = []

for lr in lr_lst:
  model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad")
  models.append(model)

  optimizer = AdamW(model.parameters(), lr=lr)

  lr_scheduler = get_scheduler(
      "linear",
      optimizer=optimizer,
      num_warmup_steps=0,
      num_training_steps=num_training_steps,
  )

  progress_bar = tqdm(range(num_training_steps))

  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  model.to(device)

  for epoch in range(num_train_epochs):
      # Training
      model.train()
      epoch_loss = 0.0
      running_loss = 0.0
      print("Training!")

      for step, batch in enumerate(train_dataloader):
          input_ids = batch['input_ids'].to(device)
          attention_mask = batch['attention_mask'].to(device)
          start_positions = batch['start_positions'].to(device)
          end_positions = batch['end_positions'].to(device)
          outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
          loss = outputs.loss
          loss.backward()

          optimizer.step()
          lr_scheduler.step()
          optimizer.zero_grad()
          progress_bar.update(1)

      # Evaluation
      model.eval()
      start_logits = []
      end_logits = []
      print("Evaluation!")
      for batch in tqdm(eval_dataloader):
          with torch.no_grad():
              input_ids = batch['input_ids'].to(device)
              attention_mask = batch['attention_mask'].to(device)
              outputs = model(input_ids, attention_mask=attention_mask)

          start_logits.append(outputs.start_logits.cpu().numpy())
          end_logits.append(outputs.end_logits.cpu().numpy())

      start_logits = np.concatenate(start_logits)
      end_logits = np.concatenate(end_logits)
      start_logits = start_logits[: len(val_dataset)]
      end_logits = end_logits[: len(val_dataset)]

      metrics = compute_metrics(
          start_logits, end_logits, val_dataset, Dataset.from_dict(val_dic),
          metric, n_best, max_answer_length
      )
      learning_rate_lst.append(lr)
      no_epochs.append(epoch)
      val_results.append(metrics['f1'])
      val_exactmatch.append(metrics['exact_match'])

      print(f"epoch {epoch}, lr {lr} (validation):", metrics)

  0%|          | 0/7730 [00:00<?, ?it/s]

Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 0, lr 0.002 (validation): {'exact_match': 0.2127659574468085, 'f1': 9.452594009547498}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 1, lr 0.002 (validation): {'exact_match': 0.07092198581560284, 'f1': 12.570553037394365}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 2, lr 0.002 (validation): {'exact_match': 0.07092198581560284, 'f1': 10.916339857386529}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 3, lr 0.002 (validation): {'exact_match': 0.2127659574468085, 'f1': 9.363568451979411}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 4, lr 0.002 (validation): {'exact_match': 0.14184397163120568, 'f1': 11.165833659973023}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 5, lr 0.002 (validation): {'exact_match': 0.2127659574468085, 'f1': 3.4516460733353034}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 6, lr 0.002 (validation): {'exact_match': 0.14184397163120568, 'f1': 5.3232351616007}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 7, lr 0.002 (validation): {'exact_match': 0.14184397163120568, 'f1': 8.472988697908065}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 8, lr 0.002 (validation): {'exact_match': 0.0, 'f1': 7.917780481704791}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 9, lr 0.002 (validation): {'exact_match': 0.28368794326241137, 'f1': 7.5743368259180786}


  0%|          | 0/7730 [00:00<?, ?it/s]

Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 0, lr 0.0008 (validation): {'exact_match': 0.3546099290780142, 'f1': 14.582921453634267}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 1, lr 0.0008 (validation): {'exact_match': 0.3546099290780142, 'f1': 15.053520189136337}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 2, lr 0.0008 (validation): {'exact_match': 0.28368794326241137, 'f1': 12.2419619273242}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 3, lr 0.0008 (validation): {'exact_match': 0.2127659574468085, 'f1': 9.726309768945354}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 4, lr 0.0008 (validation): {'exact_match': 0.07092198581560284, 'f1': 12.38863395594629}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 5, lr 0.0008 (validation): {'exact_match': 0.28368794326241137, 'f1': 9.63460921749853}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 6, lr 0.0008 (validation): {'exact_match': 0.14184397163120568, 'f1': 10.282060314890643}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 7, lr 0.0008 (validation): {'exact_match': 0.14184397163120568, 'f1': 10.079984056283378}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 8, lr 0.0008 (validation): {'exact_match': 0.07092198581560284, 'f1': 9.095616406564025}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 9, lr 0.0008 (validation): {'exact_match': 0.28368794326241137, 'f1': 10.485377045489484}


  0%|          | 0/7730 [00:00<?, ?it/s]

Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 0, lr 0.0002 (validation): {'exact_match': 31.06382978723404, 'f1': 56.41467848270696}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 1, lr 0.0002 (validation): {'exact_match': 29.858156028368793, 'f1': 54.94450604274593}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 2, lr 0.0002 (validation): {'exact_match': 31.70212765957447, 'f1': 55.40325643548228}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 3, lr 0.0002 (validation): {'exact_match': 31.98581560283688, 'f1': 55.69437666275839}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 4, lr 0.0002 (validation): {'exact_match': 33.333333333333336, 'f1': 56.94402420325109}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 5, lr 0.0002 (validation): {'exact_match': 32.340425531914896, 'f1': 55.91026673203707}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 6, lr 0.0002 (validation): {'exact_match': 30.70921985815603, 'f1': 55.37125984088266}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 7, lr 0.0002 (validation): {'exact_match': 31.70212765957447, 'f1': 54.1895816264299}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 8, lr 0.0002 (validation): {'exact_match': 31.843971631205672, 'f1': 55.22811073989654}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 9, lr 0.0002 (validation): {'exact_match': 31.70212765957447, 'f1': 55.39029187123442}


  0%|          | 0/7730 [00:00<?, ?it/s]

Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 0, lr 8e-05 (validation): {'exact_match': 37.02127659574468, 'f1': 58.297238143693065}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 1, lr 8e-05 (validation): {'exact_match': 38.08510638297872, 'f1': 58.88400603113104}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 2, lr 8e-05 (validation): {'exact_match': 36.95035460992908, 'f1': 58.507905608093914}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 3, lr 8e-05 (validation): {'exact_match': 35.744680851063826, 'f1': 56.01552920148649}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 4, lr 8e-05 (validation): {'exact_match': 35.46099290780142, 'f1': 56.70501392194672}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 5, lr 8e-05 (validation): {'exact_match': 36.38297872340426, 'f1': 57.14676083450835}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 6, lr 8e-05 (validation): {'exact_match': 36.666666666666664, 'f1': 57.77731946972824}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 7, lr 8e-05 (validation): {'exact_match': 35.39007092198582, 'f1': 57.183930938792074}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 8, lr 8e-05 (validation): {'exact_match': 35.39007092198582, 'f1': 56.78580433422605}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 9, lr 8e-05 (validation): {'exact_match': 35.46099290780142, 'f1': 56.63388517096691}


  0%|          | 0/7730 [00:00<?, ?it/s]

Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 0, lr 2e-05 (validation): {'exact_match': 34.60992907801418, 'f1': 57.755692495380195}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 1, lr 2e-05 (validation): {'exact_match': 37.234042553191486, 'f1': 58.796459289770254}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 2, lr 2e-05 (validation): {'exact_match': 35.1063829787234, 'f1': 57.32457803853952}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 3, lr 2e-05 (validation): {'exact_match': 37.5177304964539, 'f1': 58.38868629593834}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 4, lr 2e-05 (validation): {'exact_match': 34.11347517730496, 'f1': 56.04195624674421}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 5, lr 2e-05 (validation): {'exact_match': 36.59574468085106, 'f1': 57.971238881393056}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 6, lr 2e-05 (validation): {'exact_match': 37.3758865248227, 'f1': 58.31952188941519}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 7, lr 2e-05 (validation): {'exact_match': 35.0354609929078, 'f1': 56.73602160219768}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 8, lr 2e-05 (validation): {'exact_match': 35.673758865248224, 'f1': 57.002494027838175}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 9, lr 2e-05 (validation): {'exact_match': 35.673758865248224, 'f1': 56.90483380520841}


  0%|          | 0/7730 [00:00<?, ?it/s]

Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 0, lr 2e-07 (validation): {'exact_match': 17.588652482269502, 'f1': 45.472595849025176}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 1, lr 2e-07 (validation): {'exact_match': 21.843971631205672, 'f1': 49.97791632823507}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 2, lr 2e-07 (validation): {'exact_match': 22.97872340425532, 'f1': 51.816792287209346}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 3, lr 2e-07 (validation): {'exact_match': 25.24822695035461, 'f1': 52.94587377182774}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 4, lr 2e-07 (validation): {'exact_match': 26.95035460992908, 'f1': 53.46182484638973}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 5, lr 2e-07 (validation): {'exact_match': 27.872340425531913, 'f1': 53.85376029017353}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 6, lr 2e-07 (validation): {'exact_match': 28.72340425531915, 'f1': 54.14821647215626}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 7, lr 2e-07 (validation): {'exact_match': 28.79432624113475, 'f1': 54.0543999158537}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 8, lr 2e-07 (validation): {'exact_match': 28.79432624113475, 'f1': 54.00838140506781}
Training!
Evaluation!


  0%|          | 0/177 [00:00<?, ?it/s]

  0%|          | 0/1410 [00:00<?, ?it/s]

epoch 9, lr 2e-07 (validation): {'exact_match': 28.865248226950353, 'f1': 53.96806906758961}


## Prediction

In [32]:
best_model_index = int(np.argmax(val_results)/num_train_epochs)
best_epoch = np.argmax(val_results)%num_train_epochs
best_model = models[best_model_index]

test_ids = list(range(len(test_contexts)))
test_dic = {'id': test_ids, 'context': test_contexts, 'question': test_questions, 'answers': test_answers}

test_dataset = Dataset.from_dict(test_dic)

test_dataset = test_dataset.map(
    lambda x: tokenize_with_split(x, tokenizer, max_length, stride),
    batched=True,
    remove_columns=test_dataset.column_names,
)

test_set = test_dataset.remove_columns(["example_id", "offset_mapping"])
test_set.set_format("torch")

test_dataloader = DataLoader(
    test_set, collate_fn=default_data_collator, batch_size=8
)

Map:   0%|          | 0/1311 [00:00<?, ? examples/s]

In [33]:
best_model.eval()
start_logits = []
end_logits = []
print("Prediction!")
for batch in tqdm(test_dataloader):
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        outputs = best_model(input_ids, attention_mask=attention_mask)

    start_logits.append(outputs.start_logits.cpu().numpy())
    end_logits.append(outputs.end_logits.cpu().numpy())

start_logits = np.concatenate(start_logits)
end_logits = np.concatenate(end_logits)
start_logits = start_logits[: len(val_dataset)]
end_logits = end_logits[: len(val_dataset)]

metrics = compute_metrics(
    start_logits, end_logits, test_dataset, Dataset.from_dict(test_dic),
    metric, n_best, max_answer_length
)

Prediction!


  0%|          | 0/164 [00:00<?, ?it/s]

  0%|          | 0/1311 [00:00<?, ?it/s]

In [39]:
print(f"Final test F1 score: {metrics['f1']:.2f}")

Final test F1 score: 68.32
