<a href="https://colab.research.google.com/github/graviraja/100-Days-of-NLP/blob/applications%2Fquestion-answering/applications/question-answering/Question%20Answering%20using%20Double-Cross-Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Initial Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [2]:
train_file = "/content/drive/My Drive/train-v1.1.json"
dev_file = "/content/drive/My Drive/dev-v1.1.json"

### Imports

In [3]:
import time
import json
import nltk
from tqdm import tqdm
from collections import Counter
from sklearn import metrics

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch.utils.data import Dataset, DataLoader

In [4]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

### Parsing Data

In [6]:
def tokenize(sequence):
    tokens = [token.replace("``", '"').replace("''", '"').lower() for token in nltk.word_tokenize(sequence)]
    return tokens

In [7]:
def get_char_word_loc_mapping(context, context_tokens):
    acc = ""
    current_token_idx = 0
    mapping = dict()
    
    for char_idx, char in enumerate(context):
        if char != " " and char != '\n':
            acc += char
            context_token = context_tokens[current_token_idx]
            if acc == context_token:
                syn_start = char_idx - len(acc) + 1
                for char_loc in range(syn_start, char_idx + 1):
                    mapping[char_loc] = (acc, current_token_idx)
                acc = ""
                current_token_idx += 1
    
    if current_token_idx != len(context_tokens):
        return None
    else:
        return mapping

In [8]:
def parse_data(filepath):
    with open(filepath, 'r') as f:
        dataset = json.load(f)
    
    examples = []
    not_matching_answer = 0
    not_matching_ans_tokens = 0
    num_mapping_prob = 0

    for articles_id in range(len(dataset['data'])):
        article_paragraphs = dataset['data'][articles_id]['paragraphs']
        for pid in range(len(article_paragraphs)):
            context = article_paragraphs[pid]['context']
            context = context.replace("''", '" ')
            context = context.replace("``", '" ')
            context = context.lower()
            context_tokens = tokenize(context)
            
            qas = article_paragraphs[pid]['qas']

            char_to_wordloc = get_char_word_loc_mapping(context, context_tokens)
            if char_to_wordloc is None:
                num_mapping_prob += len(qas)
                continue
            
            for qa in qas:
                question = qa['question']
                question_tokens = tokenize(question)

                ans_text = qa['answers'][0]['text']
                ans_text = ans_text.replace("''", '" ')
                ans_text = ans_text.replace("``", '" ')
                ans_text = ans_text.lower()
                ans_start_charloc = qa['answers'][0]['answer_start']
                ans_end_charloc = ans_start_charloc + len(ans_text)

                if context[ans_start_charloc: ans_end_charloc] != ans_text:
                    not_matching_answer += 1
                    continue
                
                ans_start_wordloc = char_to_wordloc[ans_start_charloc][1]
                ans_end_wordloc = char_to_wordloc[ans_end_charloc - 1][1]
                assert ans_start_wordloc <= ans_end_wordloc, "Answer indices are not correct"

                ans_tokens = context_tokens[ans_start_wordloc: ans_end_wordloc + 1]
                if "".join(ans_tokens) != "".join(ans_text.lower().split()):
                    not_matching_ans_tokens += 1
                    continue
                
                examples.append([context_tokens, question_tokens, ans_tokens, ans_start_wordloc, ans_end_wordloc])

    print(f"Number of  (context, question, answer) triples discarded due to char -> token mapping problems: {num_mapping_prob}")
    print(f"Number of  (context, question, answer) triples discarded because character-based answer span is unaligned with tokenization: {not_matching_ans_tokens}")
    print(f"Number of  (context, question, answer) triples discarded due answer span alignment problems: {not_matching_answer}")

    return examples


In [9]:
%%time
train_examples = parse_data(train_file)
dev_examples = parse_data(dev_file)

