# Encoder-Decoder Architecture


into an interface that will be implemented later.

## Encoder

In the encoder interface,
we just specify that
the encoder takes variable-length sequences as the input `X`.
The implementation will be provided 
by any model that inherits this base `Encoder` class.


In [None]:
from torch import nn

#@save
class Encoder(nn.Module):
    """The base encoder interface for the encoder-decoder architecture."""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError

## Decoder

In the following decoder interface,
we add an additional `init_state` function
to convert the encoder output (`enc_outputs`)
into the encoded state.
Note that this step
may need extra inputs such as 
the valid length of the input,
which was explained
in :numref:`subsec_mt_data_loading`.
To generate a variable-length sequence token by token,
every time the decoder
may map an input (e.g., the generated token at the previous time step)
and the encoded state
into an output token at the current time step.


In [None]:
#@save
class Decoder(nn.Module):
    """The base decoder interface for the encoder-decoder architecture."""
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError

## Putting the Encoder and Decoder Together

In the end,
the encoder-decoder architecture
contains both an encoder and a decoder,
with optionally extra arguments.
In the forward propagation,
the output of the encoder
is used to produce the encoded state,
and this state
will be further used by the decoder as one of its input.


In [None]:
#@save
class EncoderDecoder(nn.Module):
    """The base class for the encoder-decoder architecture."""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)