In [1]:
from torch import nn
from utils_torch import *
import torch

# The Encoder-Decoder Architecture
Encoder-decoder architectures can handle inputs and outputs that both consist of variable-length sequences and thus are suitable for seq2seq problems such as machine translation. 

An encoder-decoder architecture consisting of two major components: 
1. Encoder: takes a variable-length sequence as input and transforms it into a state with a fixed shape;
2. Decoder: acts as a conditional language model, taking in the encoded input and the leftwards context of the target sequence and predicting the subsequent token in the target sequence.

![image.png](attachment:image.png)

## Encoder

In [2]:
class Encoder(nn.Module):
    """The base encoder interface for the encoder-decoder architecture."""

    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def forward(self, X, *args):
        # encoder takes variable-length sequences as input X
        raise NotImplementedError

To generate a variable-length sequence token by token, every time the decoder may map an input (e.g., the generated token at the previous time step) and the encoded state into an output token at the current time step.

In [3]:
class Decoder(nn.Module):  # @save
    """The base decoder interface for the encoder-decoder architecture."""

    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def init_state(self, enc_all_outputs, *args):
        # convert the encoder output (enc_all_outputs) into the encoded state.
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError

In [4]:
class EncoderDecoder(Classifier):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        # the encoder is used to produce the encoded state
        enc_all_outputs = self.encoder(enc_X, *args)
        # this state will be used by the decoder as one of its input.
        dec_state = self.decoder.init_state(enc_all_outputs, *args)
        return self.decoder(dec_X, dec_state)[0]

# Encoder-Decoder Seq2Seq for Machine Translation 
![image.png](attachment:image.png)


## Encoder

In [51]:
def init_seq2seq(module):
    """Initialize weights for Seq2Seq."""
    if type(module) == nn.Linear:
        pass
    if type(module) == nn.GRU:
        """
           module._flat_weights_names = ['weight_ih_l0',
                                         'weight_hh_l0',
                                         'bias_ih_l0',
                                         'bias_hh_l0',
                                         'weight_ih_l1',
                                         'weight_hh_l1',
                                         'bias_ih_l1',
                                         'bias_hh_l1']
        """
        for param in module._flat_weights_names:
            
            if "weight" in param:
                nn.init.xavier_uniform_(module._parameters[param])

class Seq2SeqEncoder(Encoder):
    def __init__(self, vocab_size, embed_size,
                 num_hiddens, num_layers, dropout=0):
        super().__init__()
        # an embedding layer to obtain the feature vector for each token in the input sequence.
        # The weight of an embedding layer is a matrix,
        # number of rows corresponds to the size of the input vocabulary (vocab_size);
        # number of columns corresponds to the feature vector’s dimension (embed_size).
        # For any input token index i, the embedding layer fetches the i-th row (starting from 0) of the weight matrix to return its feature vector
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = GRU(embed_size, num_hiddens, num_layers, dropout)
        self.apply(init_seq2seq)
        

    def forward(self, X, *args):
        # X shape: (batch_size, num_steps)
        embs = self.embedding(X.t().type(torch.int64))
        # embs shape: (num_steps, batch_size, embed_size)
        outputs, state = self.rnn(embs)
        # outputs shape: (num_steps, batch_size, num_hiddens)
        # state shape: (num_layers, batch_size, num_hiddens)
        return outputs, state

In [53]:
# instantiate a two-layer GRU encoder whose number of hidden units is 16
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 9
encoder = Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
X = torch.zeros((batch_size, num_steps))
enc_outputs, enc_state = encoder(X)

In [54]:
enc_outputs.shape, enc_state.shape

(torch.Size([9, 4, 16]), torch.Size([2, 4, 16]))

## Decoder 

In [None]:
class Seq2SeqDecoder(Decoder):
    def __init__