In [1]:
%load_ext autoreload
%autoreload 2

from naml.dataset import Datasets
from naml.dataset.nmt import load_nmt
datasets = Datasets("~/naml-data")
src_words, target_words = load_nmt(datasets, 'fra', 'eng')


## Encoder
We have a RNN where
$$
h_t = \text{RNN}(x_t, h_{t-1})
$$
The encoder converts its hidden states $h$ to context vectors $c$ from all timesteps
$$
c = \text{Encoder}(\{h_1, h_2, \ldots, h_T\})
$$

In [None]:
from naml.modules import torch, nn, optim, F

class Encoder(nn.Module):
    # GRU for implementation
    # This is a slightly modified version of RNN from the one from Chapter 8
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers):
        super().__init__()
        self.vocab_size, self.embed_size, self.num_hiddens, self.num_layers = vocab_size, embed_size, num_hiddens, num_layers
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers)
        # self.dense = nn.Linear(num_hiddens, embed_size) 
        # Hidden states are used as is

    def forward(self, X : torch.Tensor, H : torch.Tensor):
        # X[batch_size, num_steps]
        X = self.embedding(X.T)        
        # X[num_steps, batch_size, embed_size]
        Y, H = self.rnn(X, H)
        # Y[num_steps, batch_size,num_hiddens], H[num_layers, batch_size, num_hiddens]
        return Y, H
    
    def begin_state(self, device : torch.device, batch_size : int):
        return torch.zeros((self.num_layers, batch_size, self.num_hiddens), device=device)

encoder = Encoder(10,8,16,2)
encoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)
H = encoder.begin_state(X.device, batch_size=4)
Y, H = encoder(X, H)
Y.shape, H.shape

(torch.Size([7, 4, 16]), torch.Size([2, 4, 16]))

## Decoder
The decoder is another RNN that takes hidden state $h_{t-1}$, output $y_{t-1}$, and context vector $c$ to produce hidden state $h_t$ and output $y_t$
$$
h_t = \text{RNN}(y_{t-1}, h_{t-1}, c)
$$

In [None]:
class Decoder(nn.Module):    
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers):
        super().__init__()
        self.vocab_size, self.embed_size, self.num_hiddens, self.num_layers = vocab_size, embed_size, num_hiddens, num_layers
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers)
        # [Embedding | Hidden]
        self.dense = nn.Linear(num_hiddens, vocab_size) 

    def forward(self, X : torch.Tensor, H : torch.Tensor):
        # X[batch_size, num_steps]
        X = self.embedding(X.T)        
        # X[num_steps, batch_size, embed_size]
        C = H[-1].repeat(X.shape[0], 1, 1)
        # C[num_steps, batch_size, num_hiddens]
        XC = torch.cat((X, C), dim=2)        
        Y, H = self.rnn(XC, H)
        # Y[num_steps, batch_size,num_hiddens], H[num_layers, batch_size, num_hiddens]
        Y = self.dense(Y)
        # Y[num_steps, batch_size, vocab_size]
        Y : torch.Tensor = Y.permute(1, 0, 2)
        # Y[batch_size, num_steps, vocab_size]
        return Y, H
    
    def begin_state(self, device : torch.device, batch_size : int):
        return torch.zeros((self.num_layers, batch_size, self.num_hiddens), device=device)

encoder = Encoder(10,8,16,2)
encoder.eval()
decoder = Decoder(10,8,16,2)
decoder.eval()

X = torch.zeros((4, 7), dtype=torch.long)
H = encoder.begin_state(X.device, batch_size=4)   

Y, H = encoder(X, H)
Y, H = decoder(X, H)
Y.shape, H.shape

(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))

In [33]:
from naml.sequence import zero_one_mask
size = (4,7)
lens = torch.Tensor([1,2,3,4])
zero_one_mask(size,lens)

tensor([[ True, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False]])