## QA over unstructured data

Using Match LSTM, Pointer Networks, as mentioned in paper https://arxiv.org/pdf/1608.07905.pdf

We start with the pre-processing provided by https://github.com/MurtyShikhar/Question-Answering to clean up the data and make neat para, ques files.


### @TODOs:

1. [done] _Figure out how to put in real, pre-trained embeddings in embeddings layer._
2. [done] _Explicitly provide batch size when instantiating model_
3. is ./val.ids.* validation set or test set?: **validation**
4. [done:em] emInstead of test loss, calculate test acc metrics
    1. todo: new metrics like P, R, F1
5. Update unit test codes

In [1]:
# Codeblock to pull up embeddings. Needs to run before following imports
import numpy as np

# Macros 
DATA_LOC = './data/squad/'
EMBEDDING_FILE = 'glove.trimmed.300.npz'
VOCAB_FILE = 'vocab.dat'

file_loc = DATA_LOC + EMBEDDING_FILE
glove_file = np.load(open(file_loc))['glove']

In [2]:
from __future__ import unicode_literals, print_function, division
import matplotlib.pyplot as plt
from io import open
import numpy as np
import unicodedata
import string
import random
import time
import re
import os


import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

device = torch.device("cuda")

torch.manual_seed(42)
np.random.seed(42)

#### Debug Legend

- 5: Print everything that goes in every tensor.
- 4: ??
- 3: Check every model individually
- 2: Print things in training loops
- 1: ??

In [3]:
# Macros 
DATA_LOC = './data/squad/'
DEBUG = 2

# nn Macros
QUES_LEN, PARA_LEN =  30, 770
VOCAB_SIZE = glove_file.shape[1]                  # @TODO: get actual size
HIDDEN_DIM = 150
EMBEDDING_DIM = 300
BATCH_SIZE = 81                  # Might have total 100 batches.
EPOCHS = 1
TEST_EVERY_ = 1

### Encoder 
Use a simple lstm class to have encoder for question and paragraph. 
The output of these will be used in the match lstm

$H^p = LSTM(P)$ 


$H^q = LSTM(Q)$

In [4]:
class Encoder(nn.Module):
    
    def __init__(self, inputlen, macros, glove_file):
        super(Encoder, self).__init__()
        
        # Catch dim
        self.inputlen = inputlen
        self.hiddendim = macros['hidden_dim']
        self.embeddingdim =  macros['embedding_dim']
        self.vocablen = macros['vocab_size']
        
        self.batch_size = macros['batch_size']
        self.debug = macros['debug']
        
        # Embedding Layer
#         self.embedding = nn.Embedding(self.vocablen, self.embeddingdim)
        self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(glove_file))
        self.embedding.weight.requires_grad = False
       
        # LSTM Layer
        self.lstm = nn.LSTM(self.embeddingdim, self.hiddendim)
        
    def init_hidden(self, batch_size):
        
        # Returns a new hidden layer var for LSTM
        return (torch.zeros((1, batch_size, self.hiddendim), device=device), 
                torch.zeros((1, batch_size, self.hiddendim), device=device))
    
    def forward(self, x, h):
        
        # Input: x (batch, len ) (current input)
        # Hidden: h (1, batch, hiddendim) (last hidden state)
        
        # Batchsize: b int (inferred)
        b = x.shape[0]
        
        if self.debug > 4: print("x:\t", x.shape)
        if self.debug > 4: print("h:\t", h[0].shape, h[1].shape)
        
        x_emb = self.embedding(x)
        if self.debug > 4: print("x_emb:\t", x_emb.shape)
            
        ycap, h = self.lstm(x_emb.view(-1, b, self.embeddingdim), h)
        if self.debug > 4: print("ycap:\t", ycap.shape)
        
        return ycap, h
    
    
# with torch.no_grad():
#     print ("Trying out question encoder LSTM")
#     model = Encoder(QUES_LEN, HIDDEN_DIM, EMBEDDING_DIM, VOCAB_SIZE)
#     dummy_x = torch.tensor([22,45,12], dtype=torch.long)
#     hidden = model.init_hidden()
#     ycap, h = model(dummy_x, hidden)
    
#     print(ycap.shape)
#     print(h[0].shape, h[1].shape)


if DEBUG > 2:
    with torch.no_grad():

        dummy_para = torch.randint(0,VOCAB_SIZE-1,(PARA_LEN*BATCH_SIZE,), device=device).view(BATCH_SIZE,PARA_LEN).long()
    #     print (dummy_para.shape)
        dummy_question = torch.randint(0,VOCAB_SIZE-1,(QUES_LEN*BATCH_SIZE,), device=device).view(BATCH_SIZE,QUES_LEN).long()
    #     print (dummy_question.shape)

    #     print("LSTM with batches")
        ques_model = Encoder(QUES_LEN, HIDDEN_DIM, EMBEDDING_DIM, VOCAB_SIZE).cuda(device)
        para_model = Encoder(QUES_LEN, HIDDEN_DIM, EMBEDDING_DIM, VOCAB_SIZE).cuda(device)
        ques_hidden = ques_model.init_hidden()
        para_hidden = para_model.init_hidden()
        ques_embedded,hidden_ques = ques_model(dummy_question,ques_hidden)
        para_embedded,hidden_para = para_model(dummy_para,para_hidden)
        
        print (ques_embedded.shape) # question_length,batch,embedding_dim
        print (para_embedded.shape) # para_length,batch,embedding_dim
        print (hidden_para[0].shape,hidden_para[1].shape)

### Match LSTM

Use a match LSTM to compute a **summarized sequential vector** for the paragraph w.r.t the question.

Consider the summarized vector ($H^r$) as the output of a new decoder, where the inputs are $H^p, H^q$ computed above. 

1. Attend the para word $i$ with the entire question ($H^q$)
  
    1. $\vec{G}_i = tanh(W^qH^q + repeat(W^ph^p_i + W^r\vec{h^r_{i-1} + b^p}))$
    
    2. *Computing it*: Here, $\vec{G}_i$ is equivalent to `energy`, computed differently.
    
    3. Use a linear layer to compute the content within the $repeat$ fn.
    
    4. Add with another linear (without bias) with $H_q$
    
    5. $tanh$ the bloody thing
  
  
