In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [None]:
class BahdanauAttention(nn.Module):
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim, attention_dim):
        super(BahdanauAttention, self).__init__()
        self.W1 = nn.Linear(in_features=encoder_hidden_dim, out_features=attention_dim)
        self.W2 = nn.Linear(in_features=decoder_hidden_dim, out_features=decoder_hidden_dim)
        self.V = nn.Linear(attention_dim, 1)

    def forward(self, encoder_outputs, decoder_hidden):
        '''
            encoder_outputs ----> shape (batch, input_len, encoder_hidden_dim)
            decoder_hidden  ----> shape (batch, decoder_hidden_dim)
        '''

        decoder_hidden = decoder_hidden.unsqueeze(1) # decoder_hidden_shape ---> (batch, 1, decoder_hidden_dim)

        score = F.tanh(self.W1(encoder_outputs) + self.W2(decoder_hidden))
        # score shape ---> (batch, input_len, attention_dim)

        energy = self.V(score).squeeze(-1)
        # energy shape ----> (batch, input_len)

        attention_weighs = F.softmax(energy, dim=1)
        # attention_weights ----> (batch, input_len)

        # attention_weights.unsqueeze(1) ---> (batch, 1, input_len)
        # encoder_inputs                 ---> (batch, input_len, encoder_hidden)
        context = torch.bmm(attention_weighs.unsqueeze(1), encoder_outputs) # shape --> (batch, 1, encoder_hidden_dim)

        context_vector = context.squeeze(1) # shape ---> (batch, encoder_hidden_dim)

        return context_vector, attention_weighs

In [None]:
class LuongDotAttention(nn.Module):
    def __init__(self):
        super(LuongDotAttention, self).__init__()

    def forward(self, encoder_outputs, decoder_hidden):
        '''
            encoder_outputs ----> shape (batch, input_len, encoder_hidden_dim)
            decoder_hidden  ----> shape (batch, decoder_hidden_dim)
        '''
        decoder_hidden = decoder_hidden.unsqueeze(1) # shape ---> (batch, 1, decoder_hidden_dim)

        scores = torch.bmm(decoder_hidden, encoder_outputs.transpose(1, 2))
        # scores shape ----> (batch, 1, input_len)

        attention_weighs = F.softmax(scores, dim=-1)
        # scores shape ----> (batch, 1, input_len)

        context = torch.bmm(attention_weighs, encoder_outputs) # (batch, 1, input_len) * (batch, input_len, encoder_hidden_dim)
        # context shape ----> (batch, 1, encoder_hidden_dim)

        context_vector = context.squeeze(1) # (batch, encoder_hidden_dim)
        attention_weighs = attention_weighs.squeeze(1) # shape ---> (batch, input_len)

        return context_vector, attention_weighs


In [None]:
class LuongGeneralAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(LuongGeneralAttention, self).__init__()
        self.Wa = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, encoder_outputs, decoder_hidden):
        '''
            encoder_outputs ----> shape (batch, input_len, hidden_dim)
            decoder_hidden  ----> shape (batch, hidden_dim)
        '''

        encoder_outputs = self.Wa(encoder_outputs)
        # shape --> (batch, input_len, hidden_dim)

        decoder_hidden = decoder_hidden.unsqueeze(1) # shape --> (batch, 1, hidden_dim)

        scores = torch.bmm(decoder_hidden, encoder_outputs.transpose(1,2))
        # shape --> (batch, 1, input_len)

        attention_weights = F.softmax(scores.squeeze(1), dim=1)
        # shape --> (batch, input_len)

        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        # shape --> (batch, 1, hidden_dim)
        
        return context.squeeze(1), attention_weights

In [None]:
class LuongConcatAttention(nn.Module):
    def __init__(self, hidden_dim, attention_dim):
        super(LuongConcatAttention, self).__init__()
        self.Wa = nn.Linear(hidden_dim*2, attention_dim) ## for concatentation --> [hi ; st]
        self.Va = nn.Linear(attention_dim,1)

    def forward(self, encoder_outputs, decoder_hidden):
        '''
            encoder_outputs ----> shape (batch, input_len, hidden_dim)
            decoder_hidden  ----> shape (batch, hidden_dim)
        '''

        _, input_len, _ = encoder_outputs.size()

        decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, input_len, 1)
        # shape ---> (batch, hidden_dim) to (batch, input_len, hidden_dim)

        concatenate = torch.cat((encoder_outputs, decoder_hidden), dim=2)
        # shape --> (batch, input_len, hidden_dim * 2)

        energy = F.tanh(self.Wa(concatenate)) # shape --> (batch, input_len, attention_dim)

        scores = self.Va(energy).squeeze(2) # shape --> (batch, input_len)

        attention_weights = F.softmax(scores, dim=1) 
        # shape --> (batch, input_len)

        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        # shape --> (batch, 1, hidden_dim)

        context_vector = context.squeeze(1) # shape --> (batch, hidden_dim)
        return context_vector, attention_weights 