In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import pickle
import os


class Coattention_Encoder(nn.Module):
    def __init__(self, hidden_dim, maxout_pool_size, embedding_matrix, max_number_of_iterations, dropout_ratio):
        super(Coattention_Encoder, self).__init__()
        self.hidden_dim = hidden_dim

#         self.encoder = Word_Level_Encoder(hidden_dim, embedding_matrix, dropout_ratio)

        ## nn.Linear(input_dim, output_dim)
        self.question_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.fusion_bilstm = Fusion_BiLSTM(hidden_dim, dropout_ratio)
#         self.decoder = Dynamic_Decoder(hidden_dim, maxout_pool_size, max_number_of_iterations, dropout_ratio)
        self.dropout = nn.Dropout(p=dropout_ratio)

    def forward(self, question_representation, context_representation,document_word_sequence_mask):
        
        ############## m = number of instances in document ;  n= number of instances in question ############################33
        Q = question_representation # B x (n + 1) x l
        D = context_representation  # B x (m + 1) x l
        
#         print("We are in Co-attention Encoder ")
        print("question_representation.(Output to Encoder Layer) ==  " + str(Q.size()))
        print("context_representation. (Output to Encoder Layer)  ==  " + str(D.size()))

        # view function is meant to reshape the tensor.(Similar to reshape function in numpy)
        # view( row_size = -1 ,means that number of rows are unknown, column_size)
        
        
        # pass the Q tensor through a non-linearity 
        Q = torch.tanh(self.question_proj(Q.view(-1, self.hidden_dim))).view(Q.size()) #B x (n + 1) x l

        ##################################   Co-Attention starts here  #######################################
        
        ########################################   Step - 1  ##################################################
        # transpose(tensor, first_dimension to be transposed, second_dimension to be transposed)
        Q_transpose = torch.transpose(Q, 1, 2) #dimension: B x l x (n + 1)
        
        # Performs a batch matrix-matrix product of matrices stored in batch1 and batch2.
        # batch1 and batch2 must be 3-D tensors each containing the same number of matrices.
        L = torch.bmm(D, Q_transpose) # dimension of L : B x (m + 1) x (n + 1)

        ####################################### Step-2 ######################################################
        A_Q = F.softmax(L, dim=2) # B x (m + 1) x (n + 1)


        D_transpose = torch.transpose(D, 1, 2) #dimension: B x l x (m + 1)
        C_Q = torch.bmm(D_transpose, A_Q) # (B x l x (m + 1)) x (B x (m + 1) x (n + 1)) => B x l x (n + 1)

        ####################################### Step-3 #######################################################
        L_tranpose = torch.transpose(L,1,2)
        A_D = F.softmax(L_tranpose, dim=2)  # B x (n + 1) x (m + 1)
        
        
        # concatenation along dimension=1:(B x l x (n + 1) ; B x l x (n + 1)  -----> B x 2l x (n + 1) ) x (B x (n + 1) x (m + 1)) ====> B x 2l x (m + 1)
        C_D = torch.bmm(torch.cat((Q_transpose, C_Q), 1), A_D) # B x 2l x (m + 1)
        C_D_transpose = torch.transpose(C_D, 1, 2)  # B x (m + 1) x 2l

        
        #######################################  Step-4 ##########################################################
        #fusion BiLSTM
        # concatenation along dimension = 2:  (B x (m + 1) x 2l ; B x (m + 1) x l  -----> B x (m + 1) x 3l )
        bi_lstm_input = torch.cat((C_D_transpose, D), 2) # B x (m + 1) x 3l
        bi_lstm_input = self.dropout(bi_lstm_input)
        
#         print("document_word_sequence_mask")
#         print(document_word_sequence_mask.size())
       
        U = self.fusion_bilstm(bi_lstm_input, document_word_sequence_mask) # B x m x 2l
        
        print("size of U.(U is output of Co-attention encoder) ==  " + str(U.size()))
        
        return U

#         loss, index_start, index_end = self.decoder(U, document_word_sequence_mask, answer_start, answer_end)
#         if answer_start is not None:
#             return loss, index_start, index_end
#         else:
#             return index_start, index_end


class Fusion_BiLSTM(nn.Module):
    def __init__(self, hidden_dim, dropout_ratio):
        super(Fusion_BiLSTM, self).__init__()
         # batch_first = True
        # Input: has a dimension of B * m * embedding_dim
        # Function parameters: input_size, hidden_size, num_layers_of_LSTM = 1(here)
        self.fusion_bilstm = nn.LSTM(3 * hidden_dim, hidden_dim, 1, batch_first=True,
                                     bidirectional=True, dropout=dropout_ratio)
#         init_lstm_forget_bias(self.fusion_bilstm)
        self.dropout = nn.Dropout(p=dropout_ratio)

    def forward(self, word_sequence_embeddings, word_sequence_mask):
        
#         print(word_sequence_mask)
        # stores length of per instance for context/question
        length_per_instance = torch.sum(word_sequence_mask, 1)
        
        
        # sorts the length_per_instance vector in decreasing order
        length_per_instance_sorted, length_per_instance_argsort = torch.sort(length_per_instance, 0, True) 
        
        _, length_per_instance_argsort_argsort = torch.sort(length_per_instance_argsort, 0)
        
        # selects the word indexes from word_sequences_indexes matrix according to of length_per_instance_argsort
        word_sequence_embeddings_sorted = torch.index_select(word_sequence_embeddings, 0, length_per_instance_argsort)

      
        # All RNN modules accept packed sequences as inputs.
        # Input: word_sequence_embeddings_sorted has a dimension of B x m x l (l is the size of the glove_embedding/ pre-trained embedding/embedding_dim)
        packed_word_sequence_embeddings_sorted = pack_padded_sequence(word_sequence_embeddings_sorted, length_per_instance_sorted, batch_first=True)
        
        # nn.LSTM encoder gets an input of pack_padded_sequence of dimensions: B x m x l (l is the embedding_dim)
        # since the input was a packed sequence, the output will also be a packed sequence
        output, _ = self.fusion_bilstm(packed_word_sequence_embeddings_sorted)
        
        # Pads a packed batch of variable length sequences.
        # It is an inverse operation to pack_padded_sequence().
        output_to_BiLSTM_padded, _ = pad_packed_sequence(output, batch_first=True)
        
        # Returns a contiguous tensor containing the same data as self 
        output_to_BiLSTM_padded = output_to_BiLSTM_padded.contiguous()
        
        # dimension:  B x m x l
        output_to_BiLSTM_padded_sorted = torch.index_select(output_to_BiLSTM_padded, 0, length_per_instance_argsort_argsort)  
        output_to_BiLSTM_padded_sorted = self.dropout(output_to_BiLSTM_padded_sorted)

        return output_to_BiLSTM_padded_sorted