2. Softmax over it to get $\alpha$ weights.

    1. $\vec{\alpha_i} = softmax(w^t\vec{G}_i + repeat(b))$
    
3. Use the attention weight vector $\vec{\alpha_i}$ to obtain a weighted version of the question and concat it with the current token of the passage to form a vector $\vec{z_i}$

4. Use $\vec{z_i}$ to compute the desired $h^r_i$:

    1. $ h^r_i = LSTM(\vec{z_i}, h^r_{i-1}) $
    


In [5]:
class MatchLSTMEncoder(nn.Module):
    
    def __init__(self, macros):
        
        super(MatchLSTMEncoder, self).__init__()
        
        self.hidden_dim = macros['hidden_dim']
        self.ques_len = macros['ques_len']
        self.batch_size = macros['batch_size']
        self.debug = macros['debug']    
        
        # Catch lens and params
        self.lin_g_repeat = nn.Linear(2*self.hidden_dim, self.hidden_dim)
        self.lin_g_nobias = nn.Linear(self.hidden_dim, self.hidden_dim)
        
        self.alpha_i_w = nn.Parameter(torch.FloatTensor(self.hidden_dim, 1))
        self.alpha_i_b = nn.Parameter(torch.FloatTensor((1)))
        
        self.lstm_summary = nn.LSTM(self.hidden_dim*(self.ques_len+2), self.hidden_dim)
                                      
    
    def forward(self, H_p, h_ri, H_q, hidden):
        """
            Ideally, we would have manually unrolled the lstm 
            but due to memory constraints, we do it in the module.
        """
        
        # Find the batchsize
        batch_size = H_p.shape[1]
        
        H_r = torch.empty((0, batch_size, self.hidden_dim), device=device, dtype=torch.float)
        H_r = torch.cat((H_r, h_ri), dim=0)
        
        if self.debug > 4:
            print( "H_p:\t\t\t", H_p.shape)
            print( "h_ri:\t\t\t", h_ri.shape)
            print( "H_q:\t\t\t", H_q.shape)
        
        for i in range(H_p.shape[0]):
            
            lin_repeat_input = torch.cat((H_p[i].view(1, batch_size, -1), H_r[i].view(1, batch_size, -1)), dim=2)
            if self.debug > 4: print("lin_repeat_input:\t", lin_repeat_input.shape)

            lin_g_input_b = self.lin_g_repeat(lin_repeat_input)
            if self.debug > 4: print("lin_g_input_b unrepeated:", lin_g_input_b.shape)

            lin_g_input_b = lin_g_input_b.repeat(H_q.shape[0], 1, 1)
            if self.debug > 4: print("lin_g_input_b:\t\t", lin_g_input_b.shape)

            # lin_g_input_a = self.lin_g_nobias.matmul(H_q.view(-1, self.ques_len, self.hidden_dim)) #self.lin_g_nobias(H_q)
            lin_g_input_a =  self.lin_g_nobias(H_q)
            if self.debug > 4: print("lin_g_input_a:\t\t", lin_g_input_a.shape)

            G_i = F.tanh(lin_g_input_a + lin_g_input_b)
            if self.debug > 4: print("G_i:\t\t\t", G_i.shape)
            # Note; G_i should be a 1D vector over ques_len

            # Attention weights
            alpha_i_input_a = G_i.view(batch_size, -1, self.hidden_dim).matmul(self.alpha_i_w).view(batch_size, 1, -1)
            if self.debug > 4: print("alpha_i_input_a:\t", alpha_i_input_a.shape)

            alpha_i_input = alpha_i_input_a.add_(self.alpha_i_b.view(-1,1,1).repeat(1,1,self.ques_len))
            if self.debug > 4: print("alpha_i_input:\t\t", alpha_i_input.shape)

            # Softmax over alpha inputs
            alpha_i = F.softmax(alpha_i_input, dim=-1)
            if self.debug > 4: print("alpha_i:\t\t", alpha_i.shape)

            # Weighted summary of question with alpha    
            z_i_input_b = (
                            H_q.view(batch_size, self.ques_len, -1) *
                           (alpha_i.view(batch_size, self.ques_len, -1).repeat(1, 1, self.hidden_dim))
                          ).view(self.ques_len,batch_size, -1)
            if self.debug > 4: print("z_i_input_b:\t\t", z_i_input_b.shape)

            z_i = torch.cat((H_p[i].view(1, batch_size, -1), z_i_input_b), dim=0)
            if self.debug > 4: print("z_i:\t\t\t", z_i.shape)

            # Pass z_i, h_ri to the LSTM 
            lstm_input = torch.cat((z_i.view(1, batch_size,-1), H_r[i].view(1, batch_size, -1)), dim=2)
            if self.debug > 4: print("lstm_input:\t\t", lstm_input.shape)

            # Take input from LSTM, concat in H_r and nullify the temp var.
            h_ri, hidden = self.lstm_summary(lstm_input, hidden)
            H_r = torch.cat((H_r, h_ri), dim=0)
            h_ri = None
            
            if self.debug > 4:
                print("\tH_r:\t\t\t", H_r.shape)
#                 print("hidden new:\t\t", hidden[0].shape, hidden[1].shape)

        return H_r[1:]
    
    def init_hidden(self, batch_size):
        # Before we've done anything, we dont have any hidden state.
        # Refer to the Pytorch documentation to see exactly
        # why they have this dimensionality.
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (torch.zeros((1, batch_size, self.hidden_dim), device=device),
                torch.zeros((1, batch_size, self.hidden_dim), device=device))

# with torch.no_grad():
#     model = MatchLSTMEncoder(HIDDEN_DIM, QUES_LEN)
#     h_pi = torch.randn(1, BATCH_SIZE, HIDDEN_DIM)
#     h_ri = torch.randn(1, BATCH_SIZE, HIDDEN_DIM)
#     hidden = model.init_hidden()
#     H_q = torch.randn(QUES_LEN, BATCH_SIZE, HIDDEN_DIM)
    
#     op, hid = model(h_pi, h_ri, H_q, hidden)
    
