<a href="https://colab.research.google.com/github/dksifoua/Question-Answering/blob/master/1%20-%20DrQA%2C%20Document%20reader%20Question%20Answering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Load dependencies

In [1]:
!pip install tqdm --upgrade >> /dev/null 2>&1
!pip install spacy --upgrade >> /dev/null 2>&1
!python -m spacy download en >> /dev/null 2>&1

In [131]:
import tqdm
import json
import spacy
import collections

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

## Prepare data

***Download data***

In [3]:
!rm -rf ./data
!mkdir ./data

!wget --no-check-certificate \
    https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json \
    -O ./data/train-v1.1.json

!wget --no-check-certificate \
    https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json \
    -O ./data/dev-v1.1.json

--2020-10-31 18:25:53--  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.111.153, 185.199.108.153, 185.199.110.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.111.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30288272 (29M) [application/json]
Saving to: ‘./data/train-v1.1.json’


2020-10-31 18:25:54 (38.8 MB/s) - ‘./data/train-v1.1.json’ saved [30288272/30288272]

--2020-10-31 18:25:55--  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.111.153, 185.199.108.153, 185.199.110.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.111.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4854279 (4.6M) [application/json]
Saving to: ‘./data/dev-v1.1.json’


2020-10-31 18:25:55 (16.9 MB/s) - ‘./data/dev-v1.1.json’ saved [4854279/485427

***Load JSON data***

In [4]:
def load(path):
    with open(path, mode='r', encoding='utf-8') as file:
        return json.load(file)['data']
    raise FileNotFoundError

In [5]:
train_raw_data = load('./data/train-v1.1.json')
valid_raw_data = load('./data/dev-v1.1.json')
print(f'Length of raw train data: {len(train_raw_data):,}')
print(f'Length of raw valid data: {len(valid_raw_data):,}')

Length of raw train data: 442
Length of raw valid data: 48


***Parse JSON data***

In [6]:
def parse(data, nlp=spacy.load('en')):
    qas = []
    for paragraphs in tqdm.tqdm(data):
        for para in paragraphs['paragraphs']:
            context = nlp(para['context'], disable=['parser'])
            for qa in para['qas']:
                id = qa['id']
                question = nlp(qa['question'], disable=['parser', 'tagger', 'ner'])
                for ans in qa['answers']:
                    qas.append({
                        'id': id,
                        'context': context,
                        'question': question,
                        'answer': nlp(ans['text'], disable=['parser', 'tagger', 'ner']),
                        'answer_start': ans['answer_start'],
                    })
    return qas

In [7]:
train_qas = parse(train_raw_data)
valid_qas = parse(valid_raw_data)
print()
print(f'Length of train qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs: {len(valid_qas):,}')
print('==================== Example ====================')
print('Id:', train_qas[0]['id'])
print('Context:', train_qas[0]['context'])
print('Question:', train_qas[0]['question'])
print('Answer starts at:', train_qas[0]['answer_start'])
print('Answer:', train_qas[0]['answer'])

100%|██████████| 442/442 [05:28<00:00,  1.35it/s]
100%|██████████| 48/48 [00:38<00:00,  1.25it/s]


Length of train qa pairs: 87,599
Length of valid qa pairs: 34,726
Id: 5733be284776f41900661182
Context: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer starts at: 515
Answer: Saint Bernadette Soubirous





In [8]:
def test_answer_start(qas):
    """Test answer_start are correct in train set"""
    for qa in tqdm.tqdm(qas):
        answer = qa['answer'].text
        context = qa['context'].text
        answer_start = qa['answer_start']
        assert answer == context[answer_start:answer_start + len(answer)]

In [9]:
test_answer_start(train_qas)
test_answer_start(valid_qas)

100%|██████████| 87599/87599 [00:08<00:00, 9810.06it/s]
100%|██████████| 34726/34726 [00:03<00:00, 10883.37it/s]


***Add targets***

In [10]:
def add_targets(qas):
    """Add start and end index token"""
    for qa in qas:
        context = qa['context']
        answer = qa['answer']
        ans_start = qa['answer_start']
        for i in range(len(context)):
            if context[i].idx == ans_start:
                ans = context[i:i + len(answer)]
                qa['target'] = [ans[0].i, ans[-1].i]
                break

In [11]:
%%time
add_targets(train_qas)
add_targets(valid_qas)
print(f'Length of train qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs: {len(valid_qas):,}')

Length of train qa pairs: 87,599
Length of valid qa pairs: 34,726
CPU times: user 1.55 s, sys: 45.9 ms, total: 1.59 s
Wall time: 1.6 s


In [12]:
def filter_qas(qa):
    """Remove bad targets"""
    if 'target' in [*qa.keys()]:
        start, end = qa['target']
        return qa['context'][start:end + 1].text == qa['answer'].text
    return False

In [13]:
%%time
train_qas = [*filter(filter_qas, train_qas)]
valid_qas = [*filter(filter_qas, valid_qas)]
print(f'Length of train qa pairs after filtering out bad qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs after filtering out bad qa pairs: {len(valid_qas):,}')

Length of train qa pairs after filtering out bad qa pairs: 86,597
Length of valid qa pairs after filtering out bad qa pairs: 34,295
CPU times: user 1.28 s, sys: 974 µs, total: 1.28 s
Wall time: 1.28 s


In [14]:
def test_targets(qas):
    for qa in qas:
        if 'target' in [*qa.keys()]:
            start, end = qa['target']
            assert qa['context'][start:end + 1].text == qa['answer'].text

In [15]:
%%time
test_targets(train_qas)
test_targets(valid_qas)

CPU times: user 1.24 s, sys: 966 µs, total: 1.24 s
Wall time: 1.25 s


***Add features***

In [16]:
def add_features(qas):
    """Add extra features: Exact Match, Part-of-Speech, Name Entity Recognition & Normalized Term Frequency"""
    for qa in tqdm.tqdm(qas):
        question = [token.text for token in qa['question']]
        context = qa['context']
        counts = collections.Counter(map(lambda token: token.text.lower(), context))
        freqs = {index: counts[token.text.lower()] for index, token in enumerate(context)}
        freqs_norm = sum(freqs.values())
        qa['em'], qa['pos'], qa['ner'], qa['ntf'] = zip(
            *map(lambda index: [
                context[index].text in question, context[index].tag_,
                context[index].ent_type_ or 'None',
                freqs[index] / freqs_norm
            ], range(len(context)))
        )

In [17]:
add_features(train_qas)
add_features(valid_qas)

100%|██████████| 86597/86597 [00:54<00:00, 1586.74it/s]
100%|██████████| 34295/34295 [00:21<00:00, 1594.11it/s]


***Build vocabularies***

In [18]:
class Vocab:

    def __init__(self):
        self.vocab = None
        self.word2count = None
        self.word2index = None
        self.index2word = None
    
    def build(self, data, specials):
        """
        :param List[Union[spacy.tokens.doc.Doc, str, Tuple]] data
        :param List[str] specials
        """
        words = specials
        type_0 = type(data[0])
        if type_0 == spacy.tokens.doc.Doc:
            for item in data: # context and question
                words += [word.text.lower() for word in item]
        elif type_0 == str: # id
            words += data
        elif type_0 == tuple: # pos and ner
            for item in data:
                words += [word.lower() for word in item]
        self.word2count = collections.Counter(words)
        self.vocab = sorted(self.word2count.keys())
        self.word2index = {word: index for index, word in enumerate(self.vocab)}
        self.index2word = {index: word for index, word in enumerate(self.vocab)}
    
    def __len__(self):
        return len(self.vocab)
    
    def stoi(self, word: str):
        return self.word2index[word]

    def itos(self, index: int):
        return self.index2word[index]

In [44]:
%%time
ID, POS, NER, TEXT = Vocab(), Vocab(), Vocab(), Vocab()

ids = [*map(lambda qa: qa['id'], train_qas)] + [*map(lambda qa: qa['id'], valid_qas)]
pos, ner, contexts, questions = zip(*map(lambda qa: (qa['pos'], qa['ner'], qa['context'], qa['question']), train_qas))

PAD_TOKEN = '<pad>'

ID.build(data=[*set(ids)], specials=[])
POS.build(data=[*set(pos)], specials=[PAD_TOKEN])
NER.build(data=[*set(ner)], specials=[PAD_TOKEN])
TEXT.build(data=[*set(contexts)] + [*set(questions)], specials=[PAD_TOKEN])

print(f'Length of ID vocabulary: {len(ID):,}')
print(f'Length of POS vocabulary: {len(POS):,}')
print(f'Length of NER vocabulary: {len(NER):,}')
print(f'Length of TEXT vocabulary: {len(TEXT):,}')

Length of ID vocabulary: 97,106
Length of POS vocabulary: 51
Length of NER vocabulary: 20
Length of TEXT vocabulary: 91,445
CPU times: user 6.51 s, sys: 95.9 ms, total: 6.61 s
Wall time: 6.61 s


***Build datasets***

In [45]:
class SQuADV1Dataset(Dataset):

    def __init__(self, data, id_vocab, pos_vocab, ner_vocab, text_vocab):
        self.data = data
        self.id_vocab = id_vocab
        self.pos_vocab = pos_vocab
        self.ner_vocab = ner_vocab
        self.text_vocab = text_vocab
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        id = torch.LongTensor([self.id_vocab.stoi(item['id'])])
        ctx = torch.LongTensor([*map(lambda token: self.text_vocab.stoi(token.text.lower()), item['context'])])
        qst = torch.LongTensor([*map(lambda token: self.text_vocab.stoi(token.text.lower()), item['question'])])
        trg = torch.LongTensor(item['target'])
        em = torch.LongTensor(item['em'])
        pos = torch.LongTensor([*map(lambda token: self.pos_vocab.stoi(token.lower()), item['pos'])])
        ner = torch.LongTensor([*map(lambda token: self.ner_vocab.stoi(token.lower()), item['ner'])])
        ntf = torch.FloatTensor(item['ntf'])
        return id, ctx, qst, trg, em, pos, ner, ntf

In [46]:
train_dataset = SQuADV1Dataset(data=train_qas, id_vocab=ID, pos_vocab=POS, ner_vocab=NER, text_vocab=TEXT)
valid_dataset = SQuADV1Dataset(data=valid_qas, id_vocab=ID, pos_vocab=POS, ner_vocab=NER, text_vocab=TEXT)

id, ctx, qst, trg, em, pos, ner, ntf = train_dataset[0]
print(f'id shape: {id.shape}')
print(f'ctx shape: {ctx.shape}')
print(f'qst shape: {qst.shape}')
print(f'trg shape: {trg.shape}')
print(f'em shape: {em.shape}')
print(f'pos shape: {pos.shape}')
print(f'ner shape: {ner.shape}')
print(f'ntf shape: {ntf.shape}')

id shape: torch.Size([1])
ctx shape: torch.Size([142])
qst shape: torch.Size([14])
trg shape: torch.Size([2])
em shape: torch.Size([142])
pos shape: torch.Size([142])
ner shape: torch.Size([142])
ntf shape: torch.Size([142])


***Build data loaders***

In [64]:
class DotDict(dict):
    """Dot notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [129]:
def add_padding(batch, pad_token=PAD_TOKEN, text_vocab=TEXT, pos_vocab=POS, ner_vocab=NER, include_lengths=True):
    """Pad batch of sequence with different lengths"""
    batch_id, batch_ctx, batch_qst, batch_trg, batch_em, batch_pos, batch_ner, batch_ntf = zip(*batch)
    if include_lengths:
        len_ctx = torch.LongTensor([ctx.size(0) for ctx in batch_ctx])
        len_qst = torch.LongTensor([qst.size(0) for qst in batch_qst])
    batch_padded_id = pad_sequence(batch_id, batch_first=True)
    batch_padded_ctx = pad_sequence(batch_ctx, batch_first=True, padding_value=text_vocab.stoi(pad_token))
    batch_padded_qst = pad_sequence(batch_qst, batch_first=True, padding_value=text_vocab.stoi(pad_token))
    batch_padded_trg = pad_sequence(batch_trg, batch_first=True)
    batch_padded_em = pad_sequence(batch_em, batch_first=True)
    batch_padded_pos = pad_sequence(batch_pos, batch_first=True, padding_value=pos_vocab.stoi(pad_token))
    batch_padded_ner = pad_sequence(batch_ner, batch_first=True, padding_value=ner_vocab.stoi(pad_token))
    batch_padded_ntf = pad_sequence(batch_ntf, batch_first=True)
    return DotDict({
        'id': batch_padded_id,
        'ctx': (batch_padded_ctx, len_ctx) if include_lengths else batch_padded_ctx,
        'qst': (batch_padded_qst, len_qst) if include_lengths else batch_padded_qst,
        'trg': batch_padded_trg,
        'em': batch_padded_em,
        'pos': batch_padded_pos,
        'ner': batch_padded_ner,
        'ntf': batch_padded_ntf,
    })

In [130]:
BATCH_SIZE = 64

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=add_padding, pin_memory=True)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=add_padding, pin_memory=True)

for batch in train_dataloader:
    print('batch.id shape:', batch.id.shape)
    print('batch.ctx shape:', batch.ctx[0].shape, batch.ctx[1].shape)
    print('batch.qst shape:', batch.qst[0].shape, batch.qst[1].shape)
    print('batch.trg shape:', batch.trg.shape)
    print('batch.em shape:', batch.em.shape)
    print('batch.pos shape:', batch.pos.shape)
    print('batch.ner shape:', batch.ner.shape)
    print('batch.ntf shape:', batch.ntf.shape)
    break

batch.id shape: torch.Size([64, 1])
batch.ctx shape: torch.Size([64, 253]) torch.Size([64])
batch.qst shape: torch.Size([64, 19]) torch.Size([64])
batch.trg shape: torch.Size([64, 2])
batch.em shape: torch.Size([64, 253])
batch.pos shape: torch.Size([64, 253])
batch.ner shape: torch.Size([64, 253])
batch.ntf shape: torch.Size([64, 253])


***TODO: Download pretrained GloVe embedding***

It will take about 16 minutes to download from Colab!

## Modeling

***Stacked Bidirectional LSTM Layer***

In [137]:
class StackedBiLSTMsLayer(nn.Module):

    def __init__(self, embedding_size, hidden_size, n_layers, dropout):
        super(StackedBiLSTMsLayer, self).__init__()
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = nn.Dropout(p=dropout)
        self.lstms = nn.ModuleList([nn.LSTM(embedding_size if i == 0 else hidden_size * 2, embedding_size,
                                            batch_first=True, num_layers=n_layers, bidirectional=True)
                                    for i in range(n_layers)])
    
    def apply_lstm(self, layer, inputs, lengths):
        """
        :param nn.LSTM layer
        :param FloatTensor[batch_size, seq_len, embedding_size] inputs
        :param LongTensor[batch_size, seq_len] lengths
        :return FloatTensor[batch_size, seq_len, hidden_size * 2] out_padded
        """
        packed = pack_padded_sequence(inputs, lengths, batch_first=True, enforce_sorted=False)
        out_packed, _ = layer(self.dropout(packed))
        out_padded, out_lengths = pad_packed_sequence(out_packed, batch_first=True) # [batch_size, seq_len, hidden_size * 2]
        return out_padded, out_lengths
    
    def forward(self, input_embedded, sequence_lengths):
        """
        :param FloatTensor[batch_size, seq_len, embedding_size] input_embedded
        :param LongTensor[batch_size, seq_len] sequence_lengths
        :return FloatTensor[batch_size, seq_len, hidden_size * n_layers * 2]
        """
        packed = nn.utils.rnn.pack_padded_sequence(input_embedded, sequence_lengths)
        outputs, lens = [packed], sequence_lengths
        for lstm in self.lstms:
            out, lens = self.apply_lstm(layer=lstm, inputs=outputs[-1], lengths=lens)
            outputs.append(out)
        return self.dropout(torch.cat(outputs[1:], dim=-1))

***Aligned Question Embedding Layer***

In [138]:
class AlignQuestionEmbeddingLayer(nn.Module):

    def __init__(self, hidden_size):
        super(AlignQuestionEmbeddingLayer, self).__init__()
        self.hidden_size = hidden_size
        self.linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, ctx_embed, qst_embed, qst_mask):
        """
        :param FloatTensor[batch_size, ctx_len, embedding_size] ctx_embed
        :param FloatTensor[batch_size, qst_len, embedding_size] qst_embed
        :param IntTensor[batch_size, qst_len] qst_mask
        :return FloatTensor[batch_size, ctx_len, hidden_size]
        """
        ctx_embed = F.relu(self.linear(ctx_embed)) # [batch_size, ctx_len, hidden_size]
        qst_embed = F.relu(self.linear(qst_embed)) # [batch_size, qst_len, hidden_size]
        scores = torch.bmm(ctx_embed, qst_embed.transpose(-1, -2)) # [batch_size, ctx_len, qst_len]
        scores = scores.masked_fill(qst_mask.unsqueeze(1) == 0, 1e-18)
        attention_weights = F.softmax(scores, dim=-1) # [batch_size, ctx_len, qst_len]
        return torch.bmm(attention_weights, qst_embed)

***Question Encoding Layer***

In [143]:
class QuestionEncodingLayer(nn.Module):

    def __init__(self, embedding_size, hidden_size, dropout, n_layers):
        super(QuestionEncodingLayer, self).__init__()
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.stacked_bilstms_layer = StackedBiLSTMsLayer(embedding_size=embedding_size, hidden_size=hidden_size, n_layers=n_layers, dropout=dropout)
        self.linear = nn.Linear(embedding_size, 1)
    
    def linear_self_attention(self, qst_embed, qst_mask):
        """
        :param FloatTensor[batch_size, qst_len, embedding_size] qst_embed
        :param IntTensor[batch_size, qst_len] qst_mask
        :return FloatTensor[batch_size, qst_len]
        """
        scores = self.linear(qst_embed).squeeze(-1) # [batch_size, qst_len]
        scores = scores.masked_fill(qst_mask == 0, 1e-18)
        return F.softmax(scores, dim=-1)

    
    def forward(self, qst_embed, qst_lengths, qst_mask):
        """
        :param FloatTensor[batch_size, qst_len, embedding_size] qst_embed
        :param IntTensor[batch_size, qst_len] qst_lengths
        :param IntTensor[batch_size, qst_len] qst_mask
        :return FloatTensor[batch_size, hidden_size * n_layers * 2]
        """
        attention_weights = self.linear_self_attention(qst_embed=qst_embed, qst_mask=qst_mask) # [batch_size, qst_len]
        lstm_outputs = self.stacked_bilstms_layer(input_embedded=qst_embed, sequence_lengths=qst_lengths) # [batch_size, qst_len, hidden_size * n_layers * 2]
        return torch.bmm(attention_weights.unsqueeze(1), lstm_outputs).squeeze(1)

***BiLinear Attention Layer***

In [144]:
class BiLinearAttentionLayer(nn.Module):

    def __init__(self, ctx_size, qst_size):
        super(BiLinearAttentionLayer, self).__init__()
        self.ctx_size = ctx_size
        self.qst_size = qst_size
        self.linear = nn.Linear(qst_size, ctx_size)
    
    def forward(self, ctx_encoded, qst_encoded, ctx_mask):
        """
        :param FloatTensor[batch_size, ctx_len, ctx_size] ctx_encoded
        :param FloatTensor[batch_size, qst_size] qst_encoded
        :param IntTensor[batch_size, ctx_len] ctx_mask
        :return FloatTensor[batch_size, ctx_len, hidden_size]
        """
        qst_encoded = self.linear(qst_encoded) # [batch_size, ctx_size]
        scores = torch.bmm(ctx_encoded, qst_encoded.unsqueeze(-1)) # [batch_size, ctx_len, 1]
        scores = scores.squeeze(-1).masked_fill(ctx_mask == 0, 1e-18) # [batch_size, ctx_len]
        return scores

***Document reader Question Answering Model***

In [None]:
class DrQA(nn.Module):

    def __init__(self, vocab_size, embedding_size, n_extra_features, hidden_size, n_layers, dropout, pad_index):
        super(DrQA, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.n_extra_features = n_extra_features
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = nn.Dropout(p=dropout)
        self.pad_index = pad_index
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_index)
        self.align_question_embedding_layer = AlignQuestionEmbeddingLayer(hidden_size=embedding_size)
        self.ctx_stacked_bi_lstm_layer = StackedBiLSTMsLayer(embedding_size=embedding_size * 2 + n_extra_features,
                                                             hidden_size=hidden_size, n_layers=n_layers, dropout=dropout)
        self.qst_encoding_layer = QuestionEncodingLayer(embedding_size=embedding_size, hidden_size=hidden_size, dropout=dropout, n_layers=n_layers)
        self.bilinear_attention_layer_start = BiLinearAttentionLayer(ctx_size=hidden_size * n_layers * 2, qst_size=hidden_size * n_layers * 2)
        self.bilinear_attention_layer_end = BiLinearAttentionLayer(ctx_size=hidden_size * n_layers * 2, qst_size=hidden_size * n_layers * 2)
    
    def make_ctx_mask(self, ctx_sequences):
        """
        :param LongTensor[batch_size, ctx_len] ctx_sequences
        :return IntTensor[batch_size, ctx_len]
        """
        return ctx_sequences != self.pad_index
    
    def make_qst_mask(self, qst_sequences):
        """
        :param LongTensor[batch_size, qst_len] qst_sequences
        :return IntTensor[batch_size, qst_len]
        """
        return qst_sequences != self.pad_index
    
    @staticmethod
    def decode(starts, ends):
        """
        :param IntTensor[batch_size, ctx_len] starts
        :param IntTensor[batch_size, ctx_len] ends
        :return list(int) start_indexes
        :return list(int) end_indexes
        :return list(float) pred_probas
        """
        start_indexes, end_indexes, pred_probas = [], [], []
        for i in range(starts.size(0)):
            probas = torch.ger(starts[i], ends[i]) # [ctx_len, ctx_len]
            proba, index = torch.topk(probas.view(-1), k=1)
            start_indexes.append(index.tolist()[0] // probas.size(0))
            end_indexes.append(index.tolist()[0] % probas.size(1))
            pred_probas.append(proba.tolist()[0])
        return start_indexes, end_indexes, pred_probas
    
    def forward(self, ctx_sequences, ctx_lengths, qst_sequences, qst_lengths, em_sequences, pos_sequences, ner_sequences, ntf_sequences):
        """
        :param LongTensor[batch_size, ctx_len] ctx_sequences
        :param Tensor[batch_size,] ctx_lengths
        :param LongTensor[batch_size, qst_len] qst_sequences
        :param Tensor[batch_size,] qst_lengths
        :param LongTensor[batch_size, ctx_len] em_sequences
        :param LongTensor[batch_size, ctx_len] pos_sequences
        :param LongTensor[batch_size, ctx_len] ner_sequences
        :param LongTensor[batch_size, ctx_len] ntf_sequences
        :return Tensor[batch_size, ctx_len] starts
        :return Tensor[batch_size, ctx_len] ends
        """
        ctx_mask = self.make_ctx_mask(ctx_sequences) # [batch_size, ctx_len]
        qst_mask = self.make_qst_mask(qst_sequences) # [batch_size, qst_len]
        ctx_embedded = self.dropout(self.embedding(ctx_sequences)) # [batch_size, ctx_len, embedding_size]
        qst_embedded = self.dropout(self.embedding(qst_sequences)) # [batch_size, ctx_len, embedding_size]
        ctx_aligned = self.align_question_embedding_layer(ctx_embed=ctx_embedded, qst_embed=qst_embedded,
                                                          qst_mask=qst_mask) # [batch_size, ctx_len, embedding_size]
        