In [11]:
import torch
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F

import torchvision.datasets as dset
import torchvision.transforms as T

class AttentionWrapper(nn.Module):
    def __init__(self, rnn, use_attention):
        super(AttentionWrapper, self).__init__()
        self.rnn_cell = rnn
        self.attention = use_attention
        self.projection_for_decoderRNN = nn.Linear(512, 256, bias=False)
    def forward(self, memory, decoder_input, cell_hidden):
        """
        memory = (batch_size, encoder_T, dim)
        decoder_input = (batch_size, dim)
        cell_hidden (previous time step cell state) = (batch, dim)
        """
        batch_size = memory.size(0)
        #cell_input = torch.cat((decoder_input, prev_attention), -1) -- why do we have to concat?
        cell_input = decoder_input
        query = self.rnn_cell(cell_input, cell_hidden)
        #feed into attention
        attention_weights = self.attention(query, memory)
        #make context vector
        attention_weights = F.softmax(attention_weights, dim=-1)
        context = torch.bmm(attention_weights.view(batch_size, 1, -1), memory).squeeze(1)
        out = self.projection_for_decoderRNN(torch.cat([context, query],dim=-1))
        return out, query, attention_weights


class BahdanauAttention(nn.Module):
    def __init__(self):
        super(BahdanauAttention, self).__init__()
        self.v = nn.Linear(256,1,bias=False)
        self.query_layer = nn.Linear(256,256,bias=False)
        self.tanh = nn.Tanh()
    def forward(self, query, memory):
        """
        query : (batch, 1 ,dim)
        """
        if query.dim() == 2:
            query = query.unsqueeze(1)
        attention_weight = self.v(self.tanh(self.query_layer(query) + memory))
        return attention_weight


def test_attention_wrapper():
    B = 2

    encoder_outputs = Variable(torch.rand(B, 100, 256))

    memory_layer = nn.Linear(256, 256)
    query = Variable(torch.rand(B, 128))

    attention_mechanism = BahdanauAttention()

    # Attention context + input
    rnn = nn.GRUCell(128, 256)

    attention_rnn = AttentionWrapper(rnn, attention_mechanism)
    initial_attention = Variable(torch.zeros(B, 256))
    cell_state = Variable(torch.zeros(B, 256))

    cell_output, attention, alignment = attention_rnn(encoder_outputs,query, cell_state)

    print("Cell output size:", cell_output.size())
    print("Attenn output size:", attention.size())
    print("Alignment size:", alignment.size())

    assert (alignment.sum(-1) == 1).data.all()


test_attention_wrapper()

Cell output size: torch.Size([2, 256])
Attenn output size: torch.Size([2, 256])
Alignment size: torch.Size([2, 100, 1])