#     print("\nDone:op", op.shape)
#     print("Done:hid", hid[0].shape, hid[1].shape)

if DEBUG > 2:
    with torch.no_grad():
        matchLSTMEncoder = MatchLSTMEncoder(HIDDEN_DIM, QUES_LEN).cuda(device)
        hidden = matchLSTMEncoder.init_hidden()
        para_embedded = torch.rand((PARA_LEN, BATCH_SIZE, HIDDEN_DIM), device=device)
        ques_embedded = torch.rand((QUES_LEN, BATCH_SIZE, HIDDEN_DIM), device=device)
        h_ri = torch.randn(1, BATCH_SIZE, HIDDEN_DIM, device=device)
    #     if DEBUG:
    #         print ("init h_ri shape is: ", h_ri.shape)
    #         print ("the para length is ", len(para_embedded))
        H_r = matchLSTMEncoder(para_embedded.view(-1,BATCH_SIZE,HIDDEN_DIM),
                               h_ri, 
                               ques_embedded, 
                               hidden)
        print("H_r: ", H_r.shape)
        
        
        

### Pointer Network

Using a ptrnet over $H_r$ to unfold and get most probable spans.
We use the **boundry model** to do that (predict start and end of seq).

A simple energy -> softmax -> decoder. Where softmaxed energy is supervised.

In [6]:
class PointerDecoder(nn.Module):
    
    def __init__(self, macros):
        super(PointerDecoder, self).__init__()
        
        # Keep args
        self.hidden_dim = macros['hidden_dim']
        self.batch_size = macros['batch_size']
        self.para_len = macros['para_len']
        self.debug = macros['debug']
        
        self.lin_f_repeat = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.lin_f_nobias = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
        
        self.beta_k_w = nn.Parameter(torch.FloatTensor(self.hidden_dim, 1))
        self.beta_k_b = nn.Parameter(torch.FloatTensor(1))
        
        self.lstm = nn.LSTM(self.hidden_dim*(PARA_LEN+1), self.hidden_dim)

    
    def init_hidden(self, batch_size):
        # Before we've done anything, we dont have any hidden state.
        # Refer to the Pytorch documentation to see exactly
        # why they have this dimensionality.
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (torch.zeros((1, batch_size, self.hidden_dim), device=device),
                torch.zeros((1, batch_size, self.hidden_dim), device=device))
    
    def forward(self, h_ak, H_r, hidden):
        
        # h_ak (current decoder's last op) (1,batch,hiddendim)
        # H_r (weighted summary of para) (P, batch, hiddendim)
        batch_size = H_r.shape[1]
        
        if self.debug > 4:
            print("h_ak:\t\t\t", h_ak.shape)
            print("H_r:\t\t\t", H_r.shape)
            print("hidden:\t\t\t", hidden[0].shape, hidden[1].shape)
            
        # Prepare inputs for the tanh used to compute energy
        f_input_b = self.lin_f_repeat(h_ak)
        if self.debug > 4: print("f_input_b unrepeated:  ", f_input_b.shape)
        
        #H_r shape is ([PARA_LEN, BATCHSIZE, EmbeddingDIM])
        f_input_b = f_input_b.repeat(H_r.shape[0], 1, 1)
        if self.debug > 4: print("f_input_b repeated:\t", f_input_b.shape)
            
        f_input_a = self.lin_f_nobias(H_r)
        if self.debug > 4: print("f_input_a:\t\t", f_input_a.shape)
            
        # Send it off to tanh now
        F_k = F.tanh(f_input_a+f_input_b)
        if self.debug > 4: print("F_k:\t\t\t", F_k.shape) #PARA_LEN,BATCHSIZE,EmbeddingDim
            
        # Attention weights
        beta_k_input_a = F_k.view(batch_size, -1, self.hidden_dim).matmul(self.beta_k_w).view(batch_size, 1, -1)
        if self.debug > 4: print("beta_k_input_a:\t\t", beta_k_input_a.shape)
            
        beta_k_input = beta_k_input_a.add_(self.beta_k_b.repeat(1,1,self.para_len))
        if self.debug > 4: print("beta_k_input:\t\t", beta_k_input.shape)
            
        beta_k = F.softmax(beta_k_input, dim=-1)
        if self.debug > 4: print("beta_k:\t\t\t", beta_k.shape)
            
        lstm_input_a = H_r.view(batch_size, self.para_len, -1) * (beta_k.view(batch_size, self.para_len, -1).repeat(1,1,self.hidden_dim))
        if self.debug > 4: print("lstm_input_a:\t\t", lstm_input_a.shape)
            
        lstm_input = torch.cat((lstm_input_a.view(1, batch_size,-1), h_ak.view(1, batch_size, -1)), dim=2)
        if self.debug > 4: print("lstm_input:\t\t", lstm_input.shape)
        
        h_ak, hidden = self.lstm(lstm_input, hidden)
        
        return h_ak, hidden, beta_k
            
if DEBUG > 2:
    with torch.no_grad():
        pointerDecoder = PointerDecoder(HIDDEN_DIM).cuda(device)
        h_ak = torch.randn(1,BATCH_SIZE,HIDDEN_DIM, device=device)
    #     H_r = torch.randn(PARA_LEN, BATCH_SIZE, HIDDEN_DIM)
        pointerHidden = pointerDecoder.init_hidden()
        h_ak, hidden, beta_k = pointerDecoder(h_ak, para_embedded, hidden)
        print (beta_k.shape)

# Pull the real data from disk.

Files stored in `./data/squad/train.ids.*`
Pull both train and test.

