In [5]:
import torch
from torch import nn
from pipeline import *
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
class NMTEncoder(nn.Module):
    def __init__(self, num_embeddings, embedding_size, rnn_hidden_size):
        super(NMTEncoder, self).__init__()
        
        self.embedding = nn.Embedding(num_embeddings, embedding_size, padding_idx=0)
        self.gru = nn.GRU(embedding_size,rnn_hidden_size, bidirectional=True, batch_first=True)
        
        
    def forward(self,x, x_len):
        embedded = self.embedding(x)
        # Create packed sequence
        x_len = x.len.detach().cpu().numpy()
        x_packed = pack_padded_sequence(embedded, x_len, batch_first=True)
        
        # x_birnn_h.shape = (num_rnn, batch_size, feature_size)
        x_birnn_out, x_birnn_h = self.gru(x_packed)
        # permute to (batchs, num_rnn, feature_size)
        x_birnn_h = x_birnn_h.permute(1, 0, 2)
        
        #flatten (bz, rnn_hid * feat_size)
        x_birnn_h = x_birnn_h.contiguous().view(x_birnn_h.size(0),-1)
        
        x_unpacked, _ = pad_packed_sequence(x_birnn_out, batch_first = True)
        return x_unpacked, x_birnn_h
        
        

In [11]:
class NMTDecoder(nn.Module):
    """
     Args:
        num_embeddings (int): number of embeddings; also the number of unique words in the target vocabulary
        embedding_size (int): size of the embedding vector 
        rnn_hidden_size (int): size of the hidden RNN state bos_index(int): BEGIN-OF-SEQUENCE index
    """
    def __init__(self, num_embeddings, embedding_size, rnn_hidden_size, bos_index):
        super(NMTDecoder, self).__init__()
        self._rnn_hidden_size = rnn_hidden_size
        self.t_embedding = nn.Embedding(num_embeddings, embeddings_size, padding_idx=0)
        self.gru_cell = nn.GRUCell(embedding_size + rnn_hidden_size,rnn_hidden_size)
        self.hidden_map = nn.Linear(rnn_hidden_size, rnn_hidden_size)
        self.classifier = nn.Linear(rnn_hidden_size * 2,num_embeddings)
        self.bos_index = bos_index
        
        
    def _init_indices(self, batch_size):
        return torch.ones(batch_size, dtype=torch.int64) * self.bos_index
    
    
    def _init_context_vectors(self, batch_size):
        return torch.zeros(batch_size, self.__rnn_hidden_size)
    
    def forward(self, encoder_state, init_hidden_state, target_seq):
        """The forward pass of the model
        Args:
            encoder_state (torch.Tensor): output of the NMTEncoder 
            initial_hidden_state (torch.Tensor): last hidden state in the NMTEn 
            target_sequence (torch.Tensor): target text data tensor 
            sample_probability (float): schedule sampling parameter
            probability of using model's predictions at each decoder step 
        Returns:
            output_vectors (torch.Tensor): prediction vectors at each output st
        """
        
        target_seq = target_seq.permute(1,0)
        
        batch_size = encoder_state.size(0)
        
        h_t = self.hidden_map(initial_hidden_state)
        
        context_vectors = self._init_context_vectors(batch_size)
        
        y_t_indices = self._init_indices(batch_size)
        
        h_t = h_t.to(encoder_state.device)
        y_t_index = y_t_index.to(encoder_state.device)
        
        context_vectors = context_vectors.to(encoder_state.device)
        
        output_vectors = []
        self._cached_p_attn = []
        self._cached_ht = []
        self._cached_decoder_state = encoder_state.cpu().detach().numpy()
        
        output_sequence_size = target_sequence.size(0)
        
        for i in range(output_sequence_size):
            # Step 1: Embed word and concat with previous context
            y_input_vector = self.t_embedding(target_sequence[i])
            rnn_input = torch.cat([y_input_vector,context_vectors],dim=1)
            
            # Step 2: Make a GRU step, getting a new hidden vector
            h_t = self.gru_cell(rnn_input, h_t)
            self._cached_ht.append(h_t)
            
            # Step 3: Use current hidden vector to attend to atten encoder state
            context_vectors, p_attn, _ = verbose_attention(encoder_state_vectors = encoder_state, 
                                                          query_vector=h_t)
            
            # auxillary: cache the attention probabilities for vis
            self.cached_p_attn.append(p_attn.cpu().detach().numpy())
            
            # Step 4: use current context vector and hidden state for prediction
            prediction_vector = torch.cat((context_vectors, h_t), dim=1)
            score_for_y_t_index = self.classifier(prediction_vector)
            
            output_vectors.append(score_for_y_t_index)
            
            return score_for_y_t_index
            
        
        
        
        
        
        

In [None]:
class NMTModel(nn.Module):
    """ A Neural Translation Model"""
    def __init__(self, source_vocab_size, source_embedding_size, target_vocab_size, target_embedding_size, 
                encoding_size, target_bos_index):
        
        """
        Args:
            source_vocab_size (int): number of unique words in source vocabulary
            source_embedding_size (int): size of embedding vector
            target_vocab_size(int): number of unique words in target vocabulary
            target_embeddiing_size (int): size of target embedding vector
            encoding_size (int): size of encoder RNN
            target_bos_index (int): index for BEGIN-OF-SEQUENCE token
        """
        
        super(NMTModel,self).__init__()
        self.encoder = NMTEncoder(num_embeddings=source_vocab_size,
                                 embedding_size=source_embedding_size,
                                 rnn_hiden_size= encoding_size)
        decoding_size = encoding_size * 2
        self.decoder = NMTDecoder(num_embeddings=target_vocab_size,
                                 embedding_size=target_embedding_size,
                                 rnn_hidden_size=decoding_size)