# Curriculum Learning for Natural Language Understanding

The following is an attempt at reproducing the results presented in the [Curriculum Learning for Natural Language Understanding paper](https://aclanthology.org/2020.acl-main.542/). It aims to demonstrate the effect of curriculum learning on the performance of the BERT language model in machine reading comprehension using the SQuAD 2.0 dataset. The implementation of the **difficulty evaluation** step of the curriculum learning framework trains 10 (N = 10) models on 10 distinct portions of the training set. Then, **curriculum arrangement** sorts the 10 splits of the training set by difficulty (F1 as the golden metric), trains BERT through the 10 difficulty stages for 1 epoch each, and concludes by training it on the original distribution in the train set until it converges.

## Preliminaries

In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os

import numpy as np
from numpy import unravel_index
import pandas as pd
import math

import random
import sys
from IPython.display import Image
import time
from transformers import BertTokenizerFast, BertTokenizer, BertForQuestionAnswering

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

import re
import string

from tqdm import tqdm
import json

!CUBLAS_WORKSPACE_CONFIG=:4096:2

!pip install transformers

# reproducibility
def set_seed(seed = 1234):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    # torch.use_deterministic_algorithms(False)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device:', device)

device: cuda


In [48]:
!pip install datasets
from datasets import load_dataset



## SQuAD 2.0 Dataset


In [50]:
squad_dataset = load_dataset("squad_v2")

In [93]:
train_squad = squad_dataset["train"].shuffle(seed=42).select(range(1000))

In [95]:
validation_squad = squad_dataset["validation"].shuffle(seed=42).select(range(1000))

In [96]:
train_squad

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 1000
})

In [97]:
train_squad[0]

{'id': '56e0f3907aa994140058e80a',
 'title': 'Canon_law',
 'context': 'The Roman Catholic Church canon law also includes the main five rites (groups) of churches which are in full union with the Roman Catholic Church and the Supreme Pontiff:',
 'question': 'What term characterizes the intersection of the rites with the Roman Catholic Church?',
 'answers': {'text': ['full union'], 'answer_start': [104]}}

In [98]:
validation_squad

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 1000
})

In [99]:
validation_squad[0]