In [7]:
def prepare_data(data_loc, macros):
    """
        Given the dataloc and the data available in a specific format, it would pick the data up, and make trainable matrices,
        Harvest train_P, train_Q, train_Y, test_P, test_Q, test_Y matrices in this format
        
        **return_type**: np matrices
    """
    
    # Unpacking macros
    PARA_LEN = macros['para_len']
    QUES_LEN = macros['ques_len']
    
    train_q = np.asarray([[int(x) for x in datum.split()] for datum in list(open(os.path.join(data_loc, 'train.ids.question')))])
    train_p = np.asarray([[int(x) for x in datum.split()] for datum in list(open(os.path.join(data_loc, 'train.ids.context')))])
    train_y = np.asarray([[int(x) for x in datum.split()] for datum in list(open(os.path.join(data_loc, 'train.span')))])

    test_q = np.asarray([[int(x) for x in datum.split()] for datum in list(open(os.path.join(data_loc, 'val.ids.question')))])
    test_p = np.asarray([[int(x) for x in datum.split()] for datum in list(open(os.path.join(data_loc, 'val.ids.context')))])
    test_y = np.asarray([[int(x) for x in datum.split()] for datum in list(open(os.path.join(data_loc, 'val.span')))])

    print("Train Q: ", train_q.shape)
    print("Train P: ", train_p.shape)
    print("Train Y: ", train_y.shape)
    print("Test Q: ", test_q.shape)
    print("Test P: ", test_p.shape)
    print("Test Y: ", test_y.shape)
    
    """
        Parse the semi-raw data:
            - shuffle
            - pad, prepare
            - dump useless vars
    """
    # Shuffle data
    index_train, index_test = np.arange(len(train_p)), np.arange(len(test_p))
    np.random.shuffle(index_train)
    np.random.shuffle(index_test)

    train_p, train_q, train_y = train_p[index_train], train_q[index_train], train_y[index_train]
    test_p, test_q, test_y = test_p[index_test], test_q[index_test], test_y[index_test]

#     sanity_check(train_p, train_y)
    
    # Pad and prepare
    train_P = np.zeros((len(train_p), PARA_LEN))
    train_Q = np.zeros((len(train_q), QUES_LEN))
    train_Y_start = np.zeros((len(train_p), PARA_LEN))
    train_Y_end = np.zeros((len(train_p), PARA_LEN))

    test_P = np.zeros((len(test_p), PARA_LEN))
    test_Q = np.zeros((len(test_q), QUES_LEN))
    test_Y_start = np.zeros((len(test_p), PARA_LEN))
    test_Y_end = np.zeros((len(test_p), PARA_LEN))

    crop_train = []    # Remove these rows from training
    for i in range(len(train_p)):
        p = train_p[i]
        q = train_q[i]
        y = train_y[i]

        # First see if you can keep this example or not (due to size)
        if y[0] > PARA_LEN or y[1] > PARA_LEN:
            crop.append(i)
            continue


        train_P[i, :min(PARA_LEN, len(p))] = p[:min(PARA_LEN, len(p))]
        train_Q[i, :min(QUES_LEN, len(q))] = p[:min(QUES_LEN, len(q))]
        train_Y_start[i, y[0]] = 1
        train_Y_end[i, y[1]] = 1

    crop_test = []
    for i in range(len(test_p)):
        p = test_p[i]
        q = test_q[i]
        y = test_y[i]

        # First see if you can keep this example or not (due to size)
        if y[0] > PARA_LEN or y[1] > PARA_LEN:
            crop.append(i)
            continue

        test_P[i, :min(PARA_LEN, len(p))] = p[:min(PARA_LEN, len(p))]
        test_Q[i, :min(QUES_LEN, len(q))] = p[:min(QUES_LEN, len(q))]
        test_Y_start[i, y[0]] = 1
        test_Y_end[i, y[1]] = 1


    # Let's free up some memory now
    train_p, train_q, train_y, test_p, test_q, test_y = None, None, None, None, None, None
    
    return train_P, train_Q, train_Y_start, train_Y_end, test_P, test_Q, test_Y_start, test_Y_end

# Training, and running the model
- Write a train fn
- Write a training loop invoking it
- Fill in real data

----------

Feats:
- Function to test every n epochs.
- Report train accuracy every epoch
- Store the train, test accuracy for every instance.


In [8]:
def train(para_batch,
          ques_batch,
          answer_start_batch,
          answer_end_batch,
          ques_model,
          para_model,
          match_LSTM_encoder_model,
          pointer_decoder_model,
          optimizer, 
          loss_fn,
          macros,
          debug=2):

    """
    
    :param para_batch: paragraphs (batch, max_seq_len_para) 
    :param ques_batch: questions corresponding to para (batch, max_seq_len_ques)
    :param answer_start_batch: one-hot vector denoting pos of span start (batch, max_seq_len_para)
    :param answer_end_batch: one-hot vector denoting pos of span end (batch, max_seq_len_para)
    
    # Models
    :param ques_model: model to encode ques
    :param para_model: model to encode para
    :param match_LSTM_encoder_model: model to match para, ques to get para summary
    :param pointer_decoder_model: model to get a pointer over start and end span pointer
    
    # Loss and Optimizer.
    :param loss_fn: 
    :param optimizer: 
    
    :return: 
    
    
    NOTE: When using MSE, 
        - target labels are one-hot
        - target label is float tensor
        - shape (batch, 1, len)
        
        When using CrossEntropy
        - target is not onehot
        - long
        - shape (batch, )
    """
    
#     DEBUG = debug
#     BATCH_SIZE = macros['batch_size']
#     HIDDEN_DIM = macros['hidden_dim']
    
    if debug >=2: 
        print("\tpara_batch:\t\t", para_batch.shape)
        print("\tques_batch:\t\t", ques_batch.shape)
        print("\tanswer_start_batch:\t", answer_start_batch.shape)
        print("\tanswer_end_batch:\t\t", answer_end_batch.shape)
    
    # Wiping all gradients
    optimizer.zero_grad()
    
    # Initializing all hidden states.
    hidden_quesenc = ques_model.init_hidden(macros['batch_size'])
    hidden_paraenc = para_model.init_hidden(macros['batch_size'])
    hidden_mlstm = match_LSTM_encoder_model.init_hidden(macros['batch_size'])
    hidden_ptrnet = pointer_decoder_model.init_hidden(macros['batch_size'])
    h_ri = torch.zeros((1, macros['batch_size'], macros['hidden_dim']), dtype=torch.float, device=device)
    h_ak = torch.zeros((1, macros['batch_size'], macros['hidden_dim']), dtype=torch.float, device=device)
    if debug >= 2: print("------------Instantiated hidden states------------")
    
    #passing the data through LSTM pre-processing layer
    H_q, ques_model_hidden = ques_model(ques_batch, hidden_quesenc)
    H_p, para_model_hidden = para_model(para_batch, hidden_paraenc)
    if debug >= 2: 
        print("\tH_q:\t\t", H_q.shape)
        print("\tH_p:\t\t", H_p.shape)
        print("\tH_ri:\t\t", h_ri.shape)
