# The Encoder-Decoder Architecture



In [1]:
from torch import nn
from d2l import torch as d2l

Encoder

In [2]:
class Encoder(nn.Module):  
    """The base encoder interface for the encoder-decoder architecture."""
    def __init__(self):
        super().__init__()

    def forward(self, X, *args):
        raise NotImplementedError

Decoder

In [3]:
class Decoder(nn.Module):  
    """The base decoder interface for the encoder-decoder architecture."""
    def __init__(self):
        super().__init__()

    def init_state(self, enc_all_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError

Putting the Encoder and Decoder Together

In [4]:
class EncoderDecoder(d2l.Classifier):  
    """The base class for the encoder-decoder architecture."""
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_all_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_all_outputs, *args)
        return self.decoder(dec_X, dec_state)[0]