{'id': '5733ea04d058e614000b6595',
 'title': 'French_and_Indian_War',
 'context': "In the spring of 1753, Paul Marin de la Malgue was given command of a 2,000-man force of Troupes de la Marine and Indians. His orders were to protect the King's land in the Ohio Valley from the British. Marin followed the route that Céloron had mapped out four years earlier, but where Céloron had limited the record of French claims to the burial of lead plates, Marin constructed and garrisoned forts. He first constructed Fort Presque Isle (near present-day Erie, Pennsylvania) on Lake Erie's south shore. He had a road built to the headwaters of LeBoeuf Creek. Marin constructed a second fort at Fort Le Boeuf (present-day Waterford, Pennsylvania), designed to guard the headwaters of LeBoeuf Creek. As he moved south, he drove off or captured British traders, alarming both the British and the Iroquois. Tanaghrisson, a chief of the Mingo, who were remnants of Iroquois and other tribes who had been driven west 

## Preprocessing

#### Train

In [100]:
train_squad[1]

{'id': '571adcf932177014007e9f56',
 'title': 'Athanasius_of_Alexandria',
 'context': "Alexandria was the most important trade center in the whole empire during Athanasius's boyhood. Intellectually, morally, and politically—it epitomized the ethnically diverse Graeco-Roman world, even more than Rome or Constantinople, Antioch or Marseilles. Its famous catechetical school, while sacrificing none of its famous passion for orthodoxy since the days of Pantaenus, Clement of Alexandria, Origen of Alexandria, Dionysius and Theognostus, had begun to take on an almost secular character in the comprehensiveness of its interests, and had counted influential pagans among its serious auditors.",
 'question': 'What was Alexandria known for?',
 'answers': {'text': ['important trade center'], 'answer_start': [24]}}

In [101]:
def find_end(example):

    if (len(example['answers']['text']) != 0):
        context = example['context']
        text = example['answers']['text'][0]
        start_idx = example['answers']['answer_start'][0]

        end_idx = start_idx + len(text)

        temp = example['answers'] # to change the value
        temp['answer_end']=end_idx
        temp['answer_start'] = start_idx # [num]->num
        temp['text'] = text # ['text']->text

    else:
        temp = example['answers']
        temp['answer_end'] = 0 # []->0
        temp['answer_start'] = 0 # []->0
        temp['text'] = "" # []->""

    return example

train_squad = train_squad.map(find_end)

In [102]:
train_squad[1]

{'id': '571adcf932177014007e9f56',
 'title': 'Athanasius_of_Alexandria',
 'context': "Alexandria was the most important trade center in the whole empire during Athanasius's boyhood. Intellectually, morally, and politically—it epitomized the ethnically diverse Graeco-Roman world, even more than Rome or Constantinople, Antioch or Marseilles. Its famous catechetical school, while sacrificing none of its famous passion for orthodoxy since the days of Pantaenus, Clement of Alexandria, Origen of Alexandria, Dionysius and Theognostus, had begun to take on an almost secular character in the comprehensiveness of its interests, and had counted influential pagans among its serious auditors.",
 'question': 'What was Alexandria known for?',
 'answers': {'answer_end': 46,
  'answer_start': 24,
  'text': 'important trade center'}}

In [103]:
train_squad[10]

{'id': '56fb84ebb28b3419009f1de7',
 'title': 'Middle_Ages',
 'context': 'During this period the practice of manuscript illumination gradually passed from monasteries to lay workshops, so that according to Janetta Benton "by 1300 most monks bought their books in shops", and the book of hours developed as a form of devotional book for lay-people. Metalwork continued to be the most prestigious form of art, with Limoges enamel a popular and relatively affordable option for objects such as reliquaries and crosses. In Italy the innovations of Cimabue and Duccio, followed by the Trecento master Giotto (d. 1337), greatly increased the sophistication and status of panel painting and fresco. Increasing prosperity during the 12th century resulted in greater production of secular art; many carved ivory objects such as gaming-pieces, combs, and small religious figures have survived.',
 'question': 'What were many pieces of secular art carved from in this period?',
 'answers': {'answer_end': 728, 'a

In [104]:
train_squad[-10]

{'id': '5aceab0f32bba1001ae4af93',
 'title': 'Jews',
 'context': 'For centuries, Jews worldwide have spoken the local or dominant languages of the regions they migrated to, often developing distinctive dialectal forms or branches that became independent languages. Yiddish is the Judæo-German language developed by Ashkenazi Jews who migrated to Central Europe. Ladino is the Judæo-Spanish language developed by Sephardic Jews who migrated to the Iberian peninsula. Due to many factors, including the impact of the Holocaust on European Jewry, the Jewish exodus from Arab and Muslim countries, and widespread emigration from other Jewish communities around the world, ancient and distinct Jewish languages of several communities, including Judæo-Georgian, Judæo-Arabic, Judæo-Berber, Krymchak, Judæo-Malayalam and many others, have largely fallen out of use.',
 'question': 'For how long have Jews not spoken the local languages of the regions they migrated to?',
 'answers': {'answer_end': 0, 'answe

In [105]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [106]:
def process_train(train_squad):

  tokenized_train = tokenizer(train_squad['context'], train_squad['question'], truncation=True, padding=True)

  def find_token_indexes(tokenized, dataset):
      start_token_list = []
      end_token_list = []
      answers = dataset['answers']
      for i in range(len(answers)):
          if (answers[i]['text'] != ''):
              start_token = tokenized.char_to_token(i, answers[i]['answer_start'])
              end_token = tokenized.char_to_token(i, answers[i]['answer_end'] - 1)

              if start_token is None:
                  start_token = tokenizer.model_max_length
              if end_token is None:
                  end_token = tokenizer.model_max_length
          else:
              start_token = 0
              end_token = 0

          start_token_list.append(start_token)
          end_token_list.append(end_token)

      return start_token_list, end_token_list

  s, e = find_token_indexes(tokenized_train, train_squad)
  train_squad = train_squad.add_column("start_position", s)
  train_squad = train_squad.add_column("end_position", e)

  return (tokenized_train, train_squad)

In [107]:
tokenized_train, processed_train_squad = process_train(train_squad)

In [108]:
processed_train_squad

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers', 'start_position', 'end_position'],
    num_rows: 1000
})

In [109]:
batch_size = 8

In [110]:
def get_train_dataloader(tokenized_train, train_squad, batch_size):
  train_data = TensorDataset(torch.tensor(tokenized_train['input_ids'], dtype=torch.int64),
                            torch.tensor(tokenized_train['token_type_ids'], dtype=torch.int64),
                            torch.tensor(tokenized_train['attention_mask'], dtype=torch.float),
                            torch.tensor(train_squad['start_position'], dtype=torch.int64),
                            torch.tensor(train_squad['start_position'], dtype=torch.int64))
  train_sampler = RandomSampler(train_data)
  train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
  return train_dataloader

In [111]:
train_dataloader = get_train_dataloader(tokenized_train, processed_train_squad, batch_size)

#### Validation

In [112]:
def get_val_dataloader(validation_squad, batch_size):
  tokenized_validation = tokenizer(validation_squad['context'], validation_squad['question'], truncation=True, padding=True, return_offsets_mapping=True)
  val_data = TensorDataset(torch.tensor(tokenized_validation['input_ids'], dtype=torch.int64),
                          torch.tensor(tokenized_validation['token_type_ids'], dtype=torch.int64),
                          torch.tensor(tokenized_validation['attention_mask'], dtype=torch.float))
  val_sampler = SequentialSampler(val_data)
  val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)
  return (tokenized_validation, val_dataloader)

In [113]:
tokenized_validation, val_dataloader = get_val_dataloader(validation_squad, batch_size)

## Evaluation requirements

In [114]:
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O dev-v2.0.json

--2023-12-15 23:43:15--  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.111.153, 185.199.109.153, 185.199.108.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.111.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4370528 (4.2M) [application/json]
Saving to: ‘dev-v2.0.json’


2023-12-15 23:43:15 (100 MB/s) - ‘dev-v2.0.json’ saved [4370528/4370528]



In [115]:
def calculate_accuracy(tokenized_validation, validation_squad, val_dataloader):
  threshold = 1.0
  epoch_i = 0
  correct = 0
  pred_dict = {}
  na_prob_dict = {}

  model.eval()
  correct = 0
  batch_val_losses = []
  row = 0
  for test_batch in tqdm(val_dataloader):
      input_ids, segment_ids, masks = tuple(t.to(device) for t in test_batch)

      with torch.no_grad():
          start_logits, end_logits = model(input_ids=input_ids,
                                          token_type_ids=segment_ids,
                                          attention_mask=masks,
                                          return_dict=False)

      start_logits = start_logits.detach().cpu()
      end_logits = end_logits.detach().cpu()

      for bidx in range(len(start_logits)):
          start_scores = np.array(F.softmax(start_logits[bidx], dim = 0))
          end_scores = np.array(F.softmax(end_logits[bidx], dim = 0))

          size = len(start_scores)
          scores = np.zeros((size, size))

          for j in range(size):
              for i in range(j+1):
                  scores[i,j] = start_scores[i] + end_scores[j]

          start_pred, end_pred = unravel_index(scores.argmax(), scores.shape)
          answer_pred = ""
          if (scores[start_pred, end_pred] > scores[0,0]+threshold):

              offsets = tokenized_validation.offset_mapping[row]
              pred_char_start = offsets[start_pred][0]

              if end_pred < len(offsets):
                  pred_char_end = offsets[end_pred][1]
                  answer_pred = validation_squad[row]['context'][pred_char_start:pred_char_end]
              else:
                  answer_pred = validation_squad[row]['context'][pred_char_start:]

              if answer_pred in validation_squad[row]['answers']['text']:
                  correct += 1

          else:
              if (len(validation_squad[row]['answers']['text']) ==0):
                  correct += 1

          pred_dict[validation_squad[row]['id']] = answer_pred
          na_prob_dict[validation_squad[row]['id']] = scores[0,0]

          row+=1


  accuracy = correct/validation_squad.num_rows
  print("\n","accuracy is: ", accuracy)
  with open("pred.json", "w") as outfile:
    json.dump(pred_dict, outfile)
  with open("na_prob.json", "w") as outfile:
    json.dump(na_prob_dict, outfile)

In [116]:
"""
modified official evaluation script for SQuAD version 2
"""
import argparse
import collections
import json
import os
import re
import string
import sys
import numpy as np

ARTICLES_REGEX = re.compile(r"\b(a|an|the)\b", re.UNICODE)

OPTS = None

def make_qid_to_has_ans(dataset):
    qid_to_has_ans = {}
    for article in dataset:
        for p in article["paragraphs"]:
            for qa in p["qas"]:
                qid_to_has_ans[qa["id"]] = bool(qa["answers"][0]["text"] if qa["answers"] else "")
                #qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"])
    return qid_to_has_ans

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return ARTICLES_REGEX.sub(" ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))

def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def get_raw_scores(dataset, preds):
    exact_scores = {}
    f1_scores = {}
    for article in dataset:
        for p in article["paragraphs"]:
            for qa in p["qas"]:
                qid = qa["id"]
                gold_answers = [normalize_answer(answer["text"]) for answer in qa["answers"] if answer.get("text")]
                #gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)]
                if not gold_answers:
                    # For unanswerable questions, only correct answer is empty string
                    gold_answers = [""]
                if qid not in preds:
                    #print(f"Missing prediction for {qid}")
                    continue
                a_pred = preds[qid]
                # Take max over all gold answers
                exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
                f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
    return exact_scores, f1_scores

def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
    new_scores = {}
    for qid, s in scores.items():
        pred_na = na_probs[qid] > na_prob_thresh
        if pred_na:
            new_scores[qid] = float(not qid_to_has_ans[qid])
        else:
            new_scores[qid] = s
    return new_scores

def make_eval_dict(exact_scores, f1_scores, qid_list=None):
    if not qid_list:
        total = len(exact_scores)
        return collections.OrderedDict(
            [
                ("exact", 100.0 * sum(exact_scores.values()) / total),
                ("f1", 100.0 * sum(f1_scores.values()) / total),
                ("total", total),
            ]
        )
    else:
        total = len(qid_list)
        return collections.OrderedDict(
            [
                ("exact", 100.0 * sum(exact_scores.get(k, 0) for k in qid_list) / total),
                ("f1", 100.0 * sum(f1_scores.get(k, 0) for k in qid_list) / total),
                ("total", total),
            ]
        )

def merge_eval(main_eval, new_eval, prefix):
    for k in new_eval:
        main_eval[f"{prefix}_{k}"] = new_eval[k]

def plot_pr_curve(precisions, recalls, out_image, title):
    plt.step(recalls, precisions, color="b", alpha=0.2, where="post")
    plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.xlim([0.0, 1.05])
    plt.ylim([0.0, 1.05])
    plt.title(title)
    plt.savefig(out_image)
    plt.clf()

def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None):
    qid_list = sorted(na_probs, key=lambda k: na_probs[k])
    true_pos = 0.0
    cur_p = 1.0
    cur_r = 0.0
    precisions = [1.0]
    recalls = [0.0]
    avg_prec = 0.0
    for i, qid in enumerate(qid_list):
        if qid_to_has_ans[qid]:
            true_pos += scores[qid]
        cur_p = true_pos / float(i + 1)
        cur_r = true_pos / float(num_true_pos)
        if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]:
            # i.e., if we can put a threshold after this point
            avg_prec += cur_p * (cur_r - recalls[-1])
            precisions.append(cur_p)
            recalls.append(cur_r)
    if out_image:
        plot_pr_curve(precisions, recalls, out_image, title)
    return {"ap": 100.0 * avg_prec}

def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir):
    if out_image_dir and not os.path.exists(out_image_dir):
        os.makedirs(out_image_dir)
    num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
    if num_true_pos == 0:
        return
    pr_exact = make_precision_recall_eval(
        exact_raw,
        na_probs,
        num_true_pos,
        qid_to_has_ans,
        out_image=os.path.join(out_image_dir, "pr_exact.png"),
        title="Precision-Recall curve for Exact Match score",
    )
    pr_f1 = make_precision_recall_eval(
        f1_raw,
        na_probs,
        num_true_pos,
        qid_to_has_ans,
        out_image=os.path.join(out_image_dir, "pr_f1.png"),
        title="Precision-Recall curve for F1 score",
    )
    oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
    pr_oracle = make_precision_recall_eval(
        oracle_scores,
        na_probs,
        num_true_pos,
        qid_to_has_ans,
        out_image=os.path.join(out_image_dir, "pr_oracle.png"),
        title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)",
    )
    merge_eval(main_eval, pr_exact, "pr_exact")
    merge_eval(main_eval, pr_f1, "pr_f1")
    merge_eval(main_eval, pr_oracle, "pr_oracle")