#         raw_input("Check memory and ye shall continue")
        print("------------Encoded hidden states------------")
    
    H_r = match_LSTM_encoder_model(H_p.view(-1, macros['batch_size'], macros['hidden_dim']), h_ri, H_q, hidden_mlstm)
    if debug >= 2: print("------------Passed through matchlstm------------")
    
    #Passing the paragraph embddin via pointer network to generate final answer pointer.
    h_ak, hidden_ptrnet, beta_k_start = pointer_decoder_model(h_ak, H_r, hidden_ptrnet)
    h_ak, hidden_ptrnet, beta_k_end = pointer_decoder_model(h_ak, H_r, hidden_ptrnet)
    if debug >= 2: print("------------Passed through pointernet------------")

        
    # For crossentropy
#     _, answer_start_batch = answer_start_batch.max(dim=1)
#     _, answer_end_batch = answer_end_batch.max(dim=1)
#     print("labels: ", answer_start_batch.shape)
    
    
    #How will we manage batches for loss.
    loss = loss_fn(beta_k_start, answer_start_batch)
    loss += loss_fn(beta_k_end, answer_end_batch)
    if debug >= 2: print("------------Calculated loss------------")
    
    loss.backward()
    if debug >= 2: print("------------Calculated Gradients------------")
    
    #optimization step
    optimizer.step()
    if debug >= 2: print("------------Updated weights.------------")
    
    return loss

In [9]:
# Predict function (no grad, no eval)
def predict(para_batch,
            ques_batch,
            ques_model,
            para_model,
            match_LSTM_encoder_model,
            pointer_decoder_model,
            macros,
            debug):
    """
        Function which returns the model's output based on a given set of P&Q's. 
        Does not convert to strings, gives the direct model output.
        
        Expects:
            four models
            data
            misc macros
    """
    
    BATCH_SIZE = macros['batch_size']
    HIDDEN_DIM = macros['hidden_dim']
    DEBUG = debug
    
    if debug >=2: 
        print("\tpara_batch:\t\t", para_batch.shape)
        print("\tques_batch:\t\t", ques_batch.shape)
        
    with torch.no_grad():    

        # Initializing all hidden states.
        hidden_quesenc = ques_model.init_hidden(BATCH_SIZE)
        hidden_paraenc = para_model.init_hidden(BATCH_SIZE)
        hidden_mlstm = match_LSTM_encoder_model.init_hidden(BATCH_SIZE)
        hidden_ptrnet = pointer_decoder_model.init_hidden(BATCH_SIZE)
        h_ri = torch.zeros((1, BATCH_SIZE, HIDDEN_DIM), dtype=torch.float, device=device)
        h_ak = torch.zeros((1, BATCH_SIZE, HIDDEN_DIM), dtype=torch.float, device=device)
        if DEBUG >= 2: print("------------Instantiated hidden states------------")
            
        #passing the data through LSTM pre-processing layer
        H_q, ques_model_hidden = ques_model(ques_batch, hidden_quesenc)
        H_p, para_model_hidden = para_model(para_batch, hidden_paraenc)
        if DEBUG >= 2: 
            print("\tH_q:\t\t", H_q.shape)
            print("\tH_p:\t\t", H_p.shape)
            print("\tH_ri:\t\t", h_ri.shape)
#             raw_input("Check memory and ye shall continue")
            print("------------Encoded hidden states------------")

        H_r = match_LSTM_encoder_model(H_p.view(-1, BATCH_SIZE, HIDDEN_DIM), h_ri, H_q, hidden_mlstm)
        if DEBUG >= 2: print("------------Passed through matchlstm------------")

        #Passing the paragraph embddin via pointer network to generate final answer pointer.
        _, _, beta_k_start = pointer_decoder_model(h_ak, H_r, hidden_ptrnet)
        _, _, beta_k_end = pointer_decoder_model(h_ak, H_r, hidden_ptrnet)
        if DEBUG >= 2: print("------------Passed through pointernet------------")
                            
        # For crossentropy
#         _, answer_start_batch = answer_start_batch.max(dim=1)
#         _, answer_end_batch = answer_end_batch.max(dim=1)
#         print("labels: ", answer_start_batch.shape)
            
        #How will we manage batches for loss.
        loss = loss_fn(beta_k_start, answer_start_batch)
        loss += loss_fn(beta_k_end, answer_end_batch)
        if debug >= 2: print("------------Calculated loss------------")
            
        return (beta_k_start, beta_k_end, loss)


In [10]:
# Eval function (no grad no eval no nothing)
def eval(y_cap, y, metrics={'em':None}):
    """ 
        Returns the exact-match (em) metric by default.
        Can specifiy more in a list (TODO)
        
        Inputs:
        - y_cap: list of two tensors (start, end) of dim [BATCH_SIZE, PARA_LEN] each
        - y: list of two tensors (start, end) of dim [BATCH_SIZE, 1] each
    """
    y_cap_max_start, y_cap_max_end = torch.argmax(y_cap[0], dim=1).float(), \
                                     torch.argmax(y_cap[1], dim=1).float()
    
    if "em" in metrics.keys():
        metrics['em'] = (y[0].eq(y_cap_max_start) & y[1].eq(y_cap_max_end)).sum().item()/ float(y[0].shape[0])
        
    if DEBUG >= 2: 
        print("Test performance: ", metrics)
        print("------------Evaluated------------")
        
    return metrics

if DEBUG >= 5:
    # Testing this function
    metrics = {'em':None}
    y = torch.randint(0, PARA_LEN, (BATCH_SIZE,)).float(), torch.randint(0, PARA_LEN, (BATCH_SIZE,)).float()
    y_cap = torch.rand((BATCH_SIZE, PARA_LEN)), torch.rand((BATCH_SIZE, PARA_LEN))
    print(eval(y_cap, y))   

