<a href="https://colab.research.google.com/github/hjori66/Kaist-AI605-2021-Spring/blob/main/KAIST_AI605_Assignment_2_20194364.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# KAIST AI605 Assignment 2: Token Classification with RNNs and Attention
Author: Minjoon Seo (minjoon@kaist.ac.kr)

TA in charge: Taehyung Kwon (taehyung.kwon@kaist.ac.kr)

**Due date**:  April 19 (Mon) 11:00pm, 2021  


Your name: Taehwan Kim

Your student ID: 20194364

Your collaborators: -

## Assignment Objectives
- Verify theoretically and empirically how Transformer's attention mechanism works for sequence modeling task.
- Implement Transformer's encoder attention layer from scratch using PyTorch.
- Design an Attention-based token classification model using PyTorch.
- Apply the token classification model to a popular machine reading comprehension task, Stanford Question Answering Dataset (SQuAD).
- (Bonus) Analyze pros and cons between using RNN + attention versus purely attention.

## Your Submission
Your submission will be a link to a Colab notebook that has all written answers and is fully executable. You will submit your assignment via KLMS. Use in-line LaTeX (see below) for mathematical expressions. Collaboration among students is allowed but it is not a group assignment so make sure your answer and code are your own. Also make sure to mention your collaborators in your assignment with their names and their student ids.

## Grading
The entire assignment is out of 100 points. There are two bonus questions with 30 points altogether. Your final score can be higher than 100 points.


## Environment
You will only use Python 3.7 and PyTorch 1.8, which is already available on Colab:

In [1]:
from platform import python_version
import torch

print("python", python_version())
print("torch", torch.__version__)

python 3.7.10
torch 1.8.1+cu101


## 1. Transformer's Attention Layer

We will first start with going over a few concepts that you learned in your high school statistics class. The variance of a random variable $X$, $\text{Var}(X)$ is defined as $\text{E}[(X-\mu)^2]$ where $\mu$ is the mean of $X$. Furthermore, given two independent random variables $X$ and $Y$ and a constant $a$,
$$ \text{Var}(X+Y) = \text{Var}(X) + \text{Var}(Y), \quad \ldots \; \text{(1)}$$ 
$$ \text{Var}(aX) = a^2\text{Var}(X), \quad \ldots \; \text{(2)}$$
$$ \text{Var}(XY) = \text{E}(X^2)\text{E}(Y^2) - [\text{E}(X)]^2[\text{E}(Y)]^2. \quad \ldots \; \text{(3)}$$

**Problem 1.1** *(10 points)* Suppose we are given two sets of $n$ random variables, $X_1 \dots X_n$ and $Y_1 \dots Y_n$, where all of these $2n$ variables are mutually independent and have a mean of $0$ and a variance of $1$. Prove that
$$\text{Var}\left(\sum_i^n X_i Y_i\right) = n.$$

There is a typo in the formula $\text{(2)}$. I changed it.

**Answer 1.1** 

$$
\begin{align}
  \text{Var}\left(\sum_i^n X_i Y_i\right)
  &= \sum_i^n \text{Var} \left(X_i Y_i\right) \quad \because \text{(1), independence} \\
  &= \sum_i^n \left[\text{E}(X_i^2)\text{E}(Y_i^2) - [\text{E}(X_i)]^2[\text{E}(Y_i)]^2 \right] \quad \because \text{(3)} \\
  &= \sum_i^n \left[\text{E}((X_i-0)^2)\text{E}((Y_i-0)^2)\right] \\
  &= \sum_i^n \left[\text{E}((X_i-[\text{E}(X_i)])^2)\text{E}((Y_i-[\text{E}(Y_i)])^2)\right] \\
  &= \sum_i^n \left[\text{Var}(X_i) \text{Var}(Y_i)\right] \\
  &= \sum_i^n \left[1\right] = n \\
\end{align}
\\
$$

In Lecture 08 and 09, we discussed how the attention is computed in Transformer via the following equation,
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V.$$
**Problem 1.2** *(10 points)*  Suppose $Q$ and $K$ are matrices of independent variables each of which has a mean of $0$ and a variance of $1$. Using what you learned from Problem 1.1., show that
$$\text{Var}\left(\frac{QK^\top}{\sqrt{d_k}}\right) = 1.$$

**Answer 1.2** 

$$
\begin{align}
  \text{Var}\left(\frac{QK^\top}{\sqrt{d_k}}\right)
  &= \left(\frac{1}{\sqrt{d_k}}\right)^2 \text{Var}\left(QK^\top\right)  \quad \because \text{(2)} \\
  &= \frac{1}{d_k} \text{Var}\left(QK^\top\right) \\
\end{align}
\\
\\
$$

Then, we can focus on the one element of $QK$, $\left(QK\right)_{ij}$ without loss of generality.

$$
\begin{align}
  \frac{1}{d_k} \text{Var}\left(\left(QK\right)_{ij}^\top\right)
  &= \frac{1}{d_k} \text{Var}\left(\sum_{t}^{d_k} \left(Q_{it} K_{tj}\right) \right) \\
  &= \frac{1}{d_k} \left(\sum_{t}^{d_k} \text{Var}\left(Q_{it} K_{tj}\right) \right) \quad \because \text{(1), } Q_{it} \text{ and } Y_{tj} \text{ are mutually independent} \\
  &= \frac{1}{d_k} \left(d_k\right) = 1 \quad \because \text{Problem 1.1} \\
\end{align}
\\
$$

Therefore, 

$$
\text{Var}\left(\frac{QK^\top}{\sqrt{d_k}}\right) = 1.
$$



**Problem 1.3** *(10 points)* What would happen if the assumption that the variance of $Q$ and $K$ is $1$ does not hold? Consider each case of it being higher and lower than $1$ and conjecture what it implies, respectively.

**Answer 1.3** \

If the variance of $Q_{ij}$ and $K_{ij}$ is higher than 1 for all i and j, then

$$
\text{Var}\left(\frac{QK^\top}{\sqrt{d_k}}\right) > 1.
$$

Then, the variance of the output of the decoder becomes larger.
If we use the softmax function on this output, then the final result might be "too" sharp. (Actually, this is not true, because of the residual connection) \
So, I guess that the model overfits faster than original model.

\

Otherwse, if the variance of $Q_{ij}$ and $K_{ij}$ is lower than 1 for all i and j, then

$$
\text{Var}\left(\frac{QK^\top}{\sqrt{d_k}}\right) < 1.
$$

Then, the variance of the output of the decoder becomes smaller.
If we use the softmax function on this output, then the final result might be "too" smooth. 
\
So, I guess that the early training would be more unstable than normal although we use the bigger learning rate. The training time would be longer than the original version. 

## 2. Preprocessing SQuAD

We will use `datasets` package offered by Hugging Face, which allows us to easily download various language datasets, including Stanford Question Answering Dataset (SQuAD).

First, install the package:

In [2]:
!pip install datasets



Then, download SQuAD and print the first example:

In [3]:
from datasets import load_dataset
squad_dataset = load_dataset('squad')
print(squad_dataset['train'][0])

Reusing dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4fffa6cf76083860f85fa83486ec3028e7e32c342c218ff2a620fc6b2868483a)


