In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import json
import numpy as np
import random
import os
import unicodedata
from collections import Counter
import re
import string
import copy
from tqdm import tqdm

In [3]:
class StackedBRNN(nn.Module):
    """Stacked Bi-directional RNNs.

    Differs from standard PyTorch library in that it has the option to save
    and concat the hidden states between layers. (i.e. the output hidden size
    for each sequence input is num_layers * hidden_size).
    """

    def __init__(self, input_size, hidden_size, num_layers,
                 dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM,
                 concat_layers=False, padding=False):
        super(StackedBRNN, self).__init__()
        self.padding = padding
        self.dropout_output = dropout_output
        self.dropout_rate = dropout_rate
        self.num_layers = num_layers
        self.concat_layers = concat_layers
        self.rnns = nn.ModuleList()
        for i in range(num_layers):
            input_size = input_size if i == 0 else 2 * hidden_size
            self.rnns.append(rnn_type(input_size, hidden_size,
                                      num_layers=1,
                                      bidirectional=True))

    def forward(self, x, x_mask):
        """Encode either padded or non-padded sequences.

        Can choose to either handle or ignore variable length sequences.
        Always handle padding in eval.

        Args:
            x: batch * len * hdim
            x_mask: batch * len (1 for padding, 0 for true)
        Output:
            x_encoded: batch * len * hdim_encoded
        """
        if x_mask.data.sum() == 0:
            # No padding necessary.
            output = self._forward_unpadded(x, x_mask)
        elif self.padding or not self.training:
            # Pad if we care or if its during eval.
            output = self._forward_padded(x, x_mask)
        else:
            # We don't care.
            output = self._forward_unpadded(x, x_mask)

        return output.contiguous()

    def _forward_unpadded(self, x, x_mask):
        """Faster encoding that ignores any padding."""
        # Transpose batch and sequence dims
        x = x.transpose(0, 1)

        # Encode all layers
        outputs = [x]
        for i in range(self.num_layers):
            rnn_input = outputs[-1]

            # Apply dropout to hidden input
            if self.dropout_rate > 0:
                rnn_input = F.dropout(rnn_input,
                                      p=self.dropout_rate,
                                      training=self.training)
            # Forward
            rnn_output = self.rnns[i](rnn_input)[0]
            outputs.append(rnn_output)

        # Concat hidden layers
        if self.concat_layers:
            output = torch.cat(outputs[1:], 2)
        else:
            output = outputs[-1]

        # Transpose back
        output = output.transpose(0, 1)

        # Dropout on output layer
        if self.dropout_output and self.dropout_rate > 0:
            output = F.dropout(output,
                               p=self.dropout_rate,
                               training=self.training)
        return output

    def _forward_padded(self, x, x_mask):
        """Slower (significantly), but more precise, encoding that handles
        padding.
        """
        # Compute sorted sequence lengths
        lengths = x_mask.data.eq(0).long().sum(1).squeeze()
        _, idx_sort = torch.sort(lengths, dim=0, descending=True)
        _, idx_unsort = torch.sort(idx_sort, dim=0)
        lengths = list(lengths[idx_sort])

        # Sort x
        x = x.index_select(0, idx_sort)

        # Transpose batch and sequence dims
        x = x.transpose(0, 1)

        # Pack it up
        rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths)

        # Encode all layers
        outputs = [rnn_input]
        for i in range(self.num_layers):
            rnn_input = outputs[-1]

            # Apply dropout to input
            if self.dropout_rate > 0:
                dropout_input = F.dropout(rnn_input.data,
                                          p=self.dropout_rate,
                                          training=self.training)
                rnn_input = nn.utils.rnn.PackedSequence(dropout_input,
                                                        rnn_input.batch_sizes)
            outputs.append(self.rnns[i](rnn_input)[0])

        # Unpack everything
        for i, o in enumerate(outputs[1:], 1):
            outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0]

        # Concat hidden layers or take final
        if self.concat_layers:
            output = torch.cat(outputs[1:], 2)
        else:
            output = outputs[-1]

        # Transpose and unsort
        output = output.transpose(0, 1)
        output = output.index_select(0, idx_unsort)

        # Pad up to original batch sequence length
        if output.size(1) != x_mask.size(1):
            padding = torch.zeros(output.size(0),
                                  x_mask.size(1) - output.size(1),
                                  output.size(2)).type(output.data.type())
            output = torch.cat([output, padding], 1)

        # Dropout on output layer
        if self.dropout_output and self.dropout_rate > 0:
            output = F.dropout(output,
                               p=self.dropout_rate,
                               training=self.training)
        return output