In [11]:
def training_loop(_models, _data, _macros, _epochs=EPOCHS, _save_best=False, _test_every=0, _debug=2):
    """
        > Instantiate models
        > Instantiate loss, optimizer
        > Instantiate ways to store loss

        > Per epoch
            > sample batch and give to train fn
            > get loss
            > if epoch %k ==0: get test accuracy

        > have fn to calculate test accuracy
    """

    # Unpack data
    DEBUG = _debug
    train_P = _data['train']['P']
    train_Q = _data['train']['Q']
    train_Y_start = _data['train']['Ys']
    train_Y_end = _data['train']['Ye']
    test_P = _data['test']['P']
    test_Q = _data['test']['Q']
    test_Y_start = _data['test']['Ys']
    test_Y_end = _data['test']['Ye']
                                 
    ques_model, para_model, match_LSTM_encoder_model, pointer_decoder_model = _models
    _data = None

    # Instantiate Loss
    loss_fn = nn.MSELoss()
    optimizer = optim.Adamax(list(filter(lambda p: p.requires_grad, ques_model.parameters())) + 
                             list(filter(lambda p: p.requires_grad, para_model.parameters())) + 
                             list(match_LSTM_encoder_model.parameters()) + 
                             list(pointer_decoder_model.parameters()))

    # Losses
    train_losses = []
    test_losses = []
    test_em = []

    # Training Loop
    for epoch in range(_epochs):
        print("Epoch: ", epoch, "/", _epochs)

        epoch_loss = []
        epoch_time = time.time()

        for iter in range(int(len(train_P)/BATCH_SIZE)):
#         for iter in range(4):

            batch_time = time.time()

            # Sample batch and train on it
            sample_index = np.random.randint(0, len(train_P), _macros['batch_size'])

            loss = train(
                para_batch = torch.tensor(train_P[sample_index], dtype=torch.long, device=device),
                ques_batch = torch.tensor(train_Q[sample_index], dtype=torch.long, device=device),
                answer_start_batch = torch.tensor(train_Y_start[sample_index], dtype=torch.float, device=device).view( _macros['batch_size'], 1, _macros['para_len']),
                answer_end_batch = torch.tensor(train_Y_end[sample_index], dtype=torch.float, device=device).view(_macros['batch_size'], 1, _macros['para_len']),
                ques_model = ques_model,
                para_model = para_model,
                match_LSTM_encoder_model = match_LSTM_encoder_model,
                pointer_decoder_model = pointer_decoder_model,
                optimizer = optimizer, 
                loss_fn= loss_fn,
                macros=_macros,
                debug=_macros['debug']
            )
            

            epoch_loss.append(loss.item())
            print("Batch:\t%d" % iter,"/%d\t: " % (len(train_P)/_macros['batch_size']),
                  "%s" % (time.time() - batch_time), 
                  "\t%s" % (time.time() - epoch_time), 
                  "\tloss:%f" % loss.item())
#                   end=None if iter+1 == 4 else "\r")
#                   end=None if iter+1 == int(len(train_P)/BATCH_SIZE) else "\r")
                 
#         print("Time taken in epoch: %s" % (time.time() - epoch_time))
        train_losses.append(epoch_loss)

        if _test_every and epoch % _test_every == 0:
            
            
            y_cap_start, y_cap_end, test_loss = predict(
                para_batch = torch.tensor(test_P, dtype=torch.long, device=device),
                ques_batch = torch.tensor(test_Q, dtype=torch.long, device=device),
                ques_model = ques_model,
                para_model = para_model,
                match_LSTM_encoder_model = match_LSTM_encoder_model,
                pointer_decoder_model = pointer_decoder_model,
                macros = _macros,
                debug = _macros['debug']
            )
            metrics = eval(y=(torch.tensor(test_Y_start, dtype=torch.long, device=device).view( -1, _macros['para_len']),
                         torch.tensor(test_Y_end, dtype=torch.long, device=device).view(-1, _macros['para_len'])),
                      y_cap=[y_cap_start, y_cap_end])
            
            test_losses.append(test_loss)
            test_em.append(metrics['em'])
            
        
    return train_losses, test_losses, test_em
            
            
            

In [12]:
def visualize_loss(loss, _label="Some label", _only_epoch=True):
    """
        Fn to visualize loss.
        Expects either
            - [int, int] for epoch level stuff
            - [ [int, int], [int, int] ] for batch level data. 
    """
    
    plt.rcParams['figure.figsize'] = [15, 8] 
    
    # Detect input format
    if type(loss[0]) == int:
        
        plt.plot(loss)
        plt.ylabel(_label)
        plt.show()
        
    elif type(loss[0]) == list:
        
        if _only_epoch:
            loss = [ sum(x) for x in loss ]
            
        else:
            loss = [ y for x in loss for y in x ]
            
        plt.plot(loss)
        plt.ylabel(_label)
        plt.show()        

## Orchestrator

One cell which instantiates and runs everything

In [None]:
"""
    Cell which pulls everything together.

    > init models
    > get data prepared
    > pass models and data to training loop
    > gets trained models and loss
    > saves models
    > visualizes loss?

No other function but this one ever sees global macros!
"""
macros = {
    "ques_len": QUES_LEN,
    "hidden_dim": HIDDEN_DIM, 
    "vocab_size": VOCAB_SIZE, 
    "batch_size": BATCH_SIZE,
    "para_len": PARA_LEN,
    "embedding_dim": EMBEDDING_DIM,
    "debug": 1
} 

data = {'train':{}, 'test':{}}
data['train']['P'], data['train']['Q'], data['train']['Ys'], data['train']['Ye'], \
data['test']['P'], data['test']['Q'], data['test']['Ys'], data['test']['Ye'] = \
    prepare_data(DATA_LOC, macros)

# # Instantiate models
ques_model = Encoder(QUES_LEN, macros, glove_file).cuda(device)
para_model = Encoder(PARA_LEN, macros, glove_file).cuda(device)
match_LSTM_encoder_model = MatchLSTMEncoder(macros).cuda(device)
pointer_decoder_model = PointerDecoder(macros).cuda(device)