def histogram_na_prob(na_probs, qid_list, image_dir, name):
    if not qid_list:
        return
    x = [na_probs.get(k, 0.0) for k in qid_list]
    #x = [na_probs[k] for k in qid_list]
    weights = np.ones_like(x) / float(len(x))
    plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
    plt.xlabel("Model probability of no-answer")
    plt.ylabel("Proportion of dataset")
    plt.title(f"Histogram of no-answer probability: {name}")
    plt.savefig(os.path.join(image_dir, f"na_prob_hist_{name}.png"))
    plt.clf()

def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
    num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
    cur_score = num_no_ans
    best_score = cur_score
    best_thresh = 0.0
    qid_list = sorted(na_probs, key=lambda k: na_probs[k])
    for i, qid in enumerate(qid_list):
        if qid not in scores:
            continue
        if qid_to_has_ans[qid]:
            diff = scores[qid]
        else:
            if preds[qid]:
                diff = -1
            else:
                diff = 0
        cur_score += diff
        if cur_score > best_score:
            best_score = cur_score
            best_thresh = na_probs[qid]
    return 100.0 * best_score / len(scores), best_thresh

def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
    best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
    best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
    main_eval["best_exact"] = best_exact
    main_eval["best_exact_thresh"] = exact_thresh
    main_eval["best_f1"] = best_f1
    main_eval["best_f1_thresh"] = f1_thresh

