In [8]:
import torch
from torch import nn

In [4]:
device='cuda:0'

# Encoder Decoder Framework
A transformer model [(paper: Attention is all you need)](https://arxiv.org/pdf/1706.03762.pdf) is consist of an encoder and a decoder. Therefore, we first build a general encoder-decoder backbone and test it.

In pytorch, all models and components should be implemented by subclassing [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) instance. A `nn.Module` should at least contain a `__init__` and a `forward` function.

Note that the decoder has its own inputs besides the encoder outputs as that is how the decoder is designed in the transformer paper.

In [31]:
class EncoderDecoderModel(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(
        self,
        encoder_inputs = None,
        decoder_inputs = None,
        encoder_outputs = None,
    ):
        # Compute encoder outputs if it it not provided
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                inputs=encoder_inputs, 
            )

        # Compute decoder output with encoder outputs and 
        # decoder inputs
        decoder_outputs = self.decoder(
            inputs = decoder_inputs,
            encoder_outputs = encoder_outputs,
        )

        # Return both encoder and decoder outputs
        return {
            "encoder_outputs": encoder_outputs,
            "decoder_outputs": decoder_outputs,
        }

# Now let's build a dummy encoder and a dummy decoder to test the `EncoderDecoderModel`

In [32]:
class DummyEncoder(nn.Module):
    def __init__(self, dim_x, dim_y):
        super().__init__()
        self.l1 = nn.Linear(dim_x, dim_y)
    
    def forward(self, inputs):
        hidden_states = self.l1(inputs)
        return {
            'hidden_states': hidden_states
        }

In [33]:
class DummyDecoder(nn.Module):
    def __init__(self, dim_x, dim_y):
        super().__init__()
        self.l1 = nn.Linear(dim_x, dim_y)
    
    def forward(self, inputs, encoder_outputs):
        encoder_hidden_states = encoder_outputs['hidden_states']
        hidden_states = self.l1(inputs+encoder_hidden_states)
        return {
            'hidden_states': hidden_states
        }

# Now test the encoder decoder model with random inputs

In [34]:
dummy_encoder = DummyEncoder(100, 100)
dummy_decoder = DummyDecoder(100, 100)
dummy_encoder_decoder = EncoderDecoderModel(
    encoder=dummy_encoder,
    decoder=dummy_decoder,
)

In [35]:
dummy_encoder_decoder.to(device)

EncoderDecoderModel(
  (encoder): DummyEncoder(
    (l1): Linear(in_features=100, out_features=100, bias=True)
  )
  (decoder): DummyDecoder(
    (l1): Linear(in_features=100, out_features=100, bias=True)
  )
)

In [36]:
encoder_inputs = torch.rand((3,100), device=device)
decoder_inputs = torch.rand((3,100), device=device)

In [43]:
# Run forward pass and check output shape|
forward_outputs_with_only_enc = dummy_encoder_decoder(encoder_inputs, decoder_inputs)
assert forward_outputs_with_only_enc['decoder_outputs']['hidden_states'].shape == (3,100)

In [54]:
print(forward_outputs_with_only_enc['encoder_outputs']['hidden_states'].device)
print(forward_outputs_with_only_enc['decoder_outputs']['hidden_states'].device)

cuda:0
cuda:0


In [62]:
# Make sure all model parameters are on gpu
assert all([param.device.type=='cuda' for param in dummy_encoder_decoder.parameters()])