{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']}, 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', 'id': '5733be284776f41900661182', 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', 'title': 'University_of_Notre_Dame'}


Here, `answer_start` corresponds to the character-level start position of the answer, and `text` is the answer text itself. You will note that `answer_start` and `text` fields are given as lists but they only contain one item each. In fact, you can safely assume that this is the case for the training data. During evaluation, however, you will utilize several possible answers so that your evaluation can be compared against all of them. So your code need to handle multiple-answers case as well.

As we discussed in Lecture 05, we want to formulate this task as a token classification problem. That is, we want to find which token of the context corresponds to the start position of the answer, and which corresponds to the end.

**Problem 2.1** *(10 points)* Write `preprocess()` function that takes a SQuAD example as the input and outputs space-tokenized context and question, as well as the start and end token position of the answer if it has the answer field. That is, a pseudo code would look like:
```python
def preprocess(example):
  out = {'context': ['each', 'token'], 
         'question': ['each', 'token']}
  if 'answers' not in example:
    return out
  out['answers'] = [{'start': 3, 'end': 5}]
  return out
```
Verify that this code works by comparing between the original answer text and the concatenation of the answer tokens from start to end in training data. Report the percentage of the questions that have exact match.

**Answer 2.1**



In [4]:
# def preprocess(example):
#     def tokenizer(sentence):
#         if sentence is None:
#             return list()
#         return sentence.split()

#     out = dict()
#     context = example['context']
#     question = example['question']

#     out['context'] = tokenizer(context)
#     out['question'] = tokenizer(question)

#     answer_list = list()
#     for answer_index in range(len(example['answers']['text'])):
#         answer = example['answers']['text'][answer_index]
#         answer_start = example['answers']['answer_start'][answer_index]
#         answer_end = answer_start + len(answer)

#         if (answer_start == 0 or context[answer_start-1] == ' ') \
#             and (answer_end == len(context) or context[answer_end] == ' '):
#             n_tokens_before_answer = len(tokenizer(context[:answer_start]))
#             n_tokens_answer = len(tokenizer(answer))
#             answer_list.append({'start': n_tokens_before_answer, 'end': n_tokens_before_answer+n_tokens_answer-1})
    
#     if not answer_list:
#         out['answers'] = answer_list

#     return out

# n_exact_match = 0.0
# for i, example in enumerate(squad_dataset['train']):
#     out = preprocess(example)
#     if 'answers' in out.keys():
#         n_exact_match += 1


** Result 2.1 **


```
print("number of answers that have exact match : ", n_exact_match)
print("number of training data : ", len(squad_dataset['train']))
print("the percentage of the exact match (training) : ", n_exact_match / len(squad_dataset['train']))

# number of answers that have exact match :  41616.0
# number of training data :  87599
# the percentage of the exact match :  0.47507391636890833 -> pretty low

```



We want to maximize the percentage of the exact match. You might see a low percentage however, due to bad tokenization. For instance, such space-based tokenization will fail to separate between "world" and "!" in "hello world!". 

**Problem 2.2** *(10 points)* Write an advanced tokenization model that always separates non-alphabet characters as independent tokens. For instance, "hello1 world!!" will be tokenized into "hello", "1", "world", "!", and "!". Using this new tokenizer, re-run the `preprocess` function and report the exact match percentage. How does the ratio change?

**Answer 2.2**



In [5]:
def find_nonalpha_list(dataset):
    nonalpha_list = list()
    for example in dataset:
        for c in example['context']:
            if not c.isalpha() and c not in nonalpha_list:
                nonalpha_list.append(c)
        for c in example['question']:
            if not c.isalpha() and c not in nonalpha_list:
                nonalpha_list.append(c)
    return nonalpha_list


def preprocess(example, nonalpha_list):
    def tokenizer(sentence, nonalpha_list):
        if sentence is None:
            return list()

        for nonalpha_token in nonalpha_list:
            sentence = sentence.replace(nonalpha_token, ' ' + nonalpha_token + ' ')

        sentence = ' '.join(sentence.split())
        return sentence.split()
        
    out = dict()
    context = example['context']
    question = example['question']
    id = example['id']

    out['context'] = tokenizer(context, nonalpha_list)
    out['question'] = tokenizer(question, nonalpha_list)
    out['id'] = id

    answer_list = list()
    for answer_index in range(len(example['answers']['text'])):
        answer = example['answers']['text'][answer_index]
        answer_start = example['answers']['answer_start'][answer_index]
        answer_end = answer_start + len(answer)

        if (answer_start == 0 or context[answer_start-1] in nonalpha_list) \
            and (answer_end == len(context) or context[answer_end] in nonalpha_list):
            n_tokens_before_answer = len(tokenizer(context[:answer_start], nonalpha_list))
            n_tokens_answer = len(tokenizer(answer, nonalpha_list))
            answer_list.append({'start': n_tokens_before_answer, 'end': n_tokens_before_answer+n_tokens_answer-1})

    if answer_list:
        out['answers'] = answer_list

    return out

# dataset = squad_dataset['train']
# # dataset = squad_dataset['validation'] # Do this if you want to check the valid dataset

# nonalpha_list = find_nonalpha_list(squad_dataset['train'])
# print("nonalpha_list : ", nonalpha_list)

# n_exact_match = 0.0
# for i, example in enumerate(dataset):
#   out = preprocess(example, nonalpha_list)
#   if 'answers' in out.keys():
#     n_exact_match += 1

# print("number of answers that have exact match : ", n_exact_match)
# print("number of training data : ", len(dataset))
# print("the percentage of the exact match (training) : ", n_exact_match / len(dataset))


** Result 2.2 **


```
nonalpha_list :  [',', ' ', '.', "'", '"', '1', '8', '5', '(', '3', ')', '?', '-', '7', '6', '9', '2', '0', ';', '–', '&', '4', '%', '$', '[', ']', '/', ':', '#', '—', '!', '“', '’', '”', '<', '\u200b', '̃', '£', '½', '+', '¢', '−', '°', '>', '€', '《', '》', '±', '~', '¥', '²', '❤', '=', '\u200e', '͡', '́', '`', '्', 'ु', 'ः', 'ॊ', 'ि', 'ा', '\u200d', '\u200c', '*', '‘', '\u3000', '•', '§', '⁄', '\n', '̯', '̩', '…', '·', 'ָ', 'ִ', 'ׁ', 'ַ', 'ּ', 'ְ', 'ّ', '⟨', '◌', '⟩', '˭', '̤', '♠', '∅', '̞', '×', '̥', '′', '″', '\ufeff', '_', 'ֿ', '´', '^', '̧', '̄', '→', '‑', '，', '₹', '\u202f', '♯', '₂', '₥', '⁊', '\u2009', '{', '}', '|', '@', '̪', '‚', '›', 'ׂ', 'ֵ', 'ِ', 'ْ', 'َ', '̍', '˥', '˨', '˩', '¡', '√', '¿', 'ာ', 'း', 'ُ', '≥', '˚', '≈', '⋅', 'ี', '︘', '�', '～', '〜', '̀', 'ོ', '་', '˧', 'ಾ', 'ು', '್', 'া', '্', 'ಿ', '∗', '∈', '≡', '∖', '№', '÷', 'ٔ', '¶', 'ิ', '₤', '♆', '⅓', '∝', '¼', 'ٍ', 'ֹ', '̌', '。', '̠', '₯']
number of answers that have exact match :  87108.0
number of training data :  87599
the percentage of the exact match (training) :  0.9943949131839405 -> now, it is OK

number of answers that have exact match :  10566.0
number of validation data :  10570
the percentage of the exact match (validation) :  0.9996215704824977 -> Also, it is OK

```



## 3. LSTM Baseline for SQuAD

We will bring and reuse our model from Assignment 1. There are two key differences, however. First, we need to classify each token instead of the entire sentence. Second, we have two inputs (context and question) instead of just one.  

In [6]:
# My model from Assignment 1 is too slow to use it.
# I will use torchtext and torch.nn.lstm in the assignment 2.

import torchtext
from torchtext.legacy import data
from torchtext.legacy import datasets
from torchtext.legacy.data import BucketIterator


class SQuAD1Dataset(data.Dataset):
  """
  Defines a dataset for squad1.0.
  """
  
  @staticmethod
  def sort_key(ex):
    return data.interleave_keys(len(ex.context), len(ex.question))

  def __init__(self, data_list, fields, use_bos=True, max_length=None, **kwargs):
    if not isinstance(fields[0], (tuple, list)):
      fields = [('context', fields[0]), 
                ('question', fields[1]), 
                # ('context_question', fields[2]), # For Problem 3.2+, put the question after the context
                ('answer_start', fields[2]), 
                ('answer_end', fields[3]), 
                ('id_index', fields[4])
                ]

    examples = []
    nonalpha_list = find_nonalpha_list(data_list)

    self.id_list = list()
    self.reference = list()

    for _, example in enumerate(data_list):
        out = preprocess(example, nonalpha_list)
        # use data if the answer exists
        if max_length and max_length < max(len(out['context']), len(out['question'])):
            continue
        if 'answers' in out.keys():
            answer_start = out['answers'][0]['start'] # Use index 0 for instant valid accuracy
            answer_end = out['answers'][0]['end'] # Use index 0 for instant valid accuracy
            if use_bos:
                answer_start += 1 # for <BOS> token
                answer_end += 1 # for <BOS> token
            examples.append(data.Example.fromlist([out['context'], 
                                                   out['question'], 
                                                #    out['context'] + ["<CLS>"] + out['question'], # Use <CLS> token
                                                   answer_start,
                                                   answer_end,
                                                   len(self.id_list)], 
                                                  fields))
            self.id_list.append(out['id'])
            self.reference.append({'id':example['id'], 'answers':example['answers']})

    super(SQuAD1Dataset, self).__init__(examples, fields, **kwargs)


class SQuAD1Dataloader():
  """
  Make the dataloader for SQuAD 1.0
  """
  def __init__(self, train_data=None, valid_data=None, batch_size=64, device='cpu', 
                max_length=255, min_freq=2, fix_length=None,
                use_bos=True, use_eos=True, shuffle=True
              ):

    super(SQuAD1Dataloader, self).__init__()

    self.text = data.Field(sequential=True, use_vocab=True, batch_first=True, 
                           include_lengths=True, fix_length=fix_length, 
                           init_token='<BOS>' if use_bos else None, 
                           eos_token='<EOS>' if use_eos else None
                          )
    self.answer_start = data.Field(sequential = False, use_vocab = False)
    self.answer_end = data.Field(sequential = False, use_vocab = False)
    self.id_index = data.Field(sequential = False, use_vocab = False)
    
    train = SQuAD1Dataset(data_list=train_data, 
                          fields = [('context', self.text),
                                    ('question', self.text),
                                    # ('context_question', self.text),
                                    ('answer_start', self.answer_start),
                                    ('answer_end', self.answer_end),
                                    ('id_index', self.id_index)
                                    ], 
                          use_bos = use_bos,
                          max_length = max_length
                          )
    valid = SQuAD1Dataset(data_list=valid_data, 
                          fields = [('context', self.text),
                                    ('question', self.text),
                                    # ('context_question', self.text),
                                    ('answer_start', self.answer_start),
                                    ('answer_end', self.answer_end),
                                    ('id_index', self.id_index)
                                    ], 
                          use_bos = use_bos,
                          max_length = max_length
                          )
    self.train_id_list = train.id_list
    self.valid_id_list = valid.id_list

    self.train_reference = train.reference
    self.valid_reference = valid.reference
    
    self.train_iter = data.BucketIterator(train, batch_size=batch_size,
                                          device=device,
                                          shuffle=shuffle,
                                          sort_key=lambda x: len(x.question) + (max_length * len(x.context)), 
                                          sort_within_batch = True
                                          )
    self.valid_iter = data.BucketIterator(valid, batch_size=batch_size,
                                          device=device,
                                          shuffle=shuffle,
                                          sort_key=lambda x: len(x.question) + (max_length * len(x.context)), 
                                          sort_within_batch = True
                                          )
    
    self.text.build_vocab(train)


train_dataset = squad_dataset['train']
valid_dataset = squad_dataset['validation']

print('# of train data : {}'.format(len(train_dataset)))
print('# of vaild data : {}'.format(len(valid_dataset)))

batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_length = 253 # 255 - 2 for <BOS> and <EOS>
min_freq = 2
use_bos = False
use_eos = False
print("device : ", device)

loader = SQuAD1Dataloader(train_dataset, valid_dataset, batch_size=batch_size, 
                          device=device, max_length=max_length, min_freq=min_freq,
                          use_bos=use_bos, use_eos=use_eos)
print('\nFinish making the dataloader')
print("batch_size : ", batch_size)
print("max_length : ", max_length)
print('number of used train data ~ {}'.format((len(loader.train_iter)) * batch_size))
print('number of used vaild data ~ {}'.format((len(loader.valid_iter)) * batch_size))

vocab = loader.text.vocab
vocab_list = list(vocab.stoi.keys())
print('number of vocab : {}'.format(len(vocab)))


# of train data : 87599
# of vaild data : 10570
device :  cuda

Finish making the dataloader
batch_size :  128
max_length :  253
number of used train data ~ 81408
number of used vaild data ~ 9856
number of vocab : 86389


** Result 3.1.1 (Preprocessing) **


```
# of train data : 87599
# of vaild data : 10570
device :  cuda

Finish making the dataloader
batch_size :  128
max_length :  253
number of used train data ~ 81408
number of used vaild data ~ 9856
number of vocab : 86389

```




Before resolving these differences, you will need to define your evaluation function to correctly evaluate how well your model is doing. Note that the evaluation was very straightforward in Assignment 1's sentiment classification (it is either positive or negative) while it is a bit complicated in SQuAD. We will use the evaluation function provided by `datasets`. You can access to it via the following code.  

In [7]:
from datasets import load_metric
squad_metric = load_metric('squad')

You can also easily learn about how to use the function by simply typing the function:

In [8]:
squad_metric

Metric(name: "squad", features: {'predictions': {'id': Value(dtype='string', id=None), 'prediction_text': Value(dtype='string', id=None)}, 'references': {'id': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}}, usage: """
Computes SQuAD scores (F1 and EM).
Args:
    predictions: List of question-answers dictionaries with the following key-values:
        - 'id': id of the question-answer pair as given in the references (see below)
        - 'prediction_text': the text of the answer
    references: List of question-answers dictionaries with the following key-values:
        - 'id': id of the question-answer pair (see above),
        - 'answers': a Dict in the SQuAD dataset format
            {
                'text': list of possible texts for the answer, as a list of strings
                'answer_start': list of start positions for the answer, as a list of ints
   

**Problem 3.1** *(10 points)* Let's resolve the first issue here. Hence, for now, assume that your only input is context and you want to obtain the answer without seeing the question. While this may seem to be a non-sense, actually it can be considered as modeling the prior $\text{Prob}(a|c)$ before observing $q$ (we ultimately want $\text{Prob}(a|q,c)$). Transform your model into a token classification model by imposing $\text{softmax}$ over the tokens instead of predefined classes. You will need to do this twice for each of start and end. Report the accuracy (using the metric above) on `squad_dataset['validation']`. 

**Answer 3.1**



In [9]:
import torch.nn as nn
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class ClassificationLSTMModel(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, n_layers, n_label, emb_dropout, rnn_dropout, bidirectional, enable_layer_norm, device):
        super(ClassificationLSTMModel, self).__init__()
        self.embedding = nn.Embedding(len(vocab), embedding_dim)
        self.lstm = nn.LSTM(input_size=embedding_dim, 
                            hidden_size=hidden_dim, 
                            num_layers=n_layers, 
                            dropout=rnn_dropout, 
                            bidirectional=bidirectional)
      
        n_direction = 2 if bidirectional else 1
        self.fc_start = nn.Linear(hidden_dim*n_direction, n_label, bias=True)
        self.fc_end = nn.Linear(hidden_dim*n_direction, n_label, bias=True)

        # Layer_normalization
        self.enable_layer_norm = enable_layer_norm
        if enable_layer_norm:
            self.emb_layer_norm = nn.LayerNorm(embedding_dim)

        self.emb_dropout = nn.Dropout(emb_dropout)
        self.fc_dropout = nn.Dropout(rnn_dropout)
        self.bidirectional = bidirectional
        self.device = device

    def forward(self, input_tensor, src_seq_lens):
        emb = self.embedding(input_tensor) # emb.shape = batch * len * hidden

        # Layer_normalization
        if self.enable_layer_norm:
            emb = self.emb_layer_norm(emb)

        emb = self.emb_dropout(emb)
        emb = emb.transpose(0, 1) # emb.shape = len * batch * hidden

        # n_direction = 2 if bidirectional else 1
        # hidden = torch.zeros(n_layers*n_direction, context.shape[0], hidden_dim, requires_grad=True).to(self.device)
        # cell = torch.zeros(n_layers*n_direction, context.shape[0], hidden_dim, requires_grad=True).to(self.device)

        # nn.LSTM
        packed = pack_padded_sequence(emb, src_seq_lens.tolist(), batch_first=False)
        outs, (hidden, cell) = self.lstm(packed)
        outs, out_lens = pad_packed_sequence(outs, batch_first=False)

        if self.bidirectional:
            hidden = torch.stack([hidden[-2], hidden[-1]], dim=0)
        else:
            hidden = hidden[-1].unsqueeze(dim=0)
        hidden = hidden.transpose(0, 1)
        hidden = hidden.contiguous().view(hidden.shape[0], -1)

        hidden = self.fc_dropout(hidden)
        logits_start = self.fc_start(hidden)
        logits_end = self.fc_end(hidden)
        return (logits_start, logits_end)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device : ", device)

# Vocabulary : Use vocab_list

# Construct the LSTM Model
embedding_dim = 128 # usually bigger, e.g. 128
hidden_dim = 256
n_layers = 2
n_label = max_length+1 if use_bos else max_length
emb_dropout = 0.1
rnn_dropout = 0.5
bidirectional = True
enable_layer_norm = True
rnnmodel = ClassificationLSTMModel(embedding_dim, hidden_dim, n_layers, n_label, emb_dropout, rnn_dropout, bidirectional, enable_layer_norm, device).to(device)

print("batch_size : ", batch_size)
print("max_length : ", max_length)
print("embedding_dim : ", embedding_dim)
print("hidden_dim : ", hidden_dim)
print("n_layers : ", n_layers)
print("emb_dropout : ", emb_dropout)
print("rnn_dropout_and_fc_dropout : ", rnn_dropout)
if bidirectional:
    print("bidirectional : True")
else:
    print("bidirectional : False")
if enable_layer_norm:
    print("enable_layer_norm : True")
else:
    print("enable_layer_norm : False")

# Construct the data loader
train_iter = loader.train_iter
valid_iter = loader.valid_iter

train_id_list = loader.train_id_list
valid_id_list = loader.valid_id_list
train_reference = loader.train_reference
valid_reference = loader.valid_reference

# Training
learning_rate = 1e-3
print("learning_rate : ", learning_rate)

PAD_IDX = vocab.stoi['<pad>']
cel = nn.CrossEntropyLoss(ignore_index=PAD_IDX) # Ignore Padding
# optimizer = torch.optim.SGD(rnnmodel.parameters(), lr=1e-1)
optimizer = torch.optim.Adam(rnnmodel.parameters(), lr=learning_rate)

epochs = 30
max_norm = 5

# Evaluate
squad_metric = load_metric('squad') # get_tokens() in squad_metric is not exactly same as my preprocess()..


for epoch in tqdm(range(epochs)):
    train_loss = 0
    train_accuracy = 0.0
    train_data_num = 0
    train_prediction = list()
    for train_i, train_batch in enumerate(train_iter):
        context, context_length = train_batch.context
        question, question_length = train_batch.question # Unused
        answer_start = train_batch.answer_start
        answer_end = train_batch.answer_end
        train_id_index = train_batch.id_index

        logits_start, logits_end = rnnmodel(context, context_length)

        optimizer.zero_grad() # reset process
        loss = cel(logits_start, answer_start) + cel(logits_end, answer_end) # Loss, a.k.a L

        loss.backward() # compute gradients
        # print(torch.norm(rnnmodel.lstm.weight_hh_l0.grad), loss.item())
        # torch.nn.utils.clip_grad_norm_(rnnmodel.parameters(), max_norm) # gradent clipping
        optimizer.step() # update parameters
        train_loss += loss.item()
        
        _, train_start_preds = torch.max(logits_start, 1)
        _, train_end_preds = torch.max(logits_end, 1)
        # train_accuracy += ((train_start_preds == answer_start) * (train_end_preds == answer_end)).sum().float()

        train_data_num += context.shape[0]

        for train_j in range(context.shape[0]):
            pred_text = ""
            start = train_start_preds[train_j]
            end = train_end_preds[train_j]
            if start < end:
                pred_text = [vocab_list[text_id] for text_id in context[train_j][start:end+1]]
                pred_text = " ".join(pred_text)
            
            # start = answer_start[train_j]
            # end = answer_end[train_j]
            # answer_text = [vocab_list[text_id] for text_id in context[train_j][start:end+1]]
            # answer_text = " ".join(answer_text)
            # print(pred_text, answer_text, train_reference[train_id_index[train_j]], "\n")

            train_prediction.append({'id':train_id_list[train_id_index[train_j]], 'prediction_text':pred_text})

    train_result = squad_metric.compute(predictions=train_prediction, references=train_reference)
    print('train:: Epoch:', '%04d' % (epoch + 1), 
          'cost =', '{:.6f},'.format(train_loss / train_data_num), 
        #   'argmax acc =', '{:.6f}'.format(train_accuracy / train_data_num),
          'squad_metric : ', train_result
          )
        
    if (epoch + 1) % 1 == 0:
        with torch.no_grad():
            valid_loss = 0
            valid_accuracy = 0.0
            valid_data_num = 0
            valid_prediction = list()
            for valid_i, valid_batch in enumerate(valid_iter):
                context, context_length = valid_batch.context
                question, question_length = valid_batch.question # Unused
                answer_start = valid_batch.answer_start
                answer_end = valid_batch.answer_end
                valid_id_index = valid_batch.id_index

                logits_start, logits_end = rnnmodel(context, context_length)

                loss = cel(logits_start, answer_start) + cel(logits_end, answer_end) # Loss, a.k.a L
                valid_loss += loss.item()

                _, valid_start_preds = torch.max(logits_start, 1)
                _, valid_end_preds = torch.max(logits_end, 1)
                # valid_accuracy += ((valid_start_preds == answer_start) * (valid_end_preds == answer_end)).sum().float()

                valid_data_num += context.shape[0]

                for valid_j in range(context.shape[0]):
                    pred_text = ""
                    start = valid_start_preds[valid_j]
                    end = valid_end_preds[valid_j]
                    if start < end:
                        pred_text = [vocab_list[text_id] for text_id in context[valid_j][start:end+1]]
                        pred_text = " ".join(pred_text)
                    valid_prediction.append({'id':valid_id_list[valid_id_index[valid_j]], 'prediction_text':pred_text})
                
            valid_result = squad_metric.compute(predictions=valid_prediction, references=valid_reference)
            print('valid:: Epoch:', '%04d' % (epoch + 1), 
                  'cost =', '{:.6f},'.format(valid_loss / valid_data_num), 
                #   'argmax acc =', '{:.6f},'.format(valid_accuracy / valid_data_num),
                  'squad_metric : ', valid_result
                 )
            

device :  cuda
batch_size :  128
max_length :  253
embedding_dim :  128
hidden_dim :  256
n_layers :  2
emb_dropout :  0.1
rnn_dropout_and_fc_dropout :  0.5
bidirectional : True
enable_layer_norm : True
learning_rate :  0.001


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

train:: Epoch: 0001 cost = 0.079760, squad_metric :  {'exact_match': 0.05904059040590406, 'f1': 5.52339938822115}
valid:: Epoch: 0001 cost = 0.079491, squad_metric :  {'exact_match': 0.040679345062544496, 'f1': 4.141770941182735}
train:: Epoch: 0002 cost = 0.079107, squad_metric :  {'exact_match': 0.27921279212792127, 'f1': 4.516358336672799}
valid:: Epoch: 0002 cost = 0.079342, squad_metric :  {'exact_match': 0.3356045967659921, 'f1': 4.610341876288495}
train:: Epoch: 0003 cost = 0.078534, squad_metric :  {'exact_match': 0.5301353013530136, 'f1': 4.568132562807337}
valid:: Epoch: 0003 cost = 0.079561, squad_metric :  {'exact_match': 0.4881521407505339, 'f1': 4.664714353463255}
train:: Epoch: 0004 cost = 0.077731, squad_metric :  {'exact_match': 0.7380073800738007, 'f1': 4.899126557698399}
valid:: Epoch: 0004 cost = 0.079900, squad_metric :  {'exact_match': 0.4779823044848978, 'f1': 4.848227029812154}
train:: Epoch: 0005 cost = 0.076743, squad_metric :  {'exact_match': 0.91512915129151

** Comment 3.1 **

"squad_metric" has two metric, 'exact_match' and 'f1'.
squad_metric get the sentence with string type, not list of tokens.
Therefore, I joint the word? tokens with one space before put it in the given metric.
After that, they use simple space splitter with split(), punctuation remover and Lowercase English for numericalizer.
\
However, it is different from my numericalizer in Prob 2.2, the function 'process()'. I believe that this issue brings the harsh performance (especially, 'exact_match').
\
For example, I double checked the 'exact_match' accuracy using only the start and end position of FIRST answer. Then, accuracy value is much higher than 'exact_match' value. (argmax acc on Prob 3.2)

** Result 3.1.2 (Training & Validation) **


```
# Result ::

device :  cuda
batch_size :  128
max_length :  253
embedding_dim :  128
hidden_dim :  256
n_layers :  2
emb_dropout :  0.1
rnn_dropout_and_fc_dropout :  0.5
bidirectional : True
enable_layer_norm : True
learning_rate :  0.001
100%
30/30 [52:55<00:00, 105.84s/it]
train:: Epoch: 0001 cost = 0.079760, squad_metric :  {'exact_match': 0.05904059040590406, 'f1': 5.52339938822115}
valid:: Epoch: 0001 cost = 0.079491, squad_metric :  {'exact_match': 0.040679345062544496, 'f1': 4.141770941182735}
train:: Epoch: 0002 cost = 0.079107, squad_metric :  {'exact_match': 0.27921279212792127, 'f1': 4.516358336672799}
valid:: Epoch: 0002 cost = 0.079342, squad_metric :  {'exact_match': 0.3356045967659921, 'f1': 4.610341876288495}
train:: Epoch: 0003 cost = 0.078534, squad_metric :  {'exact_match': 0.5301353013530136, 'f1': 4.568132562807337}
valid:: Epoch: 0003 cost = 0.079561, squad_metric :  {'exact_match': 0.4881521407505339, 'f1': 4.664714353463255}
train:: Epoch: 0004 cost = 0.077731, squad_metric :  {'exact_match': 0.7380073800738007, 'f1': 4.899126557698399}
valid:: Epoch: 0004 cost = 0.079900, squad_metric :  {'exact_match': 0.4779823044848978, 'f1': 4.848227029812154}
train:: Epoch: 0005 cost = 0.076743, squad_metric :  {'exact_match': 0.915129151291513, 'f1': 5.173145030424936}
valid:: Epoch: 0005 cost = 0.080609, squad_metric :  {'exact_match': 0.3966236143598088, 'f1': 4.724701760655256}
train:: Epoch: 0006 cost = 0.075741, squad_metric :  {'exact_match': 0.998769987699877, 'f1': 5.401154805239039}
valid:: Epoch: 0006 cost = 0.081186, squad_metric :  {'exact_match': 0.3966236143598088, 'f1': 4.607046859821393}
train:: Epoch: 0007 cost = 0.074733, squad_metric :  {'exact_match': 1.137761377613776, 'f1': 5.502323035873868}
valid:: Epoch: 0007 cost = 0.081922, squad_metric :  {'exact_match': 0.4271331231567172, 'f1': 4.551865069269158}
train:: Epoch: 0008 cost = 0.073769, squad_metric :  {'exact_match': 1.204182041820418, 'f1': 5.661646000270595}
valid:: Epoch: 0008 cost = 0.082516, squad_metric :  {'exact_match': 0.32543476050035597, 'f1': 4.222652645407895}
train:: Epoch: 0009 cost = 0.072869, squad_metric :  {'exact_match': 1.3751537515375154, 'f1': 5.777759974047754}
valid:: Epoch: 0009 cost = 0.083157, squad_metric :  {'exact_match': 0.3356045967659921, 'f1': 4.191452313944129}
train:: Epoch: 0010 cost = 0.072054, squad_metric :  {'exact_match': 1.3714637146371464, 'f1': 5.790611846923638}
valid:: Epoch: 0010 cost = 0.083795, squad_metric :  {'exact_match': 0.3152649242347198, 'f1': 4.178425942910128}
train:: Epoch: 0011 cost = 0.071280, squad_metric :  {'exact_match': 1.3936039360393604, 'f1': 5.905970686867983}
valid:: Epoch: 0011 cost = 0.084291, squad_metric :  {'exact_match': 0.3864537780941727, 'f1': 4.214191459500743}
train:: Epoch: 0012 cost = 0.070560, squad_metric :  {'exact_match': 1.4772447724477245, 'f1': 5.87064883243569}
valid:: Epoch: 0012 cost = 0.085198, squad_metric :  {'exact_match': 0.28475541543781147, 'f1': 4.112992304254845}
train:: Epoch: 0013 cost = 0.069845, squad_metric :  {'exact_match': 1.5461254612546125, 'f1': 5.881162910122298}
valid:: Epoch: 0013 cost = 0.085816, squad_metric :  {'exact_match': 0.3661141055629004, 'f1': 3.868770208199774}
train:: Epoch: 0014 cost = 0.069205, squad_metric :  {'exact_match': 1.5362853628536286, 'f1': 6.011535672205428}
valid:: Epoch: 0014 cost = 0.086096, squad_metric :  {'exact_match': 0.2644157429065392, 'f1': 3.9362743996948257}
train:: Epoch: 0015 cost = 0.068615, squad_metric :  {'exact_match': 1.5817958179581795, 'f1': 5.894389559932659}
valid:: Epoch: 0015 cost = 0.086543, squad_metric :  {'exact_match': 0.19322688904708635, 'f1': 3.936425858306148}
train:: Epoch: 0016 cost = 0.068076, squad_metric :  {'exact_match': 1.6248462484624846, 'f1': 5.984265813361757}
valid:: Epoch: 0016 cost = 0.087069, squad_metric :  {'exact_match': 0.3050950879690837, 'f1': 3.9069096913585173}
train:: Epoch: 0017 cost = 0.067535, squad_metric :  {'exact_match': 1.5621156211562115, 'f1': 5.895822398815512}
valid:: Epoch: 0017 cost = 0.087296, squad_metric :  {'exact_match': 0.20339672531272246, 'f1': 4.0214630329923375}
train:: Epoch: 0018 cost = 0.066968, squad_metric :  {'exact_match': 1.6728167281672817, 'f1': 6.115089538610587}
valid:: Epoch: 0018 cost = 0.088093, squad_metric :  {'exact_match': 0.28475541543781147, 'f1': 3.9959138206550118}
train:: Epoch: 0019 cost = 0.066491, squad_metric :  {'exact_match': 1.7404674046740467, 'f1': 6.059234042789079}
valid:: Epoch: 0019 cost = 0.087950, squad_metric :  {'exact_match': 0.2135665615783586, 'f1': 3.844752147723429}
train:: Epoch: 0020 cost = 0.066065, squad_metric :  {'exact_match': 1.7699876998769988, 'f1': 6.149869687787748}
valid:: Epoch: 0020 cost = 0.088572, squad_metric :  {'exact_match': 0.20339672531272246, 'f1': 3.84057640971435}
train:: Epoch: 0021 cost = 0.065629, squad_metric :  {'exact_match': 1.7269372693726937, 'f1': 6.077576685990024}
valid:: Epoch: 0021 cost = 0.089076, squad_metric :  {'exact_match': 0.2135665615783586, 'f1': 3.668484128969692}
train:: Epoch: 0022 cost = 0.065204, squad_metric :  {'exact_match': 1.7798277982779829, 'f1': 6.191290550684122}
valid:: Epoch: 0022 cost = 0.089141, squad_metric :  {'exact_match': 0.20339672531272246, 'f1': 3.762071948778777}
train:: Epoch: 0023 cost = 0.064740, squad_metric :  {'exact_match': 1.8757687576875768, 'f1': 6.236611392119068}
valid:: Epoch: 0023 cost = 0.089259, squad_metric :  {'exact_match': 0.2542459066409031, 'f1': 3.7709149997595928}
train:: Epoch: 0024 cost = 0.064380, squad_metric :  {'exact_match': 1.8204182041820418, 'f1': 6.217691244710026}
valid:: Epoch: 0024 cost = 0.089861, squad_metric :  {'exact_match': 0.11186819892199736, 'f1': 3.6802549974475056}
train:: Epoch: 0025 cost = 0.063955, squad_metric :  {'exact_match': 1.91389913899139, 'f1': 6.317859799188245}
valid:: Epoch: 0025 cost = 0.090195, squad_metric :  {'exact_match': 0.14237770771890573, 'f1': 3.795095819731973}
train:: Epoch: 0026 cost = 0.063569, squad_metric :  {'exact_match': 1.895448954489545, 'f1': 6.309099116230587}
valid:: Epoch: 0026 cost = 0.090376, squad_metric :  {'exact_match': 0.12203803518763348, 'f1': 3.627915130292577}
train:: Epoch: 0027 cost = 0.063159, squad_metric :  {'exact_match': 1.954489544895449, 'f1': 6.303810374911365}
valid:: Epoch: 0027 cost = 0.090593, squad_metric :  {'exact_match': 0.1322078714532696, 'f1': 3.7959551917323835}
train:: Epoch: 0028 cost = 0.062920, squad_metric :  {'exact_match': 1.980319803198032, 'f1': 6.390372746952164}
valid:: Epoch: 0028 cost = 0.090556, squad_metric :  {'exact_match': 0.15254754398454184, 'f1': 3.754198485161899}
train:: Epoch: 0029 cost = 0.062495, squad_metric :  {'exact_match': 2.097170971709717, 'f1': 6.464851302425954}
valid:: Epoch: 0029 cost = 0.091080, squad_metric :  {'exact_match': 0.12203803518763348, 'f1': 3.628710220283829}
train:: Epoch: 0030 cost = 0.062238, squad_metric :  {'exact_match': 1.96309963099631, 'f1': 6.353958784307238}
valid:: Epoch: 0030 cost = 0.090873, squad_metric :  {'exact_match': 0.14237770771890573, 'f1': 3.8562488945032305}


```




**Problem 3.2** *(10 points)*  Now let's resolve the second issue, by simply concatenating the two inputs into one sequence. The simplest way would be to append the the question at the start *OR* the end of the context. If you put it at the start, you will need to shift the start and the end positions of the answer accordingly. If you put it at the end, it will be necesary to use bidirectional LSTM for the context to be aware of what is ahead (though it is recommended to use bidirectional LSTM even if the question is appended at the start). Whichever you choose, carry it out and report the accuracy. How does it differ from 3.1?

In [11]:
# Put the question sentence before the context sentence (for Prob 3.2+)
# Use <CLS> token to separate two sentences

import torchtext
from torchtext.legacy import data
from torchtext.legacy import datasets
from torchtext.legacy.data import BucketIterator


class SQuAD2Dataset(data.Dataset):
  """
  Defines a dataset for squad1.0.
  """
  
  @staticmethod
  def sort_key(ex):
    return data.interleave_keys(len(ex.context_question))

  def __init__(self, data_list, fields, use_bos=True, max_length=None, **kwargs):
    if not isinstance(fields[0], (tuple, list)):
      fields = [
                # ('context', fields[0]), 
                # ('question', fields[1]), 
                ('context_question', fields[0]), # For Problem 3.2+, put the question after the context
                ('answer_start', fields[1]), 
                ('answer_end', fields[2]), 
                ('id_index', fields[3])
                ]

    examples = []
    nonalpha_list = find_nonalpha_list(data_list)

    self.id_list = list()
    self.reference = list()

    for _, example in enumerate(data_list):
        out = preprocess(example, nonalpha_list)
        # use data if the answer exists
        if max_length and max_length < max(len(out['context']), len(out['question'])):
            continue
        if 'answers' in out.keys():
            answer_start = out['answers'][0]['start'] # Use index 0 for instant valid accuracy
            answer_end = out['answers'][0]['end'] # Use index 0 for instant valid accuracy
            if use_bos:
                answer_start += 1 # for <BOS> token
                answer_end += 1 # for <BOS> token
            examples.append(data.Example.fromlist([
                                                #    out['context'], 
                                                #    out['question'], 
                                                   out['context'] + ["<CLS>"] + out['question'], # Use <CLS> token
                                                   answer_start,
                                                   answer_end,
                                                   len(self.id_list)], 
                                                  fields))
            self.id_list.append(out['id'])
            self.reference.append({'id':example['id'], 'answers':example['answers']})

    super(SQuAD2Dataset, self).__init__(examples, fields, **kwargs)


class SQuAD2Dataloader():
  """
  Make the dataloader for SQuAD 1.0
  """
  def __init__(self, train_data=None, valid_data=None, batch_size=64, device='cpu', 
                max_length=255, min_freq=2, fix_length=None,
                use_bos=True, use_eos=True, shuffle=True
              ):

    super(SQuAD2Dataloader, self).__init__()

    self.text = data.Field(sequential=True, use_vocab=True, batch_first=True, 
                           include_lengths=True, fix_length=fix_length, 
                           init_token='<BOS>' if use_bos else None, 
                           eos_token='<EOS>' if use_eos else None
                          )
    self.answer_start = data.Field(sequential = False, use_vocab = False)
    self.answer_end = data.Field(sequential = False, use_vocab = False)
    self.id_index = data.Field(sequential = False, use_vocab = False)
    
    train = SQuAD2Dataset(data_list=train_data, 
                          fields = [
                                    # ('context', self.text),
                                    # ('question', self.text),
                                    ('context_question', self.text),
                                    ('answer_start', self.answer_start),
                                    ('answer_end', self.answer_end),
                                    ('id_index', self.id_index)
                                    ], 
                          use_bos = use_bos,
                          max_length = max_length
                          )
    valid = SQuAD2Dataset(data_list=valid_data, 
                          fields = [
                                    # ('context', self.text),
                                    # ('question', self.text),
                                    ('context_question', self.text),
                                    ('answer_start', self.answer_start),
                                    ('answer_end', self.answer_end),
                                    ('id_index', self.id_index)
                                    ], 
                          use_bos = use_bos,
                          max_length = max_length
                          )
    self.train_id_list = train.id_list
    self.valid_id_list = valid.id_list

    self.train_reference = train.reference
    self.valid_reference = valid.reference
    
    self.train_iter = data.BucketIterator(train, batch_size=batch_size,
                                          device=device,
                                          shuffle=shuffle,
                                          sort_key=lambda x: len(x.context_question), 
                                          sort_within_batch = True
                                          )
    self.valid_iter = data.BucketIterator(valid, batch_size=batch_size,
                                          device=device,
                                          shuffle=shuffle,
                                          sort_key=lambda x: len(x.context_question),
                                          sort_within_batch = True
                                          )
    
    self.text.build_vocab(train)


train_dataset = squad_dataset['train']
valid_dataset = squad_dataset['validation']

print('# of train data : {}'.format(len(train_dataset)))
print('# of vaild data : {}'.format(len(valid_dataset)))

batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_length = 253 # 255 - 2 for <BOS> and <EOS>
min_freq = 2
use_bos = False
use_eos = False
print("device : ", device)

loader = SQuAD2Dataloader(train_dataset, valid_dataset, batch_size=batch_size, 
                          device=device, max_length=max_length, min_freq=min_freq,
                          use_bos=use_bos, use_eos=use_eos)
print('\nFinish making the dataloader')
print("batch_size : ", batch_size)
print("max_length : ", max_length)
print('# of used train data ~ {}'.format((len(loader.train_iter)) * batch_size))
print('# of used vaild data ~ {}'.format((len(loader.valid_iter)) * batch_size))

vocab = loader.text.vocab
vocab_list = list(vocab.stoi.keys())
print('# of vocab : {}'.format(len(vocab_list)))


# of train data : 87599
# of vaild data : 10570
device :  cuda

Finish making the dataloader
batch_size :  128
max_length :  253
# of used train data ~ 81408
# of used vaild data ~ 9856
# of vocab : 86390



** Result 3.2.1 (Preprocessing) **

```
# of train data : 87599
# of vaild data : 10570
device :  cuda

Finish making the dataloader
batch_size :  128
max_length :  253
# of used train data ~ 81408
# of used vaild data ~ 9856
# of vocab : 86390

```



**Answer 3.2**



In [12]:
import torch.nn as nn
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device : ", device)

# Vocabulary : Use vocab_list

# Construct the LSTM Model
embedding_dim = 128 # usually bigger, e.g. 128
hidden_dim = 256
n_layers = 2
n_label = max_length+1 if use_bos else max_length
emb_dropout = 0.5
rnn_dropout = 0.5
bidirectional = True
enable_layer_norm = True
rnnmodel = ClassificationLSTMModel(embedding_dim, hidden_dim, n_layers, n_label, emb_dropout, rnn_dropout, bidirectional, enable_layer_norm, device).to(device)

print("batch_size : ", batch_size)
print("max_length : ", max_length)
print("embedding_dim : ", embedding_dim)
print("hidden_dim : ", hidden_dim)
print("n_layers : ", n_layers)
print("emb_dropout : ", emb_dropout)
print("rnn_dropout_and_fc_dropout : ", rnn_dropout)
if bidirectional:
    print("bidirectional : True")
else:
    print("bidirectional : False")
if enable_layer_norm:
    print("enable_layer_norm : True")
else:
    print("enable_layer_norm : False")

# Construct the data loader
train_iter = loader.train_iter
valid_iter = loader.valid_iter

train_id_list = loader.train_id_list
valid_id_list = loader.valid_id_list
train_reference = loader.train_reference
valid_reference = loader.valid_reference

# Training
learning_rate = 5e-4
print("learning_rate : ", learning_rate)

PAD_IDX = vocab.stoi['<pad>']
cel = nn.CrossEntropyLoss(ignore_index=PAD_IDX) # Ignore Padding
# optimizer = torch.optim.SGD(rnnmodel.parameters(), lr=1e-1)
optimizer = torch.optim.Adam(rnnmodel.parameters(), lr=learning_rate)

epochs = 30
max_norm = 5

# Evaluate
squad_metric = load_metric('squad')


for epoch in tqdm(range(epochs)):
    train_loss = 0
    train_accuracy = 0.0
    train_data_num = 0
    train_prediction = list()
    for train_i, train_batch in enumerate(train_iter):
        context, context_length = train_batch.context_question
        answer_start = train_batch.answer_start
        answer_end = train_batch.answer_end
        train_id_index = train_batch.id_index

        logits_start, logits_end = rnnmodel(context, context_length)

        optimizer.zero_grad() # reset process
        loss = cel(logits_start, answer_start) + cel(logits_end, answer_end) # Loss, a.k.a L
        loss.backward() # compute gradients
        # print(torch.norm(rnnmodel.lstm.weight_hh_l0.grad), loss.item())
        # torch.nn.utils.clip_grad_norm_(rnnmodel.parameters(), max_norm) # gradent clipping
        optimizer.step() # update parameters
        train_loss += loss.item()
        
        _, train_start_preds = torch.max(logits_start, 1)
        _, train_end_preds = torch.max(logits_end, 1)
        train_accuracy += ((train_start_preds == answer_start) * (train_end_preds == answer_end)).sum().float()

        train_data_num += context.shape[0]

        for train_j in range(context.shape[0]):
            pred_text = ""
            start = train_start_preds[train_j]
            end = train_end_preds[train_j]
            if start < end:
                pred_text = [vocab_list[text_id] for text_id in context[train_j][start:end+1]]
                pred_text = " ".join(pred_text)

            # start = answer_start[train_j]
            # end = answer_end[train_j]
            # answer_text = [vocab_list[text_id] for text_id in context[train_j][start:end+1]]
            # answer_text = " ".join(answer_text)
            # print(pred_text, answer_text, train_reference[train_id_index[train_j]], "\n")

            train_prediction.append({'id':train_id_list[train_id_index[train_j]], 'prediction_text':pred_text})

    train_result = squad_metric.compute(predictions=train_prediction, references=train_reference)
    print('train:: Epoch:', '%04d' % (epoch + 1), 
          'cost =', '{:.6f},'.format(train_loss / train_data_num), 
          'my exact_match =', '{:.6f}'.format(train_accuracy / train_data_num),
          'other_squad_metric : ', train_result)
        
    if (epoch + 1) % 1 == 0:
        with torch.no_grad():
            valid_loss = 0
            valid_accuracy = 0.0
            valid_data_num = 0
            valid_prediction = list()
            for valid_i, valid_batch in enumerate(valid_iter):
                context, context_length = valid_batch.context_question
                answer_start = valid_batch.answer_start
                answer_end = valid_batch.answer_end
                valid_id_index = valid_batch.id_index

                logits_start, logits_end = rnnmodel(context, context_length)

                loss = cel(logits_start, answer_start) + cel(logits_end, answer_end) # Loss, a.k.a L
                valid_loss += loss.item()

                _, valid_start_preds = torch.max(logits_start, 1)
                _, valid_end_preds = torch.max(logits_end, 1)
                valid_accuracy += ((valid_start_preds == answer_start) * (valid_end_preds == answer_end)).sum().float()

                valid_data_num += context.shape[0]

                for valid_j in range(context.shape[0]):
                    pred_text = ""
                    start = valid_start_preds[valid_j]
                    end = valid_end_preds[valid_j]
                    if start < end:
                        pred_text = [vocab_list[text_id] for text_id in context[valid_j][start:end+1]]
                        pred_text = " ".join(pred_text)
                    valid_prediction.append({'id':valid_id_list[valid_id_index[valid_j]], 'prediction_text':pred_text})
                
            valid_result = squad_metric.compute(predictions=valid_prediction, references=valid_reference)
            print('valid:: Epoch:', '%04d' % (epoch + 1), 
                  'cost =', '{:.6f},'.format(valid_loss / valid_data_num), 
                  'my exact_match =', '{:.6f},'.format(valid_accuracy / valid_data_num),
                  'other_squad_metric : ', valid_result)
            


device :  cuda
batch_size :  128
max_length :  253
embedding_dim :  128
hidden_dim :  256
n_layers :  2
emb_dropout :  0.5
rnn_dropout_and_fc_dropout :  0.5
bidirectional : True
enable_layer_norm : True
learning_rate :  0.0005


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

train:: Epoch: 0001 cost = 0.079870, argmax acc = 0.000480 other_squad_metric :  {'exact_match': 0.05289052890528905, 'f1': 5.989309692302495}
valid:: Epoch: 0001 cost = 0.079658, argmax acc = 0.001119, other_squad_metric :  {'exact_match': 0.010169836265636124, 'f1': 6.076181438808642}
train:: Epoch: 0002 cost = 0.079319, argmax acc = 0.002030 other_squad_metric :  {'exact_match': 0.05166051660516605, 'f1': 5.409903198798298}
valid:: Epoch: 0002 cost = 0.079311, argmax acc = 0.003458, other_squad_metric :  {'exact_match': 0.06101901759381674, 'f1': 5.7264634760190045}
train:: Epoch: 0003 cost = 0.079070, argmax acc = 0.004600 other_squad_metric :  {'exact_match': 0.12915129151291513, 'f1': 4.623574903221184}
valid:: Epoch: 0003 cost = 0.079317, argmax acc = 0.004271, other_squad_metric :  {'exact_match': 0.1830570527814502, 'f1': 4.96396233745034}
train:: Epoch: 0004 cost = 0.078806, argmax acc = 0.007085 other_squad_metric :  {'exact_match': 0.3124231242312423, 'f1': 4.48285814919620

** Comment 3.2 **

Still bad performance.

** Result 3.2.2 (Training & Validation) **



```
device :  cuda
batch_size :  128
max_length :  253
embedding_dim :  128
hidden_dim :  256
n_layers :  2
emb_dropout :  0.5
rnn_dropout_and_fc_dropout :  0.5
bidirectional : True
enable_layer_norm : True
learning_rate :  0.0005
100%
50/50 [1:29:20<00:00, 107.20s/it]
train:: Epoch: 0001 cost = 0.079870, argmax acc = 0.000480 other_squad_metric :  {'exact_match': 0.05289052890528905, 'f1': 5.989309692302495}
valid:: Epoch: 0001 cost = 0.079658, argmax acc = 0.001119, other_squad_metric :  {'exact_match': 0.010169836265636124, 'f1': 6.076181438808642}
train:: Epoch: 0002 cost = 0.079319, argmax acc = 0.002030 other_squad_metric :  {'exact_match': 0.05166051660516605, 'f1': 5.409903198798298}
valid:: Epoch: 0002 cost = 0.079311, argmax acc = 0.003458, other_squad_metric :  {'exact_match': 0.06101901759381674, 'f1': 5.7264634760190045}
train:: Epoch: 0003 cost = 0.079070, argmax acc = 0.004600 other_squad_metric :  {'exact_match': 0.12915129151291513, 'f1': 4.623574903221184}
valid:: Epoch: 0003 cost = 0.079317, argmax acc = 0.004271, other_squad_metric :  {'exact_match': 0.1830570527814502, 'f1': 4.96396233745034}
train:: Epoch: 0004 cost = 0.078806, argmax acc = 0.007085 other_squad_metric :  {'exact_match': 0.3124231242312423, 'f1': 4.482858149196208}
valid:: Epoch: 0004 cost = 0.079252, argmax acc = 0.007221, other_squad_metric :  {'exact_match': 0.3356045967659921, 'f1': 4.793506052164472}
train:: Epoch: 0005 cost = 0.078508, argmax acc = 0.008708 other_squad_metric :  {'exact_match': 0.42066420664206644, 'f1': 4.502234244770526}
valid:: Epoch: 0005 cost = 0.079273, argmax acc = 0.007627, other_squad_metric :  {'exact_match': 0.3966236143598088, 'f1': 5.125529379910425}
train:: Epoch: 0006 cost = 0.078141, argmax acc = 0.010185 other_squad_metric :  {'exact_match': 0.5768757687576875, 'f1': 4.7975583007366245}
valid:: Epoch: 0006 cost = 0.079312, argmax acc = 0.009051, other_squad_metric :  {'exact_match': 0.46781246821926165, 'f1': 4.773259851830603}
train:: Epoch: 0007 cost = 0.077821, argmax acc = 0.011021 other_squad_metric :  {'exact_match': 0.5953259532595326, 'f1': 4.83455868092752}
valid:: Epoch: 0007 cost = 0.079348, argmax acc = 0.009356, other_squad_metric :  {'exact_match': 0.41696328689108103, 'f1': 5.091698404526903}
train:: Epoch: 0008 cost = 0.077476, argmax acc = 0.012362 other_squad_metric :  {'exact_match': 0.6629766297662977, 'f1': 4.984677232806258}
valid:: Epoch: 0008 cost = 0.079552, argmax acc = 0.008034, other_squad_metric :  {'exact_match': 0.5390013220787145, 'f1': 4.841777080462419}
train:: Epoch: 0009 cost = 0.077151, argmax acc = 0.013542 other_squad_metric :  {'exact_match': 0.7429274292742928, 'f1': 5.147279034148527}
valid:: Epoch: 0009 cost = 0.079495, argmax acc = 0.008949, other_squad_metric :  {'exact_match': 0.4474727956879894, 'f1': 4.919948065225899}
train:: Epoch: 0010 cost = 0.076794, argmax acc = 0.013887 other_squad_metric :  {'exact_match': 0.7761377613776138, 'f1': 5.2560421593567135}
valid:: Epoch: 0010 cost = 0.079807, argmax acc = 0.009051, other_squad_metric :  {'exact_match': 0.5186616495474423, 'f1': 5.023962187600448}
train:: Epoch: 0011 cost = 0.076487, argmax acc = 0.014772 other_squad_metric :  {'exact_match': 0.8597785977859779, 'f1': 5.355466359956851}
valid:: Epoch: 0011 cost = 0.080142, argmax acc = 0.008238, other_squad_metric :  {'exact_match': 0.49832197701617004, 'f1': 4.92387781952113}
train:: Epoch: 0012 cost = 0.076147, argmax acc = 0.015166 other_squad_metric :  {'exact_match': 0.9298892988929889, 'f1': 5.58156635329186}
valid:: Epoch: 0012 cost = 0.080227, argmax acc = 0.010475, other_squad_metric :  {'exact_match': 0.5898505034068952, 'f1': 4.9997467309597}
train:: Epoch: 0013 cost = 0.075850, argmax acc = 0.016396 other_squad_metric :  {'exact_match': 0.980319803198032, 'f1': 5.709361305923348}
valid:: Epoch: 0013 cost = 0.080358, argmax acc = 0.008339, other_squad_metric :  {'exact_match': 0.5084918132818061, 'f1': 4.906783610764015}
train:: Epoch: 0014 cost = 0.075527, argmax acc = 0.016002 other_squad_metric :  {'exact_match': 1.030750307503075, 'f1': 5.7895020394261865}
valid:: Epoch: 0014 cost = 0.080466, argmax acc = 0.009356, other_squad_metric :  {'exact_match': 0.5186616495474423, 'f1': 4.8639803711206095}
train:: Epoch: 0015 cost = 0.075258, argmax acc = 0.017122 other_squad_metric :  {'exact_match': 1.039360393603936, 'f1': 5.929675999299098}
valid:: Epoch: 0015 cost = 0.080841, argmax acc = 0.009763, other_squad_metric :  {'exact_match': 0.4474727956879894, 'f1': 4.720280108404927}
train:: Epoch: 0016 cost = 0.074940, argmax acc = 0.017294 other_squad_metric :  {'exact_match': 1.0922509225092252, 'f1': 5.9386564306645875}
valid:: Epoch: 0016 cost = 0.080967, argmax acc = 0.009661, other_squad_metric :  {'exact_match': 0.5695108308756229, 'f1': 4.737350504812802}
train:: Epoch: 0017 cost = 0.074684, argmax acc = 0.018192 other_squad_metric :  {'exact_match': 1.118081180811808, 'f1': 6.067873580600416}
valid:: Epoch: 0017 cost = 0.080978, argmax acc = 0.009356, other_squad_metric :  {'exact_match': 0.4779823044848978, 'f1': 4.737894245953374}
train:: Epoch: 0018 cost = 0.074393, argmax acc = 0.018893 other_squad_metric :  {'exact_match': 1.153751537515375, 'f1': 6.085247510336405}
valid:: Epoch: 0018 cost = 0.081367, argmax acc = 0.008034, other_squad_metric :  {'exact_match': 0.4067934506254449, 'f1': 4.725415697406601}
train:: Epoch: 0019 cost = 0.074138, argmax acc = 0.019200 other_squad_metric :  {'exact_match': 1.2263222632226323, 'f1': 6.192327102928718}
valid:: Epoch: 0019 cost = 0.081337, argmax acc = 0.009356, other_squad_metric :  {'exact_match': 0.4779823044848978, 'f1': 4.761990326948866}
train:: Epoch: 0020 cost = 0.073868, argmax acc = 0.019828 other_squad_metric :  {'exact_match': 1.2484624846248462, 'f1': 6.259461564994696}
valid:: Epoch: 0020 cost = 0.081493, argmax acc = 0.009966, other_squad_metric :  {'exact_match': 0.5390013220787145, 'f1': 4.915876268127727}
train:: Epoch: 0021 cost = 0.073591, argmax acc = 0.020234 other_squad_metric :  {'exact_match': 1.2226322263222633, 'f1': 6.320693636443308}
valid:: Epoch: 0021 cost = 0.081802, argmax acc = 0.008238, other_squad_metric :  {'exact_match': 0.5593409946099868, 'f1': 4.731472551932656}
train:: Epoch: 0022 cost = 0.073366, argmax acc = 0.020677 other_squad_metric :  {'exact_match': 1.3665436654366543, 'f1': 6.410087668775246}
valid:: Epoch: 0022 cost = 0.082137, argmax acc = 0.008441, other_squad_metric :  {'exact_match': 0.4881521407505339, 'f1': 4.748025437134684}
train:: Epoch: 0023 cost = 0.073149, argmax acc = 0.020480 other_squad_metric :  {'exact_match': 1.2878228782287824, 'f1': 6.436317355530415}
valid:: Epoch: 0023 cost = 0.082268, argmax acc = 0.008238, other_squad_metric :  {'exact_match': 0.5084918132818061, 'f1': 4.577179652162081}
train:: Epoch: 0024 cost = 0.072879, argmax acc = 0.021611 other_squad_metric :  {'exact_match': 1.3284132841328413, 'f1': 6.534731076891177}
valid:: Epoch: 0024 cost = 0.082333, argmax acc = 0.008746, other_squad_metric :  {'exact_match': 0.49832197701617004, 'f1': 4.854091224465016}
train:: Epoch: 0025 cost = 0.072581, argmax acc = 0.021771 other_squad_metric :  {'exact_match': 1.3468634686346863, 'f1': 6.566252875913047}
valid:: Epoch: 0025 cost = 0.082600, argmax acc = 0.008136, other_squad_metric :  {'exact_match': 0.6203600122038035, 'f1': 4.427058829862266}
train:: Epoch: 0026 cost = 0.072373, argmax acc = 0.022226 other_squad_metric :  {'exact_match': 1.3825338253382533, 'f1': 6.691890118043898}
valid:: Epoch: 0026 cost = 0.082687, argmax acc = 0.009661, other_squad_metric :  {'exact_match': 0.5084918132818061, 'f1': 4.726996743387878}
train:: Epoch: 0027 cost = 0.072126, argmax acc = 0.022128 other_squad_metric :  {'exact_match': 1.3800738007380073, 'f1': 6.704347870959806}
valid:: Epoch: 0027 cost = 0.083415, argmax acc = 0.007119, other_squad_metric :  {'exact_match': 0.3559442692972643, 'f1': 4.3227320610159}
train:: Epoch: 0028 cost = 0.071940, argmax acc = 0.022964 other_squad_metric :  {'exact_match': 1.4858548585485856, 'f1': 6.8166554814508356}
valid:: Epoch: 0028 cost = 0.082823, argmax acc = 0.009458, other_squad_metric :  {'exact_match': 0.5695108308756229, 'f1': 4.664080646130991}
train:: Epoch: 0029 cost = 0.071652, argmax acc = 0.023149 other_squad_metric :  {'exact_match': 1.4710947109471095, 'f1': 6.8521612048196765}
valid:: Epoch: 0029 cost = 0.083159, argmax acc = 0.007526, other_squad_metric :  {'exact_match': 0.4373029594223533, 'f1': 4.5875795966660275}
train:: Epoch: 0030 cost = 0.071447, argmax acc = 0.024428 other_squad_metric :  {'exact_match': 1.4920049200492005, 'f1': 6.911899083776001}
valid:: Epoch: 0030 cost = 0.083746, argmax acc = 0.006814, other_squad_metric :  {'exact_match': 0.4271331231567172, 'f1': 4.501357764974634}
train:: Epoch: 0031 cost = 0.071172, argmax acc = 0.024945 other_squad_metric :  {'exact_match': 1.5891758917589176, 'f1': 7.14310523963246}
valid:: Epoch: 0031 cost = 0.083821, argmax acc = 0.007424, other_squad_metric :  {'exact_match': 0.4474727956879894, 'f1': 4.426732009177166}
train:: Epoch: 0032 cost = 0.070933, argmax acc = 0.025178 other_squad_metric :  {'exact_match': 1.5707257072570726, 'f1': 7.100587584299848}
valid:: Epoch: 0032 cost = 0.084167, argmax acc = 0.007322, other_squad_metric :  {'exact_match': 0.5288314858130784, 'f1': 4.336364783978499}
train:: Epoch: 0033 cost = 0.070716, argmax acc = 0.025646 other_squad_metric :  {'exact_match': 1.6408364083640836, 'f1': 7.230598951569533}
valid:: Epoch: 0033 cost = 0.084090, argmax acc = 0.007322, other_squad_metric :  {'exact_match': 0.5084918132818061, 'f1': 4.543467719005027}
train:: Epoch: 0034 cost = 0.070453, argmax acc = 0.025806 other_squad_metric :  {'exact_match': 1.6186961869618697, 'f1': 7.191681394622516}
valid:: Epoch: 0034 cost = 0.084361, argmax acc = 0.007526, other_squad_metric :  {'exact_match': 0.46781246821926165, 'f1': 4.736551729213181}
train:: Epoch: 0035 cost = 0.070197, argmax acc = 0.026556 other_squad_metric :  {'exact_match': 1.6137761377613775, 'f1': 7.2650431586045725}
valid:: Epoch: 0035 cost = 0.084460, argmax acc = 0.008034, other_squad_metric :  {'exact_match': 0.5084918132818061, 'f1': 4.496105225498365}
train:: Epoch: 0036 cost = 0.070010, argmax acc = 0.027011 other_squad_metric :  {'exact_match': 1.6863468634686347, 'f1': 7.352410576934097}
valid:: Epoch: 0036 cost = 0.084535, argmax acc = 0.007831, other_squad_metric :  {'exact_match': 0.4779823044848978, 'f1': 4.4192404175252715}
train:: Epoch: 0037 cost = 0.069749, argmax acc = 0.027835 other_squad_metric :  {'exact_match': 1.7392373923739237, 'f1': 7.513770905912509}
valid:: Epoch: 0037 cost = 0.084879, argmax acc = 0.006814, other_squad_metric :  {'exact_match': 0.3864537780941727, 'f1': 4.509360088792529}
train:: Epoch: 0038 cost = 0.069538, argmax acc = 0.028352 other_squad_metric :  {'exact_match': 1.795817958179582, 'f1': 7.567768556050303}
valid:: Epoch: 0038 cost = 0.085208, argmax acc = 0.007119, other_squad_metric :  {'exact_match': 0.3152649242347198, 'f1': 4.320164108583346}
train:: Epoch: 0039 cost = 0.069376, argmax acc = 0.028893 other_squad_metric :  {'exact_match': 1.7761377613776137, 'f1': 7.477864285397355}
valid:: Epoch: 0039 cost = 0.085120, argmax acc = 0.007119, other_squad_metric :  {'exact_match': 0.3762839418285366, 'f1': 4.465381510521311}
train:: Epoch: 0040 cost = 0.069104, argmax acc = 0.028647 other_squad_metric :  {'exact_match': 1.7416974169741697, 'f1': 7.64480520077062}
valid:: Epoch: 0040 cost = 0.085419, argmax acc = 0.006509, other_squad_metric :  {'exact_match': 0.3661141055629004, 'f1': 4.3012572132869}
train:: Epoch: 0041 cost = 0.068922, argmax acc = 0.029397 other_squad_metric :  {'exact_match': 1.897908979089791, 'f1': 7.657625834154736}
valid:: Epoch: 0041 cost = 0.085548, argmax acc = 0.007017, other_squad_metric :  {'exact_match': 0.4373029594223533, 'f1': 4.419533677236369}
train:: Epoch: 0042 cost = 0.068677, argmax acc = 0.030074 other_squad_metric :  {'exact_match': 1.8560885608856088, 'f1': 7.782342928965346}
valid:: Epoch: 0042 cost = 0.085584, argmax acc = 0.007119, other_squad_metric :  {'exact_match': 0.4373029594223533, 'f1': 4.279498076008183}
train:: Epoch: 0043 cost = 0.068484, argmax acc = 0.030578 other_squad_metric :  {'exact_match': 1.8929889298892988, 'f1': 7.918387526381218}
valid:: Epoch: 0043 cost = 0.086364, argmax acc = 0.005492, other_squad_metric :  {'exact_match': 0.4373029594223533, 'f1': 4.302321630297892}
train:: Epoch: 0044 cost = 0.068248, argmax acc = 0.030664 other_squad_metric :  {'exact_match': 1.952029520295203, 'f1': 7.9328011644837195}
valid:: Epoch: 0044 cost = 0.085853, argmax acc = 0.007322, other_squad_metric :  {'exact_match': 0.4779823044848978, 'f1': 4.6060141788937745}
train:: Epoch: 0045 cost = 0.068045, argmax acc = 0.030812 other_squad_metric :  {'exact_match': 1.944649446494465, 'f1': 8.078085092699313}
valid:: Epoch: 0045 cost = 0.086303, argmax acc = 0.006509, other_squad_metric :  {'exact_match': 0.5084918132818061, 'f1': 4.511649043562105}
train:: Epoch: 0046 cost = 0.067739, argmax acc = 0.032706 other_squad_metric :  {'exact_match': 2.025830258302583, 'f1': 8.129184878529529}
valid:: Epoch: 0046 cost = 0.086710, argmax acc = 0.004882, other_squad_metric :  {'exact_match': 0.4373029594223533, 'f1': 4.139784476826585}
train:: Epoch: 0047 cost = 0.067588, argmax acc = 0.031587 other_squad_metric :  {'exact_match': 1.985239852398524, 'f1': 8.122182574030209}
valid:: Epoch: 0047 cost = 0.086525, argmax acc = 0.005390, other_squad_metric :  {'exact_match': 0.41696328689108103, 'f1': 4.51468017892951}
train:: Epoch: 0048 cost = 0.067337, argmax acc = 0.032386 other_squad_metric :  {'exact_match': 2.054120541205412, 'f1': 8.30270446700247}
valid:: Epoch: 0048 cost = 0.087089, argmax acc = 0.005492, other_squad_metric :  {'exact_match': 0.3356045967659921, 'f1': 4.134395096402048}
train:: Epoch: 0049 cost = 0.067101, argmax acc = 0.033432 other_squad_metric :  {'exact_match': 2.1316113161131613, 'f1': 8.364209134811277}
valid:: Epoch: 0049 cost = 0.087063, argmax acc = 0.005593, other_squad_metric :  {'exact_match': 0.41696328689108103, 'f1': 4.447391790002926}
train:: Epoch: 0050 cost = 0.066838, argmax acc = 0.033137 other_squad_metric :  {'exact_match': 2.2189421894218944, 'f1': 8.515114560353371}
valid:: Epoch: 0050 cost = 0.087535, argmax acc = 0.005288, other_squad_metric :  {'exact_match': 0.32543476050035597, 'f1': 4.027121477755915}

```



## 4. LSTM + Attention for SQuAD

**Problem 4.1** *(20 points)* Here, we will be appending an attention layer on top of LSTM outputs. We will use a single-head attention sublayer from Transformer. That is, you will implement 
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V,$$
where $Q, K, V$ is obtained by the linear transformation of the hidden states of the LSTM outputs $H$, i.e. $Q = HW^Q, K=HW^K, V=HW^V$ ($W^Q, W^K, W^V \in \mathbb{R}^{d \times d}$ are trainable weights). Note that the output of $\text{Attention}$ layer has the same dimension as $H$, so you can directly append your token classification layer on top of it. Report the accuracy and compare it with 3.2.



**Answer 4.1**



In [1]:
import torch.nn as nn
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class AttentionLSTMModel(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, n_layers, n_label, emb_dropout, rnn_dropout, bidirectional, enable_layer_norm, device):
        super(ClassificationLSTMModel, self).__init__()
        self.embedding = nn.Embedding(len(vocab), embedding_dim)
        self.lstm = nn.LSTM(input_size=embedding_dim, 
                            hidden_size=hidden_dim, 
                            num_layers=n_layers, 
                            dropout=rnn_dropout, 
                            bidirectional=bidirectional)
      
        n_direction = 2 if bidirectional else 1
        self.fc_start = nn.Linear(hidden_dim*n_direction, n_label, bias=True)
        self.fc_end = nn.Linear(hidden_dim*n_direction, n_label, bias=True)

        # Layer_normalization
        self.enable_layer_norm = enable_layer_norm
        if enable_layer_norm:
            self.norm1 = nn.LayerNorm(embedding_dim)
            self.norm2 = nn.LayerNorm(hidden_dim*n_direction)

        self.emb_dropout = nn.Dropout(emb_dropout)
        self.fc_dropout = nn.Dropout(rnn_dropout)
        self.bidirectional = bidirectional
        self.device = device

    def forward(self, input_tensor, src_seq_lens):
        emb = self.embedding(input_tensor) # emb.shape = batch * len * hidden

        # Layer_normalization
        if self.enable_layer_norm:
            emb = self.norm1(emb)

        emb = self.emb_dropout(emb)
        emb = emb.transpose(0, 1) # emb.shape = len * batch * hidden

        # n_direction = 2 if bidirectional else 1
        # hidden = torch.zeros(n_layers*n_direction, context.shape[0], hidden_dim, requires_grad=True).to(self.device)
        # cell = torch.zeros(n_layers*n_direction, context.shape[0], hidden_dim, requires_grad=True).to(self.device)

        # nn.LSTM
        packed = pack_padded_sequence(emb, src_seq_lens.tolist(), batch_first=False)
        outs, (hidden, cell) = self.lstm(packed)
        outs, out_lens = pad_packed_sequence(outs, batch_first=False)

        if self.bidirectional:
            hidden = torch.stack([hidden[-2], hidden[-1]], dim=0)
        else:
            hidden = hidden[-1].unsqueeze(dim=0)
        print(hidden.shape)
        hidden = hidden.transpose(0, 1)
        hidden = hidden.contiguous().view(hidden.shape[0], -1)
        print(hidden.shape)

        # # Layer_normalization
        # if self.enable_layer_norm:
        #     hidden = self.norm2(hidden)

        hidden = self.fc_dropout(hidden)
        logits_start = self.fc_start(hidden)
        logits_end = self.fc_end(hidden)
        return (logits_start, logits_end)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device : ", device)

# Vocabulary : Use vocab_list

# Construct the LSTM Model
embedding_dim = 128 # usually bigger, e.g. 128
hidden_dim = 256
n_layers = 2
n_label = max_length+1 if use_bos else max_length
emb_dropout = 0.5
rnn_dropout = 0.5
bidirectional = True
enable_layer_norm = True
rnnmodel = AttentionLSTMModel(embedding_dim, hidden_dim, n_layers, n_label, emb_dropout, rnn_dropout, bidirectional, enable_layer_norm, device).to(device)

print("batch_size : ", batch_size)
print("max_length : ", max_length)
print("embedding_dim : ", embedding_dim)
print("hidden_dim : ", hidden_dim)
print("n_layers : ", n_layers)
print("emb_dropout : ", emb_dropout)
print("rnn_dropout_and_fc_dropout : ", rnn_dropout)
if bidirectional:
    print("bidirectional : True")
else:
    print("bidirectional : False")
if enable_layer_norm:
    print("enable_layer_norm : True")
else:
    print("enable_layer_norm : False")

# Construct the data loader
train_iter = loader.train_iter
valid_iter = loader.valid_iter

train_id_list = loader.train_id_list
valid_id_list = loader.valid_id_list
train_reference = loader.train_reference
valid_reference = loader.valid_reference

# Training
learning_rate = 5e-4
print("learning_rate : ", learning_rate)

PAD_IDX = vocab.stoi['<pad>']
cel = nn.CrossEntropyLoss(ignore_index=PAD_IDX) # Ignore Padding
# optimizer = torch.optim.SGD(rnnmodel.parameters(), lr=1e-1)
optimizer = torch.optim.Adam(rnnmodel.parameters(), lr=learning_rate)

epochs = 30
max_norm = 5

# Evaluate
squad_metric = load_metric('squad')


for epoch in tqdm(range(epochs)):
    train_loss = 0
    train_accuracy = 0.0
    train_data_num = 0
    train_prediction = list()
    for train_i, train_batch in enumerate(train_iter):
        context, context_length = train_batch.context_question
        answer_start = train_batch.answer_start
        answer_end = train_batch.answer_end
        train_id_index = train_batch.id_index

        logits_start, logits_end = rnnmodel(context, context_length)

        optimizer.zero_grad() # reset process
        loss = cel(logits_start, answer_start) + cel(logits_end, answer_end) # Loss, a.k.a L
        loss.backward() # compute gradients
        # print(torch.norm(rnnmodel.lstm.weight_hh_l0.grad), loss.item())
        # torch.nn.utils.clip_grad_norm_(rnnmodel.parameters(), max_norm) # gradent clipping
        optimizer.step() # update parameters
        train_loss += loss.item()
        
        _, train_start_preds = torch.max(logits_start, 1)
        _, train_end_preds = torch.max(logits_end, 1)
        train_accuracy += ((train_start_preds == answer_start) * (train_end_preds == answer_end)).sum().float()

        train_data_num += context.shape[0]

        for train_j in range(context.shape[0]):
            pred_text = ""
            start = train_start_preds[train_j]
            end = train_end_preds[train_j]
            if start < end:
                pred_text = [vocab_list[text_id] for text_id in context[train_j][start:end+1]]
                pred_text = " ".join(pred_text)

            # start = answer_start[train_j]
            # end = answer_end[train_j]
            # answer_text = [vocab_list[text_id] for text_id in context[train_j][start:end+1]]
            # answer_text = " ".join(answer_text)
            # print(pred_text, answer_text, train_reference[train_id_index[train_j]], "\n")

            train_prediction.append({'id':train_id_list[train_id_index[train_j]], 'prediction_text':pred_text})

    train_result = squad_metric.compute(predictions=train_prediction, references=train_reference)
    print('train:: Epoch:', '%04d' % (epoch + 1), 
          'cost =', '{:.6f},'.format(train_loss / train_data_num), 
          'argmax acc =', '{:.6f}'.format(train_accuracy / train_data_num),
          'other_squad_metric : ', train_result)
        
    if (epoch + 1) % 1 == 0:
        with torch.no_grad():
            valid_loss = 0
            valid_accuracy = 0.0
            valid_data_num = 0
            valid_prediction = list()
            for valid_i, valid_batch in enumerate(valid_iter):
                context, context_length = valid_batch.context_question
                answer_start = valid_batch.answer_start
                answer_end = valid_batch.answer_end
                valid_id_index = valid_batch.id_index

                logits_start, logits_end = rnnmodel(context, context_length)

                loss = cel(logits_start, answer_start) + cel(logits_end, answer_end) # Loss, a.k.a L
                valid_loss += loss.item()

                _, valid_start_preds = torch.max(logits_start, 1)
                _, valid_end_preds = torch.max(logits_end, 1)
                valid_accuracy += ((valid_start_preds == answer_start) * (valid_end_preds == answer_end)).sum().float()

                valid_data_num += context.shape[0]

                for valid_j in range(context.shape[0]):
                    pred_text = ""
                    start = valid_start_preds[valid_j]
                    end = valid_end_preds[valid_j]
                    if start < end:
                        pred_text = [vocab_list[text_id] for text_id in context[valid_j][start:end+1]]
                        pred_text = " ".join(pred_text)
                    valid_prediction.append({'id':valid_id_list[valid_id_index[valid_j]], 'prediction_text':pred_text})
                
            valid_result = squad_metric.compute(predictions=valid_prediction, references=valid_reference)
            print('valid:: Epoch:', '%04d' % (epoch + 1), 
                  'cost =', '{:.6f},'.format(valid_loss / valid_data_num), 
                  'argmax acc =', '{:.6f},'.format(valid_accuracy / valid_data_num),
                  'other_squad_metric : ', valid_result)
            

NameError: ignored

**Problem 4.2** *(10 points)* On top of the attention layer, let's add another layer of (bi-directional) LSTM. So this will look like a *sandwich* where the LSTM is bread and the attention is ham. How does it affect the accuracy? Explain why do you think this happens. 

**Answer 4.2**



## 5. Attention is All You Need

**Problem 5.1 (bonus)** *(20 points)*  Implement full Transformer encoder to entirely replace LSTMs. You are allowed to copy and paste code from [*Annotated Transformer*](https://nlp.seas.harvard.edu/2018/04/03/attention.html) (but nowhere else). Report the accuracy and explain what seems to happening with attetion-only model compared to LSTM+Attention model(s). 

**Answer 5.1**


**Problem 5.2 (bonus)** *(10 points)* Replace Transformer's sinusoidal position encoding with a fixed-length (of 256) position embedding. That is, you will create a 256-by-$d$ trainable parameter matrix for the position encoding that replaces the variable-length sinusoidal encoding. What is the clear disdvantage of this approach? Report the accuracy and compare it with 5.1. Note that this also has a clear advantage, as we will see in our future lecture on Pretrained Language Model, and more specifically, BERT (Devlin et al., 2018).

**Answer 5.2**