def evaluate():
    with open(OPTS.data_file) as f:
        dataset_json = json.load(f)
        dataset = dataset_json["data"]
    with open(OPTS.pred_file) as f:
        preds = json.load(f)
    if OPTS.na_prob_file:
        with open(OPTS.na_prob_file) as f:
            na_probs = json.load(f)
    else:
        na_probs = {k: 0.0 for k in preds}
    qid_to_has_ans = make_qid_to_has_ans(dataset)  # maps qid to True/False
    has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
    no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
    exact_raw, f1_raw = get_raw_scores(dataset, preds)
    exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
    f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
    out_eval = make_eval_dict(exact_thresh, f1_thresh)
    if has_ans_qids:
        has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
        merge_eval(out_eval, has_ans_eval, "HasAns")
    if no_ans_qids:
        no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
        merge_eval(out_eval, no_ans_eval, "NoAns")
    if OPTS.na_prob_file:
        find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
    if OPTS.na_prob_file and OPTS.out_image_dir:
        run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir)
        histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns")
        histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns")
    if OPTS.out_file:
        with open(OPTS.out_file, "w") as f:
            json.dump(out_eval, f)
    else:
        return json.dumps(out_eval, indent=2)

OPTS = argparse.Namespace(
  data_file="dev-v2.0.json",  # Specify the path to your data file
  pred_file="pred.json",       # Specify the path to your predictions file
  out_file=None,               # Set to None or specify the path for the output file
  na_prob_file="na_prob.json",  # Specify the path to your NA probability file
  na_prob_thresh=1.0,
  out_image_dir="./",           # Specify the directory for saving images
  verbose=False
)

if OPTS.out_image_dir:
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

## Fine-tuning

In [117]:
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
epochs = 10
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-5)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [118]:
def train(train_dataloader):

  for epoch in range(epochs):
      epoch_loss = []
      validation_loss = []

      total_loss = 0
      model.train()

      count=-1
      progress_bar = tqdm(train_dataloader, leave=True, position=0)
      progress_bar.set_description(f"Epoch {epoch+1}")
      for batch in progress_bar:
          count+=1
          input_ids, segment_ids, mask, start, end  = tuple(t.to(device) for t in batch)

          model.zero_grad()
          loss, start_logits, end_logits = model(input_ids = input_ids,
                                                  token_type_ids = segment_ids,
                                                  attention_mask = mask,
                                                  start_positions = start,
                                                  end_positions = end,
                                                  return_dict = False)

          total_loss += loss.item()
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
          optimizer.step()

          if (count % 20 == 0 and count != 0):
              avg = total_loss/count
              progress_bar.set_postfix(Loss=avg)

      torch.save(model.state_dict(), "./bert2_" + str(epoch) + ".h5")
      avg_train_loss = total_loss / len(train_dataloader)
      epoch_loss.append(avg_train_loss)
      print(f"Epoch {epoch} Loss: {avg_train_loss}\n")

In [119]:
train(train_dataloader)

Epoch 1: 100%|██████████| 125/125 [01:29<00:00,  1.40it/s, Loss=4.46]


Epoch 0 Loss: 4.3984032783508304



Epoch 2: 100%|██████████| 125/125 [01:32<00:00,  1.34it/s, Loss=2.89]


Epoch 1 Loss: 2.8474354438781737



Epoch 3: 100%|██████████| 125/125 [01:35<00:00,  1.30it/s, Loss=2.19]


Epoch 2 Loss: 2.1738715052604674



Epoch 4: 100%|██████████| 125/125 [01:36<00:00,  1.29it/s, Loss=1.64]


Epoch 3 Loss: 1.6258160290718078



Epoch 5: 100%|██████████| 125/125 [01:34<00:00,  1.32it/s, Loss=1.18]


Epoch 4 Loss: 1.1767065677642823



