### This is an [example](https://tomekkorbak.com/2020/06/26/implementing-attention-in-pytorch/) of the attention mechanism.

####  First, the LSTM layer.
![alt text](https://miro.medium.com/max/1400/1*IM5fjlTYrdYD5XAq2dVEvQ.png  "LSTM")


In [1]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

In [2]:
class Encoder(nn.Module):
    """ use seq of 2
    """
    def __init__(self, input_size, hidden_size, bidirectional=True):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.bidirectional = bidirectional

        self.lstm = nn.LSTM(input_size, hidden_size, bidirectional=bidirectional)

    def forward(self, inputs, hidden):

        output, hidden = self.lstm(inputs.view(-1, 1, self.input_size), hidden)
        return output, hidden
    
    def init_hidden(self):
        return (torch.zeros(1 + int(self.bidirectional), 1, self.hidden_size),
          torch.zeros(1 + int(self.bidirectional), 1, self.hidden_size))

### The attention layer


Self attention, note Key and Value are same. The formula is corresponding to the MultiplicativeAttention
<div>
<img src="https://media.geeksforgeeks.org/wp-content/uploads/20200812212119/encoderselfattention.PNG" width="600"/>

</div>

In [3]:
class Attention(torch.nn.Module):
    """
    use values as key. Do not use sample dimension
    query -- decoder; 
    value -- encoder;
    """

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def forward(self, 
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
        ):
        weights = self._get_weights(query, values) # [seq_length]
        weights = torch.nn.functional.softmax(weights, dim=0)
        return weights @ values  # [encoder_dim]

    
class AdditiveAttention(Attention):

    def __init__(self, encoder_dim, decoder_dim):
        super().__init__(encoder_dim, decoder_dim)
        self.v = torch.nn.Parameter(
            torch.FloatTensor(self.decoder_dim).uniform_(-0.1, 0.1))
        self.W_1 = torch.nn.Linear(self.decoder_dim, self.decoder_dim)
        self.W_2 = torch.nn.Linear(self.encoder_dim, self.decoder_dim)

    def _get_weights(self,        
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor,  # [seq_length, encoder_dim]
    ):
        query = query.repeat(values.size(0), 1)  # [seq_length, decoder_dim]
        weights = self.W_1(query) 
        weights += self.W_2(values)  # [seq_length, decoder_dim]
        print(f"{weights.shape}--weight")
        return torch.tanh(weights) @ self.v  # [seq_length]
    

class MultiplicativeAttention(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(
            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
    ):
        weights = query @ self.W @ values.T  # [seq_length]
        return weights/np.sqrt(self.decoder_dim)  # [seq_length]

#### Let's review the dimensions

In [4]:
bidirectional = True

input_dim = 5
output_dim = 16
hidden_dim = 8



c = Encoder(input_dim, hidden_dim, bidirectional)
a, b = c.forward(torch.randn(10), c.init_hidden())
print(f"{a.shape}--a.shape")
print(f"{b[0].shape}--b[0].shape")
print(f"{b[1].shape}--b[1].shape")

#x = AdditiveAttention(encoder_dim=16, decoder_dim=16)
x = MultiplicativeAttention(encoder_dim=16, decoder_dim=16)
y= x.forward(query=a[0], values=a.squeeze()) 


print(f"{y.shape}--w.shape")

torch.Size([2, 1, 16])--a.shape
torch.Size([2, 1, 8])--b[0].shape
torch.Size([2, 1, 8])--b[1].shape
torch.Size([1, 16])--w.shape