In [4]:
class SeqAttnMatch(nn.Module):
    """Given sequences X and Y, match sequence Y to each element in X.

    * o_i = sum(alpha_j * y_j) for i in X
    * alpha_j = softmax(y_j * x_i)
    """

    def __init__(self, input_size, identity=False):
        super(SeqAttnMatch, self).__init__()
        if not identity:
            self.linear = nn.Linear(input_size, input_size)
        else:
            self.linear = None

    def forward(self, x, y, y_mask):
        """
        Args:
            x: batch * len1 * hdim
            y: batch * len2 * hdim
            y_mask: batch * len2 (1 for padding, 0 for true)
        Output:
            matched_seq: batch * len1 * hdim
        """
        # Project vectors
        if self.linear:
            x_proj = self.linear(x.view(-1, x.size(2))).view(x.size())
            x_proj = F.relu(x_proj)
            y_proj = self.linear(y.view(-1, y.size(2))).view(y.size())
            y_proj = F.relu(y_proj)
        else:
            x_proj = x
            y_proj = y

        # Compute scores
        scores = x_proj.bmm(y_proj.transpose(2, 1))

        # Mask padding
        y_mask = y_mask.unsqueeze(1).expand(scores.size())
        scores.data.masked_fill_(y_mask.data, -float('inf'))

        # Normalize with softmax
        alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1)
        alpha = alpha_flat.view(-1, x.size(1), y.size(1))

        # Take weighted average
        matched_seq = alpha.bmm(y)
        return matched_seq

In [5]:
class BilinearSeqAttn(nn.Module):
    """A bilinear attention layer over a sequence X w.r.t y:

    * o_i = softmax(x_i'Wy) for x_i in X.

    Optionally don't normalize output weights.
    """

    def __init__(self, x_size, y_size, identity=False, normalize=True):
        super(BilinearSeqAttn, self).__init__()
        self.normalize = normalize

        # If identity is true, we just use a dot product without transformation.
        if not identity:
            self.linear = nn.Linear(y_size, x_size)
        else:
            self.linear = None

    def forward(self, x, y, x_mask):
        """
        Args:
            x: batch * len * hdim1
            y: batch * hdim2
            x_mask: batch * len (1 for padding, 0 for true)
        Output:
            alpha = batch * len
        """
        Wy = self.linear(y) if self.linear is not None else y
        xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)
        xWy.data.masked_fill_(x_mask.data, -float('inf'))
        if self.normalize:
            if self.training:
                # In training we output log-softmax for NLL
                alpha = F.log_softmax(xWy, dim=-1)
            else:
                # ...Otherwise 0-1 probabilities
                alpha = F.softmax(xWy, dim=-1)
        else:
            alpha = xWy.exp()
        return alpha

In [6]:
class LinearSeqAttn(nn.Module):
    """Self attention over a sequence:

    * o_i = softmax(Wx_i) for x_i in X.
    """

    def __init__(self, input_size):
        super(LinearSeqAttn, self).__init__()
        self.linear = nn.Linear(input_size, 1)

    def forward(self, x, x_mask):
        """
        Args:
            x: batch * len * hdim
            x_mask: batch * len (1 for padding, 0 for true)
        Output:
            alpha: batch * len
        """
        x_flat = x.view(-1, x.size(-1))
        scores = self.linear(x_flat).view(x.size(0), x.size(1))
        scores.data.masked_fill_(x_mask.data, -float('inf'))
        alpha = F.softmax(scores, dim=-1)
        return alpha