Epoch 6: 100%|██████████| 125/125 [01:35<00:00,  1.31it/s, Loss=0.839]


Epoch 5 Loss: 0.8262568665742874



Epoch 7: 100%|██████████| 125/125 [01:34<00:00,  1.32it/s, Loss=0.593]


Epoch 6 Loss: 0.5842469004988671



Epoch 8: 100%|██████████| 125/125 [01:35<00:00,  1.31it/s, Loss=0.384]


Epoch 7 Loss: 0.3772523908019066



Epoch 9: 100%|██████████| 125/125 [01:35<00:00,  1.31it/s, Loss=0.276]


Epoch 8 Loss: 0.27345848065614703



Epoch 10: 100%|██████████| 125/125 [01:34<00:00,  1.32it/s, Loss=0.228]


Epoch 9 Loss: 0.22338134651631117



## Evaluating

In [120]:
calculate_accuracy(tokenized_validation, validation_squad, val_dataloader)

100%|██████████| 125/125 [01:34<00:00,  1.32it/s]


 accuracy is:  0.403





In [121]:
evaluate()

'{\n  "exact": 41.4,\n  "f1": 43.29857142857144,\n  "total": 1000,\n  "HasAns_exact": 0.30364372469635625,\n  "HasAns_f1": 0.6239155581260845,\n  "HasAns_total": 5928,\n  "NoAns_exact": 6.661059714045416,\n  "NoAns_f1": 6.661059714045416,\n  "NoAns_total": 5945,\n  "best_exact": 594.8,\n  "best_exact_thresh": 0.001625850098207593,\n  "best_f1": 594.9316666666666,\n  "best_f1_thresh": 0.0021326299756765366,\n  "pr_exact_ap": 0.05518820667522107,\n  "pr_f1_ap": 0.13709594709576495,\n  "pr_oracle_ap": 4.6467076804282\n}'

In [122]:
evaluation = json.loads(evaluate())
print(f'F1 score = {evaluation["f1"]}')

F1 score = 43.29857142857144


In [123]:
print(f'Exact Match score = {evaluation["exact"]}')

Exact Match score = 41.4


## Curriculum Learning

#### Difficulty Evaluation

In [124]:
N = 10

In [125]:
subset_size = len(train_squad) // N
shuffled_train_squad = train_squad.shuffle(seed=42)
train_splits = [shuffled_train_squad.select(indices=range(i * subset_size, (i + 1) * subset_size)) for i in range(N)]

In [126]:
subset_size = len(validation_squad) // N
shuffled_validation_squad = validation_squad.shuffle(seed=42)
validation_splits = [shuffled_validation_squad.select(indices=range(i * subset_size, (i + 1) * subset_size)) for i in range(N)]

In [127]:
difficulty_scores = []

for i, (train_split, validation_split) in enumerate(zip(train_splits, validation_splits), 1):

    print(f"Share {i}, size = {len(train_split)}")

    tokenized_train_split, processed_train_split = process_train(train_split)
    train_split_dataloader = get_train_dataloader(tokenized_train_split, processed_train_split, batch_size)

    tokenized_split_validation, val_split_dataloader = get_val_dataloader(validation_split, batch_size)

    #model
    model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
    epochs = 10
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-5)

    #calculate the difficulty score of each training example
    train(train_split_dataloader)

    #evaluate
    calculate_accuracy(tokenized_split_validation, validation_split, val_split_dataloader)
    eval = json.loads(evaluate())
    difficulty_scores.append(eval["f1"])

Share 1, size = 100


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 13/13 [00:07<00:00,  1.64it/s]


Epoch 0 Loss: 5.851925776554988



Epoch 2: 100%|██████████| 13/13 [00:07<00:00,  1.64it/s]


Epoch 1 Loss: 5.238157492417556



Epoch 3: 100%|██████████| 13/13 [00:07<00:00,  1.64it/s]


Epoch 2 Loss: 4.463050768925593



Epoch 4: 100%|██████████| 13/13 [00:07<00:00,  1.63it/s]


Epoch 3 Loss: 3.751704766200139



Epoch 5: 100%|██████████| 13/13 [00:08<00:00,  1.62it/s]


Epoch 4 Loss: 3.3096489356114316



Epoch 6: 100%|██████████| 13/13 [00:08<00:00,  1.62it/s]


Epoch 5 Loss: 2.9934418568244348



Epoch 7: 100%|██████████| 13/13 [00:07<00:00,  1.63it/s]


Epoch 6 Loss: 2.6615422322199893



Epoch 8: 100%|██████████| 13/13 [00:07<00:00,  1.63it/s]


Epoch 7 Loss: 2.323361974496108



Epoch 9: 100%|██████████| 13/13 [00:07<00:00,  1.64it/s]


Epoch 8 Loss: 1.9843842433049128



Epoch 10: 100%|██████████| 13/13 [00:07<00:00,  1.63it/s]


Epoch 9 Loss: 1.6055041276491606



100%|██████████| 13/13 [00:05<00:00,  2.57it/s]



 accuracy is:  0.58
Share 2, size = 100


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 13/13 [00:08<00:00,  1.59it/s]


Epoch 0 Loss: 5.9117934520428



Epoch 2: 100%|██████████| 13/13 [00:08<00:00,  1.58it/s]


Epoch 1 Loss: 5.350228126232441



Epoch 3: 100%|██████████| 13/13 [00:08<00:00,  1.59it/s]


Epoch 2 Loss: 4.685936450958252



Epoch 4: 100%|██████████| 13/13 [00:08<00:00,  1.58it/s]


Epoch 3 Loss: 3.9848101689265323



Epoch 5: 100%|██████████| 13/13 [00:08<00:00,  1.59it/s]