# # Instantiate models
# ques_model = Encoder(QUES_LEN, macros, glove_file)
# para_model = Encoder(PARA_LEN, macros, glove_file)
# match_LSTM_encoder_model = MatchLSTMEncoder(macros)
# pointer_decoder_model = PointerDecoder(macros)

op = training_loop(_models=[ques_model, para_model, match_LSTM_encoder_model, pointer_decoder_model],
                       _data=data,
                       _debug=macros['debug'],
                      _save_best=True,
#                           _test_every=TEST_EVERY_,
                       _test_every=0,
                      _epochs=EPOCHS,
                      _macros=macros)    

Train Q:  (81403,)
Train P:  (81403,)
Train Y:  (81403, 2)
Test Q:  (4285,)
Test P:  (4285,)
Test Y:  (4285, 2)
Epoch:  0 / 1
Batch:	0 /1004	:  14.0413222313 	14.0414600372 	loss:0.005163
Batch:	1 /1004	:  9.0444560051 	23.0869269371 	loss:0.005034
Batch:	2 /1004	:  6.06823992729 	29.155602932 	loss:0.005131
Batch:	3 /1004	:  6.33431410789 	35.4911708832 	loss:0.005163
Batch:	4 /1004	:  6.00137805939 	41.493792057 	loss:0.005163
Batch:	5 /1004	:  6.22808694839 	47.7224519253 	loss:0.005163
Batch:	6 /1004	:  6.36011505127 	54.0839819908 	loss:0.005195
Batch:	7 /1004	:  6.34052300453 	60.4247879982 	loss:0.005195
Batch:	8 /1004	:  6.04881191254 	66.4752149582 	loss:0.005163
Batch:	9 /1004	:  6.19745612144 	72.6738300323 	loss:0.005195
Batch:	10 /1004	:  6.2030351162 	78.8786859512 	loss:0.005195
Batch:	11 /1004	:  6.26238012314 	85.1426079273 	loss:0.005195
Batch:	12 /1004	:  6.41577196121 	91.5586099625 	loss:0.005195
Batch:	13 /1004	:  6.13599395752 	97.6958420277 	loss:0.005195
Batch:

Batch:	129 /1004	:  6.08014798164 	814.191941977 	loss:0.005195
Batch:	130 /1004	:  6.12432813644 	820.316874981 	loss:0.005195
Batch:	131 /1004	:  6.21467995644 	826.532151937 	loss:0.005163
Batch:	132 /1004	:  5.91621398926 	832.448695898 	loss:0.005195
Batch:	133 /1004	:  6.31686210632 	838.766077995 	loss:0.005195
Batch:	134 /1004	:  6.21613407135 	844.983551025 	loss:0.005195
Batch:	135 /1004	:  6.07984399796 	851.065192938 	loss:0.005195
Batch:	136 /1004	:  6.6127409935 	857.679018021 	loss:0.005163
Batch:	137 /1004	:  6.13017892838 	863.810738087 	loss:0.005195
Batch:	138 /1004	:  6.40625405312 	870.217717886 	loss:0.005195
Batch:	139 /1004	:  6.27568817139 	876.493800879 	loss:0.005195
Batch:	140 /1004	:  6.16348719597 	882.659224987 	loss:0.005195
Batch:	141 /1004	:  6.25278496742 	888.913482904 	loss:0.005195
Batch:	142 /1004	:  6.14163804054 	895.056382895 	loss:0.005163
Batch:	143 /1004	:  6.10298991203 	901.160661936 	loss:0.005195
Batch:	144 /1004	:  6.3518910408 	907.513

Batch:	258 /1004	:  6.14872694016 	1615.149297 	loss:0.005131
Batch:	259 /1004	:  6.0907227993 	1621.24024892 	loss:0.005131
Batch:	260 /1004	:  6.16551589966 	1627.40696406 	loss:0.005195
Batch:	261 /1004	:  6.31291794777 	1633.72010493 	loss:0.005163
Batch:	262 /1004	:  6.09975790977 	1639.82084489 	loss:0.005131
Batch:	263 /1004	:  6.23997998238 	1646.06238103 	loss:0.005195
Batch:	264 /1004	:  6.10430502892 	1652.16752601 	loss:0.005195
Batch:	265 /1004	:  6.18547296524 	1658.35416389 	loss:0.005195
Batch:	266 /1004	:  6.17042899132 	1664.52495503 	loss:0.005195
Batch:	267 /1004	:  6.06668305397 	1670.59192991 	loss:0.005131
Batch:	268 /1004	:  6.36845493317 	1676.961411 	loss:0.005195
Batch:	269 /1004	:  6.33138990402 	1683.29385209 	loss:0.005163
Batch:	270 /1004	:  6.14858603477 	1689.44327903 	loss:0.005195
Batch:	271 /1004	:  6.04453492165 	1695.48946404 	loss:0.005195
Batch:	272 /1004	:  6.05099892616 	1701.54168892 	loss:0.005099
Batch:	273 /1004	:  6.05780506134 	1707.60032

Batch:	387 /1004	:  6.20923399925 	2414.62722993 	loss:0.005099
Batch:	388 /1004	:  6.07170987129 	2420.69950604 	loss:0.005195
Batch:	389 /1004	:  6.06850409508 	2426.76903391 	loss:0.005195
Batch:	390 /1004	:  6.19575595856 	2432.96605492 	loss:0.005163
Batch:	391 /1004	:  6.32635211945 	2439.29383898 	loss:0.005131
Batch:	392 /1004	:  6.0762860775 	2445.37135506 	loss:0.005195
Batch:	393 /1004	:  6.28105902672 	2451.65331888 	loss:0.005195
Batch:	394 /1004	:  6.39584803581 	2458.05032897 	loss:0.005195
Batch:	395 /1004	:  5.99245095253 	2464.04414988 	loss:0.005163
Batch:	396 /1004	:  6.09572696686 	2470.14327598 	loss:0.005195
Batch:	397 /1004	:  6.07305502892 	2476.2176621 	loss:0.005195
Batch:	398 /1004	:  6.22309494019 	2482.4412961 	loss:0.005195
Batch:	399 /1004	:  6.04945397377 	2488.49196696 	loss:0.005195
Batch:	400 /1004	:  6.4056019783 	2494.89897609 	loss:0.005195
Batch:	401 /1004	:  6.13421297073 	2501.03442788 	loss:0.005195
Batch:	402 /1004	:  6.16437792778 	2507.1991