In [7]:
class RnnDocReader(nn.Module):
    def __init__(self, vocab_size, num_features, embedding_dim=300, normalize=True):
        super(RnnDocReader, self).__init__()
        # Word embeddings (+1 for padding)

        self.embedding = nn.Embedding(vocab_size,
                                      embedding_dim,
                                      padding_idx=0)

        # Projection for attention weighted question
        self.qemb_match = SeqAttnMatch(embedding_dim)

        # Input size to RNN: word emb + question emb + manual features
        doc_input_size = embedding_dim * 2+ num_features

        # RNN document encoder
        self.doc_rnn = StackedBRNN(
            input_size=doc_input_size,
            hidden_size=128,
            num_layers=3, # Number of encoding layers for document
            dropout_rate=0.4,
            dropout_output=True,
            concat_layers=True,
            rnn_type=nn.LSTM,
            padding=False, # Explicitly account for padding in RNN encoding
        )

        # RNN question encoder
        self.question_rnn = StackedBRNN(
            input_size=embedding_dim,
            hidden_size=128,
            num_layers=3,
            dropout_rate=0.4,
            dropout_output=True,
            concat_layers=True,
            rnn_type=nn.LSTM,
            padding=False,
        )

        # Output sizes of rnn encoders
        doc_hidden_size = 2 * 128 # 2 layers, 128 neurons
        question_hidden_size = 2 * 128
        # if concatenate rnn layers:
        doc_hidden_size *= 3
        question_hidden_size *= 3

        # Question merging
        self.self_attn = LinearSeqAttn(question_hidden_size)

        # Bilinear attention for span start/end
        self.start_attn = BilinearSeqAttn(
            doc_hidden_size,
            question_hidden_size,
            normalize=normalize,
        )
        self.end_attn = BilinearSeqAttn(
            doc_hidden_size,
            question_hidden_size,
            normalize=normalize,
        )
    
    def _weighted_avg(self, x, weights):
        """Return a weighted average of x (a sequence of vectors).

        Args:
            x: batch * len * hdim
            weights: batch * len, sum(dim = 1) = 1
        Output:
            x_avg: batch * hdim
        """
        return weights.unsqueeze(1).bmm(x).squeeze(1)

    def forward(self, x1, x1_f, x1_mask, x2, x2_mask, dropout_emb=0.3):
        """Inputs:
        x1 = document word indices             [batch * len_d]
        x1_f = document word features indices  [batch * len_d * nfeat]
        x1_mask = document padding mask        [batch * len_d]
        x2 = question word indices             [batch * len_q]
        x2_mask = question padding mask        [batch * len_q]
        """
        # Embed both document and question
        x1_emb = self.embedding(x1)
        x2_emb = self.embedding(x2)

        # Dropout on embeddings
        if dropout_emb > 0:
            x1_emb = nn.functional.dropout(x1_emb, p=dropout_emb,
                                           training=self.training)
            x2_emb = nn.functional.dropout(x2_emb, p=dropout_emb,
                                           training=self.training)

        # Form document encoding inputs
        drnn_input = [x1_emb]

        # Add attention-weighted question representation
        x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask)
        drnn_input.append(x2_weighted_emb)

        # Add manual features
        drnn_input.append(x1_f)

        # Encode document with RNN
        doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask)

        # Encode question with RNN + merge hiddens
        question_hiddens = self.question_rnn(x2_emb, x2_mask)
        q_merge_weights = self.self_attn(question_hiddens, x2_mask)
        question_hidden = self._weighted_avg(question_hiddens, q_merge_weights)

        # Predict start and end positions
        start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask)
        end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask)
        return start_scores, end_scores

In [8]:
class Dictionary(object):
    NULL = '<NULL>'
    UNK = '<UNK>'
    START = 2

    @staticmethod
    def normalize(token):
        return unicodedata.normalize('NFD', token)

    def __init__(self):
        self.tok2ind = {self.NULL: 0, self.UNK: 1}
        self.ind2tok = {0: self.NULL, 1: self.UNK}

    def __len__(self):
        return len(self.tok2ind)

    def __iter__(self):
        return iter(self.tok2ind)

    def __contains__(self, key):
        if type(key) == int:
            return key in self.ind2tok
        elif type(key) == str:
            return self.normalize(key) in self.tok2ind

    def __getitem__(self, key):
        if type(key) == int:
            return self.ind2tok.get(key, self.UNK)
        if type(key) == str:
            return self.tok2ind.get(self.normalize(key),
                                    self.tok2ind.get(self.UNK))

    def __setitem__(self, key, item):
        if type(key) == int and type(item) == str:
            self.ind2tok[key] = item
        elif type(key) == str and type(item) == int:
            self.tok2ind[key] = item
        else:
            raise RuntimeError('Invalid (key, item) types.')

    def add(self, token):
        token = self.normalize(token)
        if token not in self.tok2ind:
            index = len(self.tok2ind)
            self.tok2ind[token] = index
            self.ind2tok[index] = token

    def tokens(self):
        """Get dictionary tokens.

        Return all the words indexed by this dictionary, except for special
        tokens.
        """
        tokens = [k for k in self.tok2ind.keys()
                  if k not in {'<NULL>', '<UNK>'}]
        return tokens