Epoch 4 Loss: 3.479357884480403



Epoch 6: 100%|██████████| 13/13 [00:08<00:00,  1.58it/s]


Epoch 5 Loss: 3.04250876720135



Epoch 7: 100%|██████████| 13/13 [00:08<00:00,  1.58it/s]


Epoch 6 Loss: 2.652655784900372



Epoch 8: 100%|██████████| 13/13 [00:08<00:00,  1.58it/s]


Epoch 7 Loss: 2.2694855745022116



Epoch 9: 100%|██████████| 13/13 [00:08<00:00,  1.57it/s]


Epoch 8 Loss: 1.8073397874832153



Epoch 10: 100%|██████████| 13/13 [00:08<00:00,  1.57it/s]


Epoch 9 Loss: 1.3752739429473877



100%|██████████| 13/13 [00:08<00:00,  1.55it/s]



 accuracy is:  0.49
Share 3, size = 100


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 0 Loss: 6.090338963728684



Epoch 2: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 1 Loss: 5.494831231924204



Epoch 3: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 2 Loss: 4.932100259340727



Epoch 4: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 3 Loss: 4.319653731126052



Epoch 5: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 4 Loss: 3.8421063789954553



Epoch 6: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 5 Loss: 3.4133441631610575



Epoch 7: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 6 Loss: 2.986253628363976



Epoch 8: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 7 Loss: 2.55928020293896



Epoch 9: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 8 Loss: 2.157136531976553



Epoch 10: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 9 Loss: 1.727669330743643



100%|██████████| 13/13 [00:09<00:00,  1.43it/s]



 accuracy is:  0.44
Share 4, size = 100


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 0 Loss: 6.056454658508301



Epoch 2: 100%|██████████| 13/13 [00:09<00:00,  1.44it/s]


Epoch 1 Loss: 5.289409197293795



Epoch 3: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 2 Loss: 4.565073270064134



Epoch 4: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 3 Loss: 3.962950798181387



Epoch 5: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 4 Loss: 3.601930783345149



Epoch 6: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 5 Loss: 3.177178841370803



Epoch 7: 100%|██████████| 13/13 [00:09<00:00,  1.41it/s]


Epoch 6 Loss: 2.8550300964942346



Epoch 8: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 7 Loss: 2.4049819157673764



Epoch 9: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 8 Loss: 2.0517704028349657



Epoch 10: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 9 Loss: 1.5729075532693129



100%|██████████| 13/13 [00:08<00:00,  1.59it/s]



 accuracy is:  0.54
Share 5, size = 100


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 13/13 [00:07<00:00,  1.66it/s]


Epoch 0 Loss: 5.662888343517597



Epoch 2: 100%|██████████| 13/13 [00:07<00:00,  1.66it/s]


Epoch 1 Loss: 5.125863735492413



Epoch 3: 100%|██████████| 13/13 [00:07<00:00,  1.65it/s]


Epoch 2 Loss: 4.56474205163809



Epoch 4: 100%|██████████| 13/13 [00:07<00:00,  1.65it/s]


Epoch 3 Loss: 4.03410423718966



Epoch 5: 100%|██████████| 13/13 [00:07<00:00,  1.65it/s]


Epoch 4 Loss: 3.3980347376603346



Epoch 6: 100%|██████████| 13/13 [00:07<00:00,  1.66it/s]


Epoch 5 Loss: 2.9706924511836124



Epoch 7: 100%|██████████| 13/13 [00:07<00:00,  1.65it/s]


Epoch 6 Loss: 2.5455264311570387



Epoch 8: 100%|██████████| 13/13 [00:07<00:00,  1.65it/s]


Epoch 7 Loss: 2.0545077874110294



Epoch 9: 100%|██████████| 13/13 [00:07<00:00,  1.65it/s]


Epoch 8 Loss: 1.5214072740994966



Epoch 10: 100%|██████████| 13/13 [00:07<00:00,  1.65it/s]


Epoch 9 Loss: 1.0611734252709608



100%|██████████| 13/13 [00:09<00:00,  1.31it/s]



 accuracy is:  0.5
Share 6, size = 100


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 0 Loss: 5.975341430077186



Epoch 2: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 1 Loss: 5.3234316018911505



Epoch 3: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 2 Loss: 4.598797174600454



Epoch 4: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 3 Loss: 3.968320901577289



Epoch 5: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 4 Loss: 3.452544010602511



Epoch 6: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 5 Loss: 2.917267927756676



Epoch 7: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 6 Loss: 2.5034271570352407



Epoch 8: 100%|██████████| 13/13 [00:09<00:00,  1.40it/s]


Epoch 7 Loss: 2.1135324148031382



Epoch 9: 100%|██████████| 13/13 [00:09<00:00,  1.41it/s]


Epoch 8 Loss: 1.6134525354091938



Epoch 10: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 9 Loss: 1.1893563316418574



100%|██████████| 13/13 [00:09<00:00,  1.33it/s]



 accuracy is:  0.52
Share 7, size = 100


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 13/13 [00:06<00:00,  1.89it/s]


Epoch 0 Loss: 5.829115134019118



Epoch 2: 100%|██████████| 13/13 [00:06<00:00,  1.90it/s]


Epoch 1 Loss: 5.228814345139724



Epoch 3: 100%|██████████| 13/13 [00:06<00:00,  1.89it/s]


Epoch 2 Loss: 4.590365923368013



Epoch 4: 100%|██████████| 13/13 [00:06<00:00,  1.89it/s]


Epoch 3 Loss: 3.9741879059718204



Epoch 5: 100%|██████████| 13/13 [00:06<00:00,  1.89it/s]