Number of  (context, question, answer) triples discarded due to char -> token mapping problems: 97
Number of  (context, question, answer) triples discarded because character-based answer span is unaligned with tokenization: 2521
Number of  (context, question, answer) triples discarded due answer span alignment problems: 23
Number of  (context, question, answer) triples discarded due to char -> token mapping problems: 0
Number of  (context, question, answer) triples discarded because character-based answer span is unaligned with tokenization: 331
Number of  (context, question, answer) triples discarded due answer span alignment problems: 0
CPU times: user 42.9 s, sys: 261 ms, total: 43.1 s
Wall time: 45 s


In [10]:
len(train_examples), len(dev_examples)

(84958, 10239)

### Vocabulary

In [11]:
class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

In [12]:
def build_vocab(sentences, threshold=50):
    """Build a simple vocabulary wrapper."""
    counter = Counter()
    for i, sent in enumerate(sentences):
        counter.update(sent[0])
        counter.update(sent[1])

    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

In [13]:
word_vocab = build_vocab(train_examples)
print(f"Length of vocab: {len(word_vocab)}")

Length of vocab: 14160


### Dataset Wrapper

In [14]:
class SquadDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, item):
        example = self.examples[item]

        context_tokens = example[0]
        question_tokens = example[1]
        answer_tokens = example[2]
        ans_start_idx = example[3]
        ans_end_idx = example[4]
        ans_span = [ans_start_idx, ans_end_idx]

        context_tokens = ['<start>'] + context_tokens + ['<end>']
        question_tokens = ['<start>'] + question_tokens + ['<end>']
        context_ids = [word_vocab(tok) for tok in context_tokens]
        question_ids = [word_vocab(tok) for tok in question_tokens]

        return (
            torch.LongTensor(context_ids),
            context_tokens,
            torch.LongTensor(question_ids),
            question_tokens,
            torch.LongTensor(ans_span),
            answer_tokens
        )

In [15]:
train_dataset = SquadDataset(train_examples)
dev_dataset = SquadDataset(dev_examples)

In [16]:
train_dataset[0]

(tensor([ 1,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,  6, 14, 15, 16, 17, 18, 19,
          9, 20, 21, 22,  6, 23, 24, 12, 25, 26, 27, 22,  6, 14, 15, 28, 29, 30,
          5, 19,  9, 31, 21, 22, 32, 33, 34,  3, 33,  6, 35, 36,  3, 37, 38,  3,
         36, 12, 39, 40,  6, 14, 15, 19,  6, 41, 22,  6, 42, 43, 12, 25, 44,  6,
         41, 19,  6,  3,  5,  9, 45, 46, 22, 47, 28, 48, 12, 30, 19,  9, 49, 22,
          6,  3, 50,  3,  5, 51, 52,  6, 23, 24,  3, 53, 40, 54,  3,  3, 26,  3,
         50,  6, 55, 22,  6, 14, 56, 57, 28, 26,  9, 58, 59, 60, 61, 62, 63, 64,
         28,  6, 17, 18, 65,  5, 19,  9, 66,  5, 67, 68, 21, 22, 24, 12,  2]),
 ['<start>',
  'architecturally',
  ',',
  'the',
  'school',
  'has',
  'a',
  'catholic',
  'character',
  '.',
  'atop',
  'the',
  'main',
  'building',
  "'s",
  'gold',
  'dome',
  'is',
  'a',
  'golden',
  'statue',
  'of',
  'the',
  'virgin',
  'mary',
  '.',
  'immediately',
  'in',
  'front',
  'of',
  'the',
  'main',
  'building',
  'and'

### DataLoaders

In [17]:
def collate_fn(data):
    data.sort(key=lambda x: len(x[2]), reverse=True)
    context_ids, context_tokens, question_ids, question_tokens, answer_span, answer_tokens = zip(*data)

    question_lengths = [len(ques) for ques in question_ids]
    context_lengths = [len(context) for context in context_ids]

    context_ids_padded = torch.zeros((len(context_ids), max(context_lengths)), dtype=torch.long)
    question_ids_padded = torch.zeros((len(question_ids), max(question_lengths)), dtype=torch.long)

    for i, sent in enumerate(context_ids):
        end = context_lengths[i]
        context_ids_padded[i, :end] = sent[:end]
    
    for i, sent in enumerate(question_ids):
        end = question_lengths[i]
        question_ids_padded[i, :end] = sent[:end]

    questions_mask = (question_ids_padded != 0).long()
    contexts_mask = (context_ids_padded != 0).long()
    return {
        "question_ids": question_ids_padded,
        "question_masks": questions_mask,
        "question_tokens": question_tokens,
        "context_ids": context_ids_padded,
        "context_masks": contexts_mask,
        "context_tokens": context_tokens,
        "answer_spans": torch.stack(answer_span),
        "answer_tokens": answer_tokens
    }

In [18]:
BATCH_SIZE = 128

train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
dev_data_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [19]:
sample = next(iter(train_data_loader))
sample['question_ids'].shape, sample['question_masks'].shape, sample['context_ids'].shape, sample['context_masks'].shape, sample['answer_spans'].shape

(torch.Size([128, 41]),
 torch.Size([128, 41]),
 torch.Size([128, 398]),
 torch.Size([128, 398]),
 torch.Size([128, 2]))

## Double Cross Attention Network

### Encoder

In [20]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, dropout):
        super().__init__()

        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.encoder = nn.LSTM(emb_dim, hidden_dim, batch_first=True, num_layers=1, bidirectional=True)

        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input, input_mask):
        # input => [batch_size, seq_len]
        # input_mask => [batch_size, seq_len]

        input_lengths = torch.sum(input_mask, 1)
        sorted_lengths, sorted_lengths_index = torch.sort(input_lengths, 0, True)
        _, original_index = torch.sort(sorted_lengths_index, 0)

        # arrange input according to descending length
        input_sorted = torch.index_select(input, 0, sorted_lengths_index)
        input_embed = self.embedding(input_sorted)
        input_embed = self.dropout(input_embed)
        packed_input = nn.utils.rnn.pack_padded_sequence(input_embed, sorted_lengths, batch_first=True)
        output, _ = self.encoder(packed_input)
        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
        output = output.contiguous()

        # rearrange the output to it's original order
        output = torch.index_select(output, 0, original_index)
        output = self.dropout(output)

        return output