In [9]:
class ReaderDataset(Dataset):

    def __init__(self, examples, word_dict, feature_dict, single_answer=False):
        self.examples = examples
        self.word_dict = word_dict
        self.feature_dict = feature_dict
        self.single_answer = single_answer
    
    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        """Torchify a single example."""
        ex = self.examples[index]

        # Index words
        document = torch.LongTensor([self.word_dict[w] for w in ex['document']])
        question = torch.LongTensor([self.word_dict[w] for w in ex['question']])

        # Create extra features vector
        if len(self.feature_dict) > 0:
            features = torch.zeros(len(ex['document']), len(self.feature_dict))
        else:
            features = None

        # f_{exact_match}
        q_words_cased = {w for w in ex['question']}
        q_words_uncased = {w.lower() for w in ex['question']}
        q_lemma = {w for w in ex['qlemma']}
        for i in range(len(ex['document'])):
            if ex['document'][i] in q_words_cased:
                features[i][self.feature_dict['in_question']] = 1.0
            if ex['document'][i].lower() in q_words_uncased:
                features[i][self.feature_dict['in_question_uncased']] = 1.0
            if q_lemma and ex['lemma'][i] in q_lemma:
                features[i][self.feature_dict['in_question_lemma']] = 1.0

        # f_{token} (POS)
        for i, w in enumerate(ex['pos']):
            f = 'pos=%s' % w
            if f in self.feature_dict:
                features[i][self.feature_dict[f]] = 1.0

        # f_{token} (NER)
        for i, w in enumerate(ex['ner']):
            f = 'ner=%s' % w
            if f in self.feature_dict:
                features[i][self.feature_dict[f]] = 1.0

        # f_{token} (TF)
        counter = Counter([w.lower() for w in ex['document']])
        l = len(ex['document'])
        for i, w in enumerate(ex['document']):
            features[i][self.feature_dict['tf']] = counter[w.lower()] * 1.0 / l

        # Maybe return without target
        if 'answers' not in ex:
            return document, features, question, ex['id']

        # ...or with target(s) (might still be empty if answers is empty)
        if self.single_answer:
            assert(len(ex['answers']) > 0)
            start = torch.LongTensor(1).fill_(ex['answers'][0][0])
            end = torch.LongTensor(1).fill_(ex['answers'][0][1])
        else:
            start = [a[0] for a in ex['answers']]
            end = [a[1] for a in ex['answers']]

        return document, features, question, start, end, ex['id']

    def lengths(self):
        return [(len(ex['document']), len(ex['question']))
                for ex in self.examples]

In [10]:
class AverageMeter(object):
    """Computes and stores the average and current value."""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [32]:
class Model():
    def __init__(self):
        self.device = self._get_device()
        self._set_random_seed()
    
    def _get_device(self, show_info = False):
        if torch.cuda.is_available():    
            device = torch.device("cuda")

            if show_info:
                print('There are %d GPU(s) available.' % torch.cuda.device_count())
                print('We will use the GPU:', torch.cuda.get_device_name(0))

        else:
            device = torch.device("cpu")

            if show_info:
                print('No GPU available, using the CPU instead.')

        return device
    
    def _set_random_seed(self, seed=1013):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
    
    def _load_data(self, filename, uncased_question=False, uncased_doc=False, skip_no_answer=False):
        """Load examples from preprocessed file.
        One example per line, JSON encoded.
        """
        # Load JSON lines
        with open(filename) as f:
            examples = [json.loads(line) for line in f]

        # Make case insensitive?
        if uncased_question or uncased_doc:
            for ex in examples:
                if uncased_question:
                    ex['question'] = [w.lower() for w in ex['question']]
                if uncased_doc:
                    ex['document'] = [w.lower() for w in ex['document']]

        # Skip unparsed (start/end) examples
        if skip_no_answer:
            examples = [ex for ex in examples if len(ex['answers']) > 0]

        return examples
    
    def _load_text(self, filename):
        """Load the paragraphs only of a SQuAD dataset. Store as qid -> text."""
        # Load JSON file
        with open(filename) as f:
            examples = json.load(f)['data']

        texts = {}
        for article in examples:
            for paragraph in article['paragraphs']:
                for qa in paragraph['qas']:
                    texts[qa['id']] = paragraph['context']
        return texts

    def _load_answers(self, filename):
        """Load the answers only of a SQuAD dataset. Store as qid -> [answers]."""
        # Load JSON file
        with open(filename) as f:
            examples = json.load(f)['data']

        ans = {}
        for article in examples:
            for paragraph in article['paragraphs']:
                for qa in paragraph['qas']:
                    ans[qa['id']] = list(map(lambda x: x['text'], qa['answers']))
        return ans

    def _build_feature_dict(self, examples):
        """Index features (one hot) from fields in examples and options."""
        def _insert(feature):
            if feature not in feature_dict:
                feature_dict[feature] = len(feature_dict)

        feature_dict = {}

        # Exact match features
        _insert('in_question')
        _insert('in_question_uncased')
        _insert('in_question_lemma')

        # Part of speech tag features
        for ex in examples:
            for w in ex['pos']:
                _insert('pos=%s' % w)

        # Named entity tag features
        for ex in examples:
            for w in ex['ner']:
                _insert('ner=%s' % w)

        # Term frequency feature
        _insert('tf')
        
        return feature_dict

    def _build_word_dict(self, examples, embedding_file):
        """Return a dictionary from question and document words in
        provided examples.
        """
        def load_words(examples, embedding_file):
            """Iterate and index all the words in examples (documents + questions)."""
            def _insert(iterable, valid_words):
                for w in iterable:
                    w = Dictionary.normalize(w)
                    if valid_words and w not in valid_words:
                        continue
                    words.add(w)
            
            # Put all the words in embedding_file into a set.
            valid_words = set()
            with open(embedding_file) as f:
                for line in f:
                    w = Dictionary.normalize(line.rstrip().split(' ')[0])
                    valid_words.add(w)

            words = set()
            for ex in examples:
                _insert(ex['question'], valid_words)
                _insert(ex['document'], valid_words)

            return words

        word_dict = Dictionary()
        for w in load_words(examples, embedding_file):
            word_dict.add(w)

        return word_dict

    def _load_embeddings(self, model, word_dict, embedding_file):
        """Load pretrained embeddings for a given list of words, if they exist.

        Args:
            words: iterable of tokens. Only those that are indexed in the
                dictionary are kept.
            embedding_file: path to text file of embeddings, space separated.
        """
        words = {w for w in word_dict.tokens()}
        embedding = model.embedding.weight.data

        # When normalized, some words are duplicated. (Average the embeddings).
        vec_counts = {}
        with open(embedding_file) as f:
            # Skip first line if of form count/dim.
            line = f.readline().rstrip().split(' ')
            if len(line) != 2:
                f.seek(0)
            for line in f:
                parsed = line.rstrip().split(' ')
                assert(len(parsed) == embedding.size(1) + 1)
                w = word_dict.normalize(parsed[0])
                if w in words:
                    vec = torch.Tensor([float(i) for i in parsed[1:]])
                    if w not in vec_counts:
                        vec_counts[w] = 1
                        embedding[word_dict[w]].copy_(vec)
                    else:
                        # 'WARN: Duplicate embedding found
                        vec_counts[w] = vec_counts[w] + 1
                        embedding[word_dict[w]].add_(vec)

        for w, c in vec_counts.items():
            embedding[word_dict[w]].div_(c)

        print('Loaded %d embeddings (%.2f%%)' %
                    (len(vec_counts), 100 * len(vec_counts) / len(words)))

    def _top_question_words(self, examples, word_dict, tune_partial=1000):
        """Count and return the most common question words in provided examples."""
        word_count = Counter()
        for ex in examples:
            for w in ex['question']:
                w = Dictionary.normalize(w)
                if w in word_dict:
                    word_count.update([w])
        return word_count.most_common(tune_partial)

    def _tune_embeddings(self, words, model, word_dict):
        """Unfix the embeddings of a list of words. This is only relevant if
        only some of the embeddings are being tuned (tune_partial = N).

        Shuffles the N specified words to the front of the dictionary, and saves
        the original vectors of the other N + 1:vocab words in a fixed buffer.

        Args:
            words: iterable of tokens contained in dictionary.
        """
        words = {w for w in words}

        # Shuffle words and vectors
        embedding = model.embedding.weight.data
        for idx, swap_word in enumerate(words, word_dict.START):
            # Get current word + embedding for this index
            curr_word = word_dict[idx]
            curr_emb = embedding[idx].clone()
            old_idx = word_dict[swap_word]

            # Swap embeddings + dictionary indices
            embedding[idx].copy_(embedding[old_idx])
            embedding[old_idx].copy_(curr_emb)
            word_dict[swap_word] = idx
            word_dict[idx] = swap_word
            word_dict[curr_word] = old_idx
            word_dict[old_idx] = curr_word

        # Save the original, fixed embeddings
        model.register_buffer(
            'fixed_embedding', embedding[idx + 1:].clone()
        )

    def _init_optimizer(self, model, weight_decay=0):
        """Initialize an adamax optimizer for the free parameters of the network.
        """
        parameters = [p for p in model.parameters() if p.requires_grad]

        optimizer = torch.optim.Adamax(parameters, weight_decay=weight_decay)
        
        return optimizer

    def _batchify(self, batch):
        """Gather a batch of individual examples into one batch."""
        NUM_INPUTS = 3
        NUM_TARGETS = 2
        NUM_EXTRA = 1

        ids = [ex[-1] for ex in batch]
        docs = [ex[0] for ex in batch]
        features = [ex[1] for ex in batch]
        questions = [ex[2] for ex in batch]

        
        # Batch documents and features
        max_length = max([d.size(0) for d in docs])
        x1 = torch.LongTensor(len(docs), max_length).zero_()
        x1_mask = torch.BoolTensor(len(docs), max_length).fill_(1) # ByteTensor
        if features[0] is None:
            x1_f = None
        else:
            x1_f = torch.zeros(len(docs), max_length, features[0].size(1))
        for i, d in enumerate(docs):
            x1[i, :d.size(0)].copy_(d)
            x1_mask[i, :d.size(0)].fill_(0)
            if x1_f is not None:
                x1_f[i, :d.size(0)].copy_(features[i])

        # Batch questions
        max_length = max([q.size(0) for q in questions])
        x2 = torch.LongTensor(len(questions), max_length).zero_()
        x2_mask = torch.BoolTensor(len(questions), max_length).fill_(1)
        for i, q in enumerate(questions):
            x2[i, :q.size(0)].copy_(q)
            x2_mask[i, :q.size(0)].fill_(0)

        # Maybe return without targets
        if len(batch[0]) == NUM_INPUTS + NUM_EXTRA:
            return x1, x1_f, x1_mask, x2, x2_mask, ids

        elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS:
            # ...Otherwise add targets
            if torch.is_tensor(batch[0][3]):
                y_s = torch.cat([ex[3] for ex in batch])
                y_e = torch.cat([ex[4] for ex in batch])
            else:
                y_s = [ex[3] for ex in batch]
                y_e = [ex[4] for ex in batch]
        else:
            raise RuntimeError('Incorrect number of inputs per example.')
        
        return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids

    def _normalize_answer(self, s):
        """Lower text and remove punctuation, articles and extra whitespace."""
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)

        def white_space_fix(text):
            return ' '.join(text.split())

        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)

        def lower(text):
            return text.lower()

        return white_space_fix(remove_articles(remove_punc(lower(s))))

    def _exact_match_score(self, prediction, ground_truth):
        """Check if the prediction is a (soft) exact match with the ground truth."""
        return self._normalize_answer(prediction) == self._normalize_answer(ground_truth)

    def _f1_score(self, prediction, ground_truth):
        """Compute the geometric mean of precision and recall for answer tokens."""
        prediction_tokens = self._normalize_answer(prediction).split()
        ground_truth_tokens = self._normalize_answer(ground_truth).split()
        common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0
        precision = 1.0 * num_same / len(prediction_tokens)
        recall = 1.0 * num_same / len(ground_truth_tokens)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1
    
    def _metric_max_over_ground_truths(self, metric_fn, prediction, ground_truths):
        """Given a prediction and multiple valid answers, return the score of
        the best prediction-answer_n pair given a metric function.
        """
        scores_for_ground_truths = []
        for ground_truth in ground_truths:
            score = metric_fn(prediction, ground_truth)
            scores_for_ground_truths.append(score)
        return max(scores_for_ground_truths)

    def _decode(self, score_s, score_e, top_n=1, max_len=None):
        """Take argmax of constrained score_s * score_e.

        Args:
            score_s: independent start predictions
            score_e: independent end predictions
            top_n: number of top scored pairs to take
            max_len: max span length to consider
        """
        pred_s = []
        pred_e = []
        pred_score = []
        max_len = max_len or score_s.size(1)
        for i in range(score_s.size(0)):
            # Outer product of scores to get full p_s * p_e matrix
            scores = torch.ger(score_s[i], score_e[i])

            # Zero out negative length and over-length span scores
            scores.triu_().tril_(max_len - 1)

            # Take argmax or top n
            scores = scores.numpy()
            scores_flat = scores.flatten()
            if top_n == 1:
                idx_sort = [np.argmax(scores_flat)]
            elif len(scores_flat) < top_n:
                idx_sort = np.argsort(-scores_flat)
            else:
                idx = np.argpartition(-scores_flat, top_n)[0:top_n]
                idx_sort = idx[np.argsort(-scores_flat[idx])]
            s_idx, e_idx = np.unravel_index(idx_sort, scores.shape)
            pred_s.append(s_idx)
            pred_e.append(e_idx)
            pred_score.append(scores_flat[idx_sort])
        return pred_s, pred_e, pred_score
    
    def _predict(self, model, ex, top_n=1):
        """Forward a batch of examples only to get predictions.
        """
        # Eval mode
        model.eval()

        # Transfer to GPU
        inputs = [e if e is None else e.to(self.device) for e in ex[:5]]

        # Run forward
        with torch.no_grad():
            score_s, score_e = model(*inputs)

        # Decode predictions
        score_s = score_s.data.to('cpu')
        score_e = score_e.data.to('cpu')

        args = (score_s, score_e, top_n, 15)
        
        return self._decode(*args)

    def _validate_official(self, data_loader, model, global_stats,
                      offsets, texts, answers):
        """Run one full official validation. Uses exact spans and same
        exact match/F1 score computation as in the SQuAD script.

        Extra arguments:
            offsets: The character start/end indices for the tokens in each context.
            texts: Map of qid --> raw text of examples context (matches offsets).
            answers: Map of qid --> list of accepted answers.
        """
        f1 = AverageMeter()
        exact_match = AverageMeter()

        # Run through examples
        examples = 0
        for ex in data_loader:
            ex_id, batch_size = ex[-1], ex[0].size(0)
            pred_s, pred_e, _ = self._predict(model, ex)

            for i in range(batch_size):
                s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
                e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
                prediction = texts[ex_id[i]][s_offset:e_offset]

                # Compute metrics
                ground_truths = answers[ex_id[i]]
                exact_match.update(self._metric_max_over_ground_truths(
                    self._exact_match_score, prediction, ground_truths))
                f1.update(self._metric_max_over_ground_truths(
                    self._f1_score, prediction, ground_truths))

            examples += batch_size

        print('dev valid official: Epoch = %d | EM = %.2f | ' %
                    (global_stats['epoch'], exact_match.avg * 100) +
                    'F1 = %.2f | examples = %d' %
                    (f1.avg * 100, examples))

        return {'exact_match': exact_match.avg * 100, 'f1': f1.avg * 100}
    
    def _save(self, model, word_dict, feature_dict, filename):
        state_dict = copy.copy(model.state_dict())
        if 'fixed_embedding' in state_dict:
            state_dict.pop('fixed_embedding')
        params = {
            'state_dict': state_dict,
            'word_dict': word_dict,
            'feature_dict': feature_dict,
        }
        try:
            torch.save(params, filename)
        except BaseException:
            print('WARN: Saving failed... continuing anyway.')

    def init_model(self, train_file, dev_file, embedding_file):
        train_exs = self._load_data(train_file, skip_no_answer=True)
        dev_exs = self._load_data(dev_file)

        # Create a feature dict out of the annotations in the data
        feature_dict = self._build_feature_dict(train_exs)

        # Build a dictionary from the data questions + words (train/dev splits)
        word_dict = self._build_word_dict(train_exs + dev_exs, embedding_file)

        # Initialize model
        vocab_size = len(word_dict)
        num_features = len(feature_dict)
        model = RnnDocReader(vocab_size, num_features)

        # Load pretrained embeddings for words in dictionary
        self._load_embeddings(model, word_dict, EMBEDDING_FILE)

        # Set up partial tuning of embeddings
        top_words = self._top_question_words(train_exs, word_dict)
        self._tune_embeddings([w[0] for w in top_words], model, word_dict)

        # Set up optimizer
        optimizer = self._init_optimizer(model)

        # Move model to gpu
        model = model.to(self.device)

        return train_exs, dev_exs, word_dict, feature_dict, optimizer, model
        
    def create_dataloaders(self, dev_json_file, train_exs, dev_exs, word_dict, feature_dict):
        # If we are doing offician evals then we need to:
            # 1) Load the original text to retrieve spans from offsets.
            # 2) Load the (multiple) text answers for each question.
        dev_texts = self._load_text(dev_json_file)
        dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
        dev_answers = self._load_answers(dev_json_file)

        train_dataset = ReaderDataset(train_exs, word_dict, feature_dict, single_answer=True)
        dev_dataset = ReaderDataset(dev_exs, word_dict, feature_dict, single_answer=False)

        train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
        train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=32,
                sampler=train_sampler,
                collate_fn=self._batchify,
            )

        dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
        dev_loader = torch.utils.data.DataLoader(
                dev_dataset,
                batch_size=128,
                sampler=dev_sampler,
                collate_fn=self._batchify,
            )

        return train_loader, dev_loader, dev_texts, dev_offsets, dev_answers
        
    def train(self, train_loader, dev_loader, model, optimizer, dev_offsets, dev_texts, dev_answers, num_epochs, model_file, valid_metric):
        stats = {'epoch': 0, 'best_valid': 0}
        for epoch in range(num_epochs):
            stats['epoch'] = epoch

            for idx, ex in tqdm(enumerate(train_loader)):
                model.train()

                # Transfer to GPU
                inputs = [e if e is None else e.to(self.device) for e in ex[:5]]
                target_s = ex[5].to(self.device)
                target_e = ex[6].to(self.device)

                # Run forward
                score_s, score_e = model(*inputs)

                # Compute loss and accuracies
                loss = F.nll_loss(score_s, target_s) + F.nll_loss(score_e, target_e)

                # Clear gradients and run backward
                optimizer.zero_grad()
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(),10)

                # Update parameters
                optimizer.step()

                # Reset any partially fixed parameters (e.g. rare words)
                embedding = model.embedding.weight.data
                fixed_embedding = model.fixed_embedding

                # Embeddings to fix are the last indices
                offset = embedding.size(0) - fixed_embedding.size(0)
                if offset >= 0:
                    embedding[offset:] = fixed_embedding

            print(f"train: Epoch {stats['epoch']} done.")

            # Validate official
            result = self._validate_official(dev_loader, model, stats,
                                    dev_offsets, dev_texts, dev_answers)

            # Save best valid
            if result[valid_metric] > stats['best_valid']:
                self._save(model, word_dict, feature_dict, model_file)
                stats['best_valid'] = result[valid_metric]

### Train DrQA's Document Reader and evaluate on validation dataset after each epoch

In [105]:
TRAIN_FILE = 'data/datasets/SQuAD-v1.1-train-processed-spacy.txt'
DEV_FILE = 'data/datasets/SQuAD-v1.1-dev-processed-spacy.txt'
DEV_JSON_FILE = 'data/datasets/SQuAD-v1.1-dev.json'
EMBEDDING_FILE = 'data/embeddings/glove.840B.300d.txt'
NUM_EPOCHS = 13
VALID_METRIC = 'f1'
MODEL_FILE = 'tmp/drqa-models/document_reader.mdl'

model_class = Model()

train_exs, dev_exs, word_dict, feature_dict, optimizer, model = model_class.init_model(TRAIN_FILE,
                                                                                        DEV_FILE,
                                                                                        EMBEDDING_FILE)

train_loader, dev_loader, dev_texts, dev_offsets, dev_answers = model_class.create_dataloaders(DEV_JSON_FILE,
                                                                                                train_exs,
                                                                                                dev_exs,
                                                                                                word_dict,
                                                                                                feature_dict)

model_class.train(train_loader, dev_loader, model, optimizer, dev_offsets, dev_texts, dev_answers,
                    NUM_EPOCHS, MODEL_FILE, VALID_METRIC)

Loaded 91231 embeddings (100.00%)


2709it [13:16,  3.40it/s]


train: Epoch 0 done.
dev valid official: Epoch = 0 | EM = 62.36 | F1 = 72.68 | examples = 10570
train: Epoch 1 done.
dev valid official: Epoch = 1 | EM = 63.96 | F1 = 73.81 | examples = 10570
train: Epoch 2 done.
dev valid official: Epoch = 2 | EM = 64.90 | F1 = 74.42 | examples = 10570
train: Epoch 3 done.
dev valid official: Epoch = 3 | EM = 65.82 | F1 = 75.27 | examples = 10570
train: Epoch 4 done.
dev valid official: Epoch = 4 | EM = 66.60 | F1 = 75.74 | examples = 10570
train: Epoch 5 done.
dev valid official: Epoch = 5 | EM = 67.07 | F1 = 76.55 | examples = 10570
train: Epoch 6 done.
dev valid official: Epoch = 6 | EM = 67.99 | F1 = 77.09 | examples = 10570
train: Epoch 7 done.
dev valid official: Epoch = 7 | EM = 67.69 | F1 = 77.31 | examples = 10570
train: Epoch 8 done.
dev valid official: Epoch = 8 | EM = 68.18 | F1 = 77.56 | examples = 10570
train: Epoch 9 done.
dev valid official: Epoch = 9 | EM = 67.78 | F1 = 77.23 | examples = 10570
train: Epoch 10 done.
dev valid official