# Encoder Decoder

In [1]:
import collections
import math
import torch
from torch import nn
from d2l import torch as d2l
import wandb
wandb.init(project='encoder_decoder')

[34m[1mwandb[0m: Currently logged in as: [33mingambe[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


For a lot of sequence to sequence problem such as translation, you can't simply translate word by word the input to get a correct output  
You have to first read the full original sentence to get a context before translation  
Also, the input sentence and the output sequence are of variable length, so they can't be processed sequentially

To handle such kind of problem, we need an architecture to first get the global context before starting producing output called an *encoder-decoder architecture*

The first component is an *encoder*, it takes a variable-length sequence as the input and transforms it into a state with a fixed shape representing the context  
The second component is a *decoder*, it maps the encoded state of a fixed shape to a variable-length sequence  


<center>
    <img src='data/seq2seq.svg' width="65%" style="margin-left:auto; margin-right:auto"/>
    <p style="font-size:14px;">Source: <a href='d2l.ai'>D2L</a></p>
</center>

We use an embedding layer to obtain the feature vector for each token in the input sequence. 
The weight of an embedding layer is a matrix whose number of rows equals to the size of the input vocabulary (`vocab_size`) and number of columns equals to the feature vector’s dimension (`embed_size`). 
For any input token index $i$, the embedding layer fetches the $i^{\mathrm{th}}$ row (starting from $0$) of the weight matrix to return its feature vector

In [2]:
nn.Embedding(3, 4)

# 0 -> [0.3, 0.5, 0.12, 0.4]
# 1 -> [0.23, 0.3, 0.22, 0.42]
# 2 -> [0.03, 0.12, 0.11, 0.65]

# cat 0 => [1, 0, 0]
# dog 1 => [0, 1, 0]
# car 2 => [0, 0, 1]

# BEFORE:
# teaching: [0, 0, 0, 0, 0, 1, ...., 0, ..., 0] # \30 000\

nn.Embedding(30000, 256)

# teaching: [0.1, 0.24, -0.3, ..., 0.9, 0.88] # \256\

# Tokenize:
# word_a, word_b, word_c
# 0       1       2


class Seq2SeqEncoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqEncoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,
                          dropout=dropout)

    def forward(self, X, *args):
        # The output `X` shape: (`batch_size`, `num_steps`, `embed_size`)
        X = self.embedding(X)
        # In RNN models, the first axis corresponds to time steps
        X = X.permute(1, 0, 2)
        # When state is not mentioned, it defaults to zeros
        output, state = self.rnn(X)
        # `output` shape: (`num_steps`, `batch_size`, `num_hiddens`)
        # `state` shape: (`num_layers`, `batch_size`, `num_hiddens`)
        return output, state

In [3]:
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,
                         num_layers=2)
X = torch.zeros((4, 7), dtype=torch.long)
output, state = encoder(X)
output.shape, state.shape

(torch.Size([7, 4, 16]), torch.Size([2, 4, 16]))

In [4]:
class Seq2SeqDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqDecoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,
                          dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]

    def forward(self, X, state):
        # The output `X` shape: (`num_steps`, `batch_size`, `embed_size`)
        X = self.embedding(X).permute(1, 0, 2)
        # Broadcast `context` so it has the same `num_steps` as `X`
        context = state[-1].repeat(X.shape[0], 1, 1)
        X_and_context = torch.cat((X, context), 2)
        output, state = self.rnn(X_and_context, state)
        output = self.dense(output).permute(1, 0, 2)
        # `output` shape: (`batch_size`, `num_steps`, `vocab_size`)
        # `state` shape: (`num_layers`, `batch_size`, `num_hiddens`)
        return output, state

In [5]:
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,
                         num_layers=2)
state = decoder.init_state(encoder(X))
output, state = decoder(X, state)
output.shape, state.shape

(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))

In [6]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X):
        enc_outputs = self.encoder(enc_X)
        dec_state = self.decoder.init_state(enc_outputs)
        return self.decoder(dec_X, dec_state)

In [7]:
def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])

As we pad the sequence as the beginning, we need to not consider this in the `CrossEntropy`loss

In [8]:
def sequence_mask(X, valid_len, value=0):
    """Mask irrelevant entries in sequences."""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    """The softmax cross-entropy loss with masks."""
    # `pred` shape: (`batch_size`, `num_steps`, `vocab_size`)
    # `label` shape: (`batch_size`, `num_steps`)
    # `valid_len` shape: (`batch_size`,)
    
    def forward(self, pred, label, valid_len):
        
        weights = torch.ones_like(label)
        weights = sequence_mask(weights, valid_len).float()
        self.reduction='none'
        unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(pred.permute(0, 2, 1), label)
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss

Download the machine translation dataset, it contains English sentences translated to French  
Each input/output have a fixed length, specified by `num_steps`

In [None]:
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1

batch_size, num_steps = 64, 10
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
lr, num_epochs = 0.005, 300

In [10]:
encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)

net = EncoderDecoder(encoder, decoder)

net.apply(xavier_init_weights)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
loss = MaskedSoftmaxCELoss()
net.train()

cum_losses = []

for epoch in range(num_epochs):
    
    metric = d2l.Accumulator(2)  # Sum of training loss, no. of tokens
    
    for batch in train_iter:
        X, X_valid_len, Y, Y_valid_len = [x for x in batch]
        bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0]).reshape(-1, 1)
        dec_input = torch.cat([bos, Y[:, :-1]], 1)  # Teacher forcing
        Y_hat, _ = net(X, dec_input)
        l = loss(Y_hat, Y, Y_valid_len)
        l.sum().backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
        optimizer.step()
        with torch.no_grad():
            wandb.log({
            'loss': l.sum().item()
            })
        

wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