### CoAttention Encoder

In [32]:
class CoAttentionEncoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, dropout):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.encoder_dq = Encoder(input_dim, emb_dim, hidden_dim, dropout)
        
        self.q_proj = nn.Linear(2 * hidden_dim, 2 * hidden_dim)
        self.fusion_lstm = nn.LSTM(6 * hidden_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=True)

        self.dropout = nn.Dropout(dropout)
    
    def forward(self, question, question_mask, document, document_mask):
        # question => [batch_size, ques_len]
        # question_mask => [batch_size, ques_len]
        # document => [batch_size, doc_len]
        # document_mask => [batch_size, doc_len]

        Q = self.encoder_dq(question, question_mask)
        # Q => [batch_size, ques_len, hidden_dim * 2]
        C = self.encoder_dq(document, document_mask)
        # C => [batch_size, doc_len, hidden_dim * 2]

        Q = torch.tanh(self.q_proj(Q.view(-1, 2 * hidden_dim)).view(Q.size()))
        # Q => [batch_size, ques_len, hidden_dim * 2]

        S = torch.bmm(Q, C.transpose(1, 2))
        # Q   => [batch_size, ques_len, hid_dim * 2]
        # C_t => [batch_size, hid_dim * 2, doc_len]
        # S   => [batch_size, ques_len, doc_len]

        A_Q = F.softmax(S, dim=1)
        # A_Q => [batch_size, ques_len, doc_len]
        C2Q = torch.bmm(A_Q.transpose(1, 2), Q)
        # A_Q_t => [batch_size, doc_len, ques_len]
        # Q     => [batch_size, ques_len, hid_dim * 2]
        # C2Q   => [batch_size, doc_len, hid_dim * 2]

        A_C = F.softmax(S, dim=2)
        # A_C => [batch_size, ques_len, doc_len]
        Q2C = torch.bmm(A_C, C)
        # A_C => [batch_size, ques_len, doc_len]
        # C   => [batch_size, doc_len, hid_dim * 2]
        # Q2C => [batch_size, ques_len, hid_dim * 2]

        R = torch.bmm(Q2C, C2Q.transpose(1, 2))
        # Q2C   => [batch_size, ques_len, hid_dim * 2]
        # C2Q_t => [batch_size, hid_dim * 2, doc_len]
        # R     => [batch_size, ques_len, doc_len]

        gamma = F.softmax(R, dim=1)
        # gamma => [batch_size, ques_len, doc_len]
        CA2QA = torch.bmm(gamma.transpose(1, 2), Q2C)
        # gamma_t => [batch_size, doc_len, ques_len]
        # Q2C     => [batch_size, ques_len, hid_dim * 2]
        # CA2QA   => [batch_size, doc_len, hid_dim * 2]

        input_bilstm = torch.cat((C, C2Q, CA2QA), dim=2)
        input_bilstm = self.dropout(input_bilstm)
        # input_bilstm => [batch_size, doc_len, hid_dim * 6]

        doc_lengths = torch.sum(document_mask, 1)
        sorted_doc_lengths, sorted_doc_lengths_index = torch.sort(doc_lengths, descending=True)
        _, doc_original_index = torch.sort(sorted_doc_lengths_index)
        sorted_docs = torch.index_select(input_bilstm, 0, sorted_doc_lengths_index)
        packed_input = nn.utils.rnn.pack_padded_sequence(sorted_docs, sorted_doc_lengths, batch_first=True)
        output, _ = self.fusion_lstm(packed_input)
        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
        output = output.contiguous()
        output = torch.index_select(output, 0, doc_original_index)
        output = self.dropout(output)
        # output => [batch_size, doc_len, hid_dim * 2]

        return output