Epoch 4 Loss: 3.56036010155311



Epoch 6: 100%|██████████| 13/13 [00:06<00:00,  1.89it/s]


Epoch 5 Loss: 3.230900232608502



Epoch 7: 100%|██████████| 13/13 [00:06<00:00,  1.89it/s]


Epoch 6 Loss: 2.7936522227067213



Epoch 8: 100%|██████████| 13/13 [00:06<00:00,  1.89it/s]


Epoch 7 Loss: 2.486556053161621



Epoch 9: 100%|██████████| 13/13 [00:06<00:00,  1.89it/s]


Epoch 8 Loss: 2.032072351529048



Epoch 10: 100%|██████████| 13/13 [00:06<00:00,  1.89it/s]


Epoch 9 Loss: 1.6245467754510732



100%|██████████| 13/13 [00:09<00:00,  1.35it/s]



 accuracy is:  0.43
Share 8, size = 100


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 13/13 [00:07<00:00,  1.81it/s]


Epoch 0 Loss: 5.936285128960242



Epoch 2: 100%|██████████| 13/13 [00:07<00:00,  1.80it/s]


Epoch 1 Loss: 5.50904629780696



Epoch 3: 100%|██████████| 13/13 [00:07<00:00,  1.80it/s]


Epoch 2 Loss: 5.067943939795861



Epoch 4: 100%|██████████| 13/13 [00:07<00:00,  1.81it/s]


Epoch 3 Loss: 4.527000775704017



Epoch 5: 100%|██████████| 13/13 [00:07<00:00,  1.81it/s]


Epoch 4 Loss: 4.003667152844942



Epoch 6: 100%|██████████| 13/13 [00:07<00:00,  1.80it/s]


Epoch 5 Loss: 3.4480403936826267



Epoch 7: 100%|██████████| 13/13 [00:07<00:00,  1.79it/s]


Epoch 6 Loss: 2.9645719894996057



Epoch 8: 100%|██████████| 13/13 [00:07<00:00,  1.78it/s]


Epoch 7 Loss: 2.550121307373047



Epoch 9: 100%|██████████| 13/13 [00:07<00:00,  1.79it/s]


Epoch 8 Loss: 1.9661319164129405



Epoch 10: 100%|██████████| 13/13 [00:07<00:00,  1.80it/s]


Epoch 9 Loss: 1.4823351456568792



100%|██████████| 13/13 [00:07<00:00,  1.76it/s]



 accuracy is:  0.52
Share 9, size = 100


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 13/13 [00:08<00:00,  1.61it/s]


Epoch 0 Loss: 5.693298523242657



Epoch 2: 100%|██████████| 13/13 [00:08<00:00,  1.61it/s]


Epoch 1 Loss: 5.027247062096229



Epoch 3: 100%|██████████| 13/13 [00:08<00:00,  1.61it/s]


Epoch 2 Loss: 4.312583501522358



Epoch 4: 100%|██████████| 13/13 [00:08<00:00,  1.60it/s]


Epoch 3 Loss: 3.7060553293961744



Epoch 5: 100%|██████████| 13/13 [00:08<00:00,  1.59it/s]


Epoch 4 Loss: 3.354290943879348



Epoch 6: 100%|██████████| 13/13 [00:08<00:00,  1.59it/s]


Epoch 5 Loss: 2.938144188660842



Epoch 7: 100%|██████████| 13/13 [00:08<00:00,  1.58it/s]


Epoch 6 Loss: 2.6681164044600267



Epoch 8: 100%|██████████| 13/13 [00:08<00:00,  1.60it/s]


Epoch 7 Loss: 2.247289648422828



Epoch 9: 100%|██████████| 13/13 [00:08<00:00,  1.61it/s]


Epoch 8 Loss: 1.956817076756404



Epoch 10: 100%|██████████| 13/13 [00:08<00:00,  1.60it/s]


Epoch 9 Loss: 1.577721077662248



100%|██████████| 13/13 [00:04<00:00,  3.13it/s]



 accuracy is:  0.44
Share 10, size = 100


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 0 Loss: 6.007621398338904



Epoch 2: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 1 Loss: 5.502606061788706



Epoch 3: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 2 Loss: 4.930734671079195



Epoch 4: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 3 Loss: 4.358292377912081



Epoch 5: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 4 Loss: 3.902817671115582



Epoch 6: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 5 Loss: 3.5056703640864444



Epoch 7: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 6 Loss: 3.162917595643264



Epoch 8: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 7 Loss: 2.7393771501687856



Epoch 9: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 8 Loss: 2.218337526688209



Epoch 10: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 9 Loss: 1.7829309243422289



100%|██████████| 13/13 [00:09<00:00,  1.31it/s]



 accuracy is:  0.61


In [128]:
difficulty_scores

[58.0, 49.0, 44.0, 54.0, 50.0, 52.0, 43.0, 52.0, 44.0, 61.0]

#### Curriculum Arrangement

In [129]:
# model
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-5)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [130]:
stages = list(zip(train_splits, difficulty_scores))

sorted_stages = sorted(stages, key=lambda x: x[1])

epochs = 1

for i, (train_split, difficulty_score) in enumerate(sorted_stages, 1):
  print(f"Stage {i}, difficulty score = {difficulty_score}")

  tokenized_train_split, processed_train_split = process_train(train_split)
  train_split_dataloader = get_train_dataloader(tokenized_train_split, processed_train_split, batch_size)

  train(train_split_dataloader)

Stage 1, difficulty score = 43.0


Epoch 1: 100%|██████████| 13/13 [00:06<00:00,  1.88it/s]


Epoch 0 Loss: 5.66234592291025

Stage 2, difficulty score = 44.0


Epoch 1: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 0 Loss: 5.533884525299072

Stage 3, difficulty score = 44.0


Epoch 1: 100%|██████████| 13/13 [00:08<00:00,  1.60it/s]


Epoch 0 Loss: 4.825092975909893

Stage 4, difficulty score = 49.0


Epoch 1: 100%|██████████| 13/13 [00:08<00:00,  1.58it/s]


Epoch 0 Loss: 4.287556391495925

Stage 5, difficulty score = 50.0


Epoch 1: 100%|██████████| 13/13 [00:07<00:00,  1.65it/s]


Epoch 0 Loss: 4.109731912612915

Stage 6, difficulty score = 52.0


Epoch 1: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 0 Loss: 3.5619541314932017

Stage 7, difficulty score = 52.0


Epoch 1: 100%|██████████| 13/13 [00:07<00:00,  1.80it/s]


Epoch 0 Loss: 3.9690938546107364

Stage 8, difficulty score = 54.0


Epoch 1: 100%|██████████| 13/13 [00:09<00:00,  1.43it/s]


Epoch 0 Loss: 3.6029193584735575

Stage 9, difficulty score = 58.0


Epoch 1: 100%|██████████| 13/13 [00:07<00:00,  1.63it/s]


Epoch 0 Loss: 3.2861941044147196

Stage 10, difficulty score = 61.0


Epoch 1: 100%|██████████| 13/13 [00:09<00:00,  1.42it/s]


Epoch 0 Loss: 3.652097463607788



In [131]:
epochs = 10
train(train_dataloader)

Epoch 1: 100%|██████████| 125/125 [01:34<00:00,  1.32it/s, Loss=2.83]


Epoch 0 Loss: 2.7785153646469114



Epoch 2: 100%|██████████| 125/125 [01:36<00:00,  1.30it/s, Loss=2.1]


Epoch 1 Loss: 2.074807716846466



Epoch 3: 100%|██████████| 125/125 [01:35<00:00,  1.30it/s, Loss=1.55]


Epoch 2 Loss: 1.5493745169639588



Epoch 4: 100%|██████████| 125/125 [01:36<00:00,  1.30it/s, Loss=1.07]


Epoch 3 Loss: 1.048546949148178



Epoch 5: 100%|██████████| 125/125 [01:34<00:00,  1.32it/s, Loss=0.736]


Epoch 4 Loss: 0.7400834363698959



Epoch 6: 100%|██████████| 125/125 [01:34<00:00,  1.32it/s, Loss=0.503]


Epoch 5 Loss: 0.4972257369160652



Epoch 7: 100%|██████████| 125/125 [01:35<00:00,  1.31it/s, Loss=0.348]


Epoch 6 Loss: 0.34840413030982015



Epoch 8: 100%|██████████| 125/125 [01:35<00:00,  1.30it/s, Loss=0.222]


Epoch 7 Loss: 0.2210002940967679



Epoch 9: 100%|██████████| 125/125 [01:35<00:00,  1.30it/s, Loss=0.154]


Epoch 8 Loss: 0.16077140829712153



Epoch 10: 100%|██████████| 125/125 [01:35<00:00,  1.31it/s, Loss=0.139]


Epoch 9 Loss: 0.1337820711405948



In [132]:
calculate_accuracy(tokenized_validation, validation_squad, val_dataloader)

100%|██████████| 125/125 [01:35<00:00,  1.31it/s]


 accuracy is:  0.44





In [133]:
evaluate()

'{\n  "exact": 44.2,\n  "f1": 46.12190476190477,\n  "total": 1000,\n  "HasAns_exact": 0.33738191632928477,\n  "HasAns_f1": 0.661589872116188,\n  "HasAns_total": 5928,\n  "NoAns_exact": 7.098402018502943,\n  "NoAns_f1": 7.098402018502943,\n  "NoAns_total": 5945,\n  "best_exact": 594.5,\n  "best_exact_thresh": 0.0,\n  "best_f1": 594.5,\n  "best_f1_thresh": 0.0,\n  "pr_exact_ap": 0.04978528437952521,\n  "pr_f1_ap": 0.1562556621535656,\n  "pr_oracle_ap": 4.698402914437673\n}'

In [134]:
evaluation = json.loads(evaluate())
print(f'F1 score = {evaluation["f1"]}')

F1 score = 46.12190476190477


In [135]:
print(f'Exact Match score = {evaluation["exact"]}')

Exact Match score = 44.2


## Results and Future Work

In [141]:
model1 = {
    "exact_match": 43.3,
    "f1": 41.4,
}

cl_model = {
    "exact_match": evaluation["exact"],
    "f1": evaluation["f1"],
}

df = pd.DataFrame({
    "Metric": ["EM", "F1"],
    "No CL": [model1["exact_match"], model1["f1"]],
    "CL": [cl_model["exact_match"], cl_model["f1"]],
})

print(df)

  Metric  No CL         CL
0     EM   43.3  44.200000
1     F1   41.4  46.121905


*  The result of this experiment is inline with the results described in the paper. Curriculum learning yielded a better performance on the machine reading comprehension task with higher F1 and exact match scores.
*   Only a simple aspect of the results in the paper is reproduced in this experiment. So, with enough time and compute, future work can cover:
  - Meticulously reproducing all the results from the paper spanning different NLU tasks, with the same exact details.
  - Experimenting with variables and hyperparameters like the number of training stages N and the number of epochs to extend the results presented in the paper.
  - Incorporate and merge different training paradigms into the curriculum learning framework.


