# Sequence-to-sequence RNN with Attention
We will now add attention to our sequence-to-sequence RNN. There are several ways to incorporate the context vector $c$ into the RNN architecture:
1. Add an additional term to the computation of the gates/states (i.e. treat it as an input just like $h_{t-1}$ and $x_t$). This was used in the original paper (Bahdanau et al, 2015), described in Appendix A.
2. Concatenate it with the hidden state of the last time step $h_{t-1}$ and project the concatenation down from `enc_hidden_dim + dec_hidden_dim` to `dec_hidden_dim`.
3. Concatenate it with the input $x_t$ and downproject it.

We will use variant 2 in this exercise. We'll make our lives a bit easier by implementing a 1-layer decoder and working with a batch size of 1.

In [1]:
import torch
import torch.nn as nn

Since we have to compute the context vector at every step, we can't use the high-level `nn.LSTM` interface by PyTorch. We first implement a decoder LSTM class that operates an `nn.LSTMCell`. We start with the `__init__` method where we initialize all parameters.

In [9]:
class DecoderLSTMWithAttention(nn.Module):
    
    def __init__(self, input_dim, enc_output_dim, dec_hidden_dim):
        super().__init__()
        self.dec_hidden_dim = dec_hidden_dim
        self.enc_output_dim = enc_output_dim
        self.cell = nn.LSTMCell(input_dim, dec_hidden_dim)
        self.project_down = nn.Linear(enc_output_dim + dec_hidden_dim, dec_hidden_dim)
        self.W = nn.Linear(dec_hidden_dim, enc_output_dim, bias=False)

Add a `reset_parameters` method that initializes all parameters.

In [10]:
def reset_parameters(self):
    self.cell.reset_parameters()
    self.project_down.reset_parameters()
    self.W.reset_parameters()

DecoderLSTMWithAttention.reset_parameters = reset_parameters

Add a `forward` method that takes a sequence `y` and encoder hidden states `encoder_hidden_states` as input. `encoder_hidden_states` is a tensor of size `[sequence_length, encoder_output_dim]`, where `encoder_output_dim = num_directions * encoder_hidden_dim`. The `forward` method should call `compute_context_vector` that computes the attention-weighted context vector. We will implement it later.

In [11]:
def forward(self, y, encoder_hidden_states):
    hidden_state = torch.zeros(self.dec_hidden_dim)
    cell_state = torch.zeros(self.dec_hidden_dim)
    for y_i in y:
        context_vector = self.compute_context_vector(hidden_state, encoder_hidden_states)
        projected = self.project_down(torch.cat(hidden_state, context_vector))
        hidden_state, cell_state = self.cell(y_i, (projected, cell_state))
    return hidden_state
    
DecoderLSTMWithAttention.forward = forward

Now it's time to implement the `compute_context_vector` function. Its inputs are `previous_decoder_hidden_state` and `encoder_hidden_states`. Use either additive or multiplicative attention, as we saw it in the course. Extend the trainable parameters in your `__init__` method if necessary and initialize them in `reset_parameters`.

In [12]:
def compute_context_vector(self, previous_decoder_hidden_state, encoder_hidden_states):
    def f_att(h_dec, h_enc):
        return h_dec @ self.W @ h_enc
    
    a = torch.softmax(torch.tensor([f_att(previous_decoder_hidden_state, h_enc) for h_enc in encoder_hidden_states]), dim=0)
    return sum(a * encoder_hidden_states)
    
DecoderLSTMWithAttention.compute_context_vector = compute_context_vector

**Sequence-to-sequence model.** We will use the following hyperparameters.

In [None]:
# Typically, encoder/decoder hidden dimensions are the same,
# but here we choose them differently to test our implementation.
embedding_dim = 10
enc_hidden_dim = 15
dec_hidden_dim = 20
num_layers = 2
bidirectional = True
num_directions = 2 if bidirectional else 1

Now we define the model.

In [None]:
class Seq2seqLSTMWithAttention(nn.Module):
    
    def __init__(self, embedding_dim, enc_hidden_dim, num_enc_layers, bidirectional, dec_hidden_dim):
        pass

    def forward(self, x, y, h0, c0):
        pass

Try your Module with an example input.

In [None]:
model = Seq2seqLSTMWithAttention(embedding_dim, enc_hidden_dim, num_layers, bidirectional, dec_hidden_dim)
x = torch.randn(10, embedding_dim)
y = torch.randn(8, embedding_dim)
h0 = torch.zeros(num_layers * num_directions, enc_hidden_dim)
c0 = torch.zeros(num_layers * num_directions, enc_hidden_dim)
outputs, _ = model(x, y, h0, c0)
assert list(outputs.shape) == [8, dec_hidden_dim], "Wrong output shape"

Create a subclass of your decoder LSTM that implements the other type of attention (additive or multiplicative) that you haven't implemented above. What do you need to change?

In [None]:
class DecoderLSTMWithMultiplicativeAttention(DecoderLSTMWithAttention):
    # or: DecoderLSTMWithAdditiveAttention
    pass

We can test our implementation with the code below.

In [None]:
enc_output_dim = enc_hidden_dim * num_directions
# Uncomment the version you just implemented
# model.decoder = DecoderLSTMWithAdditiveAttention(embedding_dim, enc_output_dim, dec_hidden_dim)
# model.decoder = DecoderLSTMWithMultiplicativeAttention(embedding_dim, enc_output_dim, dec_hidden_dim)
model.decoder.reset_parameters()
outputs, _ = model(x, y, h0, c0)
assert list(outputs.shape) == [8, dec_hidden_dim], "Wrong output shape"