### Highway Maxout

In [22]:
class HighwayMaxoutModel(nn.Module):
    def __init__(self, hidden_dim, pool_size):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.pool_size = pool_size

        self.f_r = nn.Linear(5 * hidden_dim, hidden_dim, bias=False)
        self.f_m_1 = nn.Linear(3 * hidden_dim, pool_size * hidden_dim)
        self.f_m_2 = nn.Linear(hidden_dim, pool_size * hidden_dim)
        self.f_final = nn.Linear(2 * hidden_dim, pool_size)
        self.loss = nn.CrossEntropyLoss(reduction='none')
        
    def forward(self, doc_encoded, doc_mask, loss_mask, old_idx, hidden_state, u_s, u_e, target=None):
        # doc_encoded => [batch_size, doc_len, hid_dim * 2]
        # doc_mask => [batch_size, doc_len]
        # loss_mask => [batch_size, doc_len]
        # old_idx => [batch_size]
        # hidden_state => [batch_size, hid_dim]
        # u_s, u_e => [batch_size, hid_dim * 2]
        # target => [batch_size]

        batch_size, doc_len, _ = list(doc_encoded.size())

        r = torch.tanh(self.f_r(torch.cat((hidden_state, u_s, u_e), dim=1)))
        # r => [batch_size, hid_dim]
        r = r.unsqueeze(1)
        # r => [batch_size, 1, hid_dim]
        r = r.expand(batch_size, doc_len, -1).contiguous()
        # r => [batch_size, doc_len, hid_dim]

        highway_input = torch.cat((doc_encoded, r), dim=2).view(-1, 3 * self.hidden_dim)
        # highway_input => [batch_size * doc_len, 3 * hidden_dim]
        m_1 = self.f_m_1(highway_input)
        # m_1 => [batch_size * doc_len, hidden_dim * pool_size]
        m_1 = m_1.view(batch_size, doc_len, self.pool_size, self.hidden_dim)
        # m_1 => [batch_size, doc_len, pool_size, hidden_dim]
        m_1, _ = torch.max(m_1, 2)
        # m_1 => [batch_size, doc_len, hidden_dim]

        m_2 = self.f_m_2(m_1.view(-1, self.hidden_dim))
        # m_2 => [batch_size * doc_len, hidden_dim * pool_size]
        m_2 = m_2.view(batch_size, doc_len, self.pool_size, self.hidden_dim)
        m_2, _ = torch.max(m_2, 2)
        # m_2 => [batch_size, doc_len, hidden_dim]

        final_input = torch.cat((m_1, m_2), dim=2)
        # final_input => [batch_size, doc_len, hid_dim * 2]
        final_input = final_input.view(-1, self.hidden_dim * 2)
        # final_input => [batch_size * doc_len, hid_dim * 2]
        output = self.f_final(final_input)
        # output => [batch_size * doc_len, pool_size]
        output = output.view(batch_size, doc_len, self.pool_size)
        # output => [batch_size, doc_len, pool_size]
        output, _ = torch.max(output, 2)
        # output => [batch_size, doc_len]

        output = output + doc_mask
        # output => [batch_size, doc_len]
        _, idx_output = torch.max(output, 1)
        # idx_output => [batch_size]

        # Eliminate unnecessary loss values
        if loss_mask is None:
            loss_mask = (idx_output == idx_output)
        else:
            old_idx_ = old_idx * loss_mask.long()
            idx_output_ = idx_output * loss_mask.long()
            loss_mask = (old_idx_ != idx_output_)
        
        loss = None
        # Calculate the loss
        if target is not None:
            scores = F.log_softmax(output, 1)
            loss = self.loss(scores, target)
            loss = loss * loss_mask.float()
        
        return idx_output, loss_mask, loss

### Dynamic Decoder

In [23]:
class DynamicDecoder(nn.Module):
    def __init__(self, hidden_dim, pool_size, max_iter, dropout):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.max_iter = max_iter
        
        self.decoder = nn.LSTM(4 * hidden_dim, hidden_dim, num_layers=1, batch_first=True)
        self.hmn_s = HighwayMaxoutModel(hidden_dim, pool_size)
        self.hmn_e = HighwayMaxoutModel(hidden_dim, pool_size)
    
    def forward(self, doc_encoded, doc_mask, ans_span):
        # doc_encoded => [batch_size, doc_len, hid_dim * 2]
        # doc_mask => [batch_size, doc_len]
        # ans_span => [batch_size, 2]

        batch_size, doc_len, _ = list(doc_encoded.size())

        # Initialize start to be the first word and end to be the last word
        s = torch.zeros(batch_size).long().to(device)
        e = torch.sum(doc_mask, 1) - 1
        e = e.to(device)

        indices = torch.arange(0, batch_size).long().to(device)
        # To make choosing impossible in HighwayMaxout
        mask_hmn = (1 - doc_mask).float() * -1e15

        target_s, target_e = None, None
        if ans_span is not None:
            target_s = ans_span[:, 0]
            target_e = ans_span[:, 1]
            # target_* => [batch_size]
        
        lstm_states = None
        losses = []

        for _ in range(self.max_iter):
            u_s = doc_encoded[indices, s, :]
            u_e = doc_encoded[indices, e, :]
            # u_* => [batch_size, hid_dim * 2]

            combined_input = torch.cat((u_s, u_e), dim=1)
            # combined_input => [batch_size, hid_dim * 4]
            combined_input = combined_input.unsqueeze(1)
            # combined_input => [batch_size, 1, hid_dim * 4]

            _, lstm_states = self.decoder(combined_input, lstm_states)
            hidden_state, _ = lstm_states
            # hidden_state => [num_layers * num_dir, batch_size, hid_dim]
            hidden_state = hidden_state.view(-1, self.hidden_dim)
            # hidden_state => [batch_size, hid_dim]

            loss_mask_s, loss_mask_e = None, None

            s_new, loss_mask_s, loss_s = self.hmn_s(doc_encoded, mask_hmn, loss_mask_s, s, hidden_state, u_s, u_e, target_s)
            e_new, loss_mask_e, loss_e = self.hmn_e(doc_encoded, mask_hmn, loss_mask_e, e, hidden_state, u_s, u_e, target_e)

            if ans_span is not None:
                losses.append(loss_s + loss_e)
            
            if torch.sum(s_new != s).item() == 0 and torch.sum(e_new != e).item() == 0:
                s = s_new
                e = e_new
                break
            
            s = s_new
            e = e_new
        
        cumulative_loss = None

        if ans_span is not None:
            cumulative_loss = torch.sum(torch.stack(losses, 1), 1)
            cumulative_loss = cumulative_loss / self.max_iter
            cumulative_loss = torch.mean(cumulative_loss)
        
        return s, e, cumulative_loss

### DCA

In [33]:
class DCAModel(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, pool_size, max_iter, dropout=0.4):
        super().__init__()

        self.encoder = CoAttentionEncoder(input_dim, emb_dim, hidden_dim, dropout)
        self.decoder = DynamicDecoder(hidden_dim, pool_size, max_iter, dropout)
    
    def forward(self, question, question_mask, document, document_mask, ans_span=None):
        U = self.encoder(question, question_mask, document, document_mask)
        s, e, loss = self.decoder(U, document_mask, ans_span)
        if ans_span is not None:
            return loss, s, e
        else:
            return s, e


### Model

In [34]:
input_dim = len(word_vocab)
emb_dim = 100
hidden_dim = 200
pool_size = 16
max_dec_steps = 4

model = DCAModel(input_dim, emb_dim, hidden_dim, pool_size, max_dec_steps)
model = model.to(device)

In [35]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"The model has {count_parameters(model):,} trainable parameters")

The model has 10,650,032 trainable parameters


### Optimizer

In [27]:
lr = 1e-3
min_lr = 3e-5
lr_decay=0.5
lr_patience=2

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = ReduceLROnPlateau(optimizer, 'min', lr_decay, lr_patience, verbose=True, min_lr=min_lr)

### Training Method

In [28]:
def train(iterator, clip=2.0):
    epoch_loss = 0
    model.train()
    for batch in iterator:
        question_ids = batch['question_ids'].to(device)
        question_masks = batch['question_masks'].to(device)
        document_ids = batch['context_ids'].to(device)
        document_masks = batch['context_masks'].to(device)
        ans_span = batch['answer_spans'].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + loss
        loss, start_logits, end_logits = model(question_ids, question_masks, document_ids, document_masks, ans_span)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()
    return epoch_loss / len(iterator)

### Evaluation Method

In [29]:
def evaluate(iterator):
    epoch_loss = 0
    model.eval()
    epoch_f1, epoch_em = 0, 0

    with torch.no_grad():
        for batch in iterator:
            question_ids = batch['question_ids'].to(device)
            question_masks = batch['question_masks'].to(device)
            document_ids = batch['context_ids'].to(device)
            document_masks = batch['context_masks'].to(device)
            ans_span = batch['answer_spans'].to(device)

            # forward  + loss
            loss, start_logits, end_logits = model(question_ids, question_masks, document_ids, document_masks, ans_span)
            epoch_loss += loss.item()

            # f1 + em
            # TODO
    return epoch_loss / len(iterator) #, epoch_f1 / len(iterator), epoch_em / len(iterator)

In [30]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = elapsed_time - (elapsed_mins * 60)
    return elapsed_mins, elapsed_secs

### Training

In [None]:
NUM_EPOCHS = 2
best_valid_loss = float('inf')
for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    train_loss = train(train_data_loader)
    val_loss = evaluate(dev_data_loader)
    end_time = time.time()
    scheduler.step(val_loss)

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    print(f"Epoch: {epoch + 1:02} | Time: {epoch_mins}m {epoch_secs:.2f}s")
    print(f"\tTrain Loss: {train_loss:.3f} | Val Loss: {val_loss:.3f}")
    
    if val_loss < best_valid_loss:
        best_valid_loss = val_loss
        torch.save(model.state_dict(), 'model.pt')

Epoch: 01 | Time: 50m 2.92s
	Train Loss: 9.721 | Val Loss: 9.770