Batch:	516 /1004	:  6.5406639576 	3216.70316195 	loss:0.005195
Batch:	517 /1004	:  6.16429901123 	3222.86875796 	loss:0.005163
Batch:	518 /1004	:  6.47137093544 	3229.34196591 	loss:0.005195
Batch:	519 /1004	:  6.37333893776 	3235.71664596 	loss:0.005195
Batch:	520 /1004	:  6.18980622292 	3241.90771508 	loss:0.005195
Batch:	521 /1004	:  6.41914510727 	3248.32730389 	loss:0.005195
Batch:	522 /1004	:  6.10790205002 	3254.43649697 	loss:0.005131
Batch:	523 /1004	:  6.09047198296 	3260.52825904 	loss:0.005195
Batch:	524 /1004	:  6.00865888596 	3266.5380249 	loss:0.005195
Batch:	525 /1004	:  6.58293199539 	3273.12139201 	loss:0.005195
Batch:	526 /1004	:  6.02468705177 	3279.1464951 	loss:0.005195
Batch:	527 /1004	:  5.96345186234 	3285.11011004 	loss:0.005195
Batch:	528 /1004	:  6.12014102936 	3291.23157096 	loss:0.005195
Batch:	529 /1004	:  6.38738012314 	3297.62050605 	loss:0.005195
Batch:	530 /1004	:  6.21269893646 	3303.83344889 	loss:0.005195
Batch:	531 /1004	:  6.11884284019 	3309.952

Batch:	645 /1004	:  6.09649085999 	4019.52063203 	loss:0.005195
Batch:	646 /1004	:  6.14832305908 	4025.67012787 	loss:0.005163
Batch:	647 /1004	:  6.18088293076 	4031.85187793 	loss:0.005195
Batch:	648 /1004	:  6.10018205643 	4037.95294905 	loss:0.005195
Batch:	649 /1004	:  6.28775000572 	4044.24092889 	loss:0.005163
Batch:	650 /1004	:  6.24008584023 	4050.48159909 	loss:0.005131
Batch:	651 /1004	:  6.43125700951 	4056.91343689 	loss:0.005163
Batch:	652 /1004	:  6.10264587402 	4063.01806903 	loss:0.005163
Batch:	653 /1004	:  6.37452292442 	4069.39390588 	loss:0.005163
Batch:	654 /1004	:  6.29731893539 	4075.69258404 	loss:0.005131
Batch:	655 /1004	:  6.20227599144 	4081.89608002 	loss:0.005195
Batch:	656 /1004	:  6.35914802551 	4088.25690198 	loss:0.005195
Batch:	657 /1004	:  6.11828899384 	4094.37695289 	loss:0.005195
Batch:	658 /1004	:  6.56483602524 	4100.94306803 	loss:0.005195
Batch:	659 /1004	:  6.33607411385 	4107.2811079 	loss:0.005163
Batch:	660 /1004	:  6.15590000153 	4113.4

Batch:	774 /1004	:  6.5792889595 	4861.47758007 	loss:0.005195
Batch:	775 /1004	:  6.78965497017 	4868.26903605 	loss:0.005131
Batch:	776 /1004	:  6.79726696014 	4875.0678699 	loss:0.005195
Batch:	777 /1004	:  6.56897091866 	4881.63810587 	loss:0.005163
Batch:	778 /1004	:  6.46389508247 	4888.10346389 	loss:0.005195
Batch:	779 /1004	:  6.71716809273 	4894.82162499 	loss:0.005163
Batch:	780 /1004	:  6.75686907768 	4901.57979894 	loss:0.005195
Batch:	781 /1004	:  6.48928499222 	4908.07054996 	loss:0.005163
Batch:	782 /1004	:  6.73110389709 	4914.80304003 	loss:0.005131
Batch:	783 /1004	:  6.68858098984 	4921.49206495 	loss:0.005195
Batch:	784 /1004	:  6.28634881973 	4927.779598 	loss:0.005131
Batch:	785 /1004	:  6.76571297646 	4934.54660797 	loss:0.005163
Batch:	786 /1004	:  6.657351017 	4941.20492196 	loss:0.005163
Batch:	787 /1004	:  6.78531599045 	4947.99131203 	loss:0.005163
Batch:	788 /1004	:  6.15209507942 	4954.14433503 	loss:0.005195
Batch:	789 /1004	:  6.88902902603 	4961.034470

In [15]:
op[0]


[[0.005162738263607025,
  0.005034471862018108,
  0.00513067189604044,
  0.005162738263607025,
  0.005162738263607025,
  0.005162738263607025,
  0.005194805096834898,
  0.005194805096834898,
  0.005162738263607025,
  0.005194805096834898,
  0.005194805096834898,
  0.005194805096834898,
  0.005194805096834898,
  0.005194805096834898,
  0.005194805096834898,
  0.005162738263607025,
  0.005162738263607025,
  0.005194805096834898,
  0.005194805096834898,
  0.005162738263607025,
  0.005194805096834898,
  0.005194805096834898,
  0.005098605062812567,
  0.00513067189604044,
  0.005162738263607025,
  0.005194805096834898,
  0.005194805096834898,
  0.005162738263607025,
  0.005194805096834898,
  0.005194805096834898,
  0.005162738263607025,
  0.005194805096834898,
  0.005194805096834898,
  0.005162738263607025,
  0.005162738263607025,
  0.005162738263607025,
  0.00513067189604044,
  0.005194805096834898,
  0.005194805096834898,
  0.005162738263607025,
  0.005194805096834898,
  0.005130671896040

#### Visualizations

So far, we plot the training losss. 
Shall we superimpose test loss on it too? We don't calculate test loss per batch though (fortunately).

In [None]:
# Visualizations
print("Training Loss")
visualize_loss(op[0], "train loss", _only_epoch=False)

if len(op[1]) > 0:

    print("Validation Loss")
    visualize_loss(op[1], "validation loss")
