In [68]:
import torch.nn as nn
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import pickle
from vocab import Vocab, VocabEntry

In [140]:
class ConcatAttention(nn.Module):
    def __init__(self,encoder_dim,decoder_dim):
        super(ConcatAttention, self).__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.linear = nn.Linear(self.encoder_dim+self.decoder_dim,self.decoder_dim)
        self.W = nn.Linear(self.decoder_dim,1)
        self.relu = nn.ReLU()

    def forward(self, hidden,encoder_outputs):
        batchsize,maxlen,encoderdim = encoder_outputs.size()
        print(hidden.size())
        hidden = hidden.expand(-1,maxlen,-1) #B,1,D => B,L,D
        input = torch.cat([encoder_outputs,hidden],dim=2)
        energy = self.relu(self.linear(input))
        scores = self.W(energy).squeeze(-1) #B x L
        scores = F.softmax(scores,dim=1)
        return scores

In [199]:
class RNNDecoder(nn.Module):

    def __init__(self,vocab,embed_size,context_size,hidden_size,n_layers,dropout,attention):
        super(RNNDecoder, self).__init__()
        self.dropout = dropout
        self.vocab = vocab
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.embed = nn.Embedding(len(self.vocab),self.embed_size)
        self.context_size = context_size
        self.rnn = nn.LSTM(input_size=self.embed_size+self.context_size,hidden_size=self.hidden_size,
                               num_layers=self.n_layers, batch_first=True)
        self.embed_dropout = nn.Dropout(p=self.dropout)
        self.attention = attention
        self.linear = nn.Linear(self.hidden_size+self.context_size,len(vocab))

    def forward(self,input,hidden,encoder_inputs):
        embedding = self.embed(input)
        attn_scores = self.attention(hidden[0].permute(1,0,2), encoder_inputs).unsqueeze(1)  # bs x 1 x maxlen
        context = attn_scores.bmm(encoder_inputs).squeeze(1)
        input = torch.cat([embedding,context],dim=1)
        output,hidden = self.rnn(input.unsqueeze(1),hidden)
        output = torch.cat([output.squeeze(1),context],dim=1)
        output = F.log_softmax(self.linear(output),dim=1)
        return output,hidden,attn_scores.squeeze(1)


In [202]:
decoder_embed_size = 100
decoder_hidden_size = 100
decoder_context_size = 100
vocab = pickle.load(open('data/vocab.bin', 'rb'))
attention = ConcatAttention(100,100)
embed = nn.Embedding(100,100)
decoder = RNNDecoder(vocab.src,decoder_embed_size,
                     decoder_context_size,decoder_hidden_size,1,0.0,attention)
sentences = torch.LongTensor([[1,2,3,4],[4,5,6,7],[3,6,1,2]])
embeddings = embed(sentences)
encoder_outputs = embeddings
hidden = embeddings[:,1,:]
res = attention(hidden.unsqueeze(1),encoder_outputs)

torch.Size([3, 1, 100])


In [203]:
output,hidden,attn_scores = decoder(sentences[:,1],[hidden.unsqueeze(0),hidden.unsqueeze(0)],encoder_outputs)

torch.Size([3, 1, 100])


In [205]:
output,hidden,attn_scores = decoder(sentences[:,1],hidden,encoder_outputs)

torch.Size([3, 1, 100])


In [207]:
hidden[0].size(),hidden[1].size()

(torch.Size([1, 3, 100]), torch.Size([1, 3, 100]))

In [59]:
encoder_outputs, encoder_hidden = encoder(source,sourcelens)
predicted_target = torch.zeros(target.size())
batch_size,max_len = target.size()
for idx in range(0,max_len):
    outputs,hidden,attn_scores = decoder(inputs,hidden,encoder_outputs)
    inputs = target[:,idx]
    predicted_target[:,idx] = outputs
    
loss = criterion(predicted_target.view(batch_size*max_len,len(vocab.trgt))
                 ,targets.view(batch_size*max_len))
    


tensor([1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

tensor([[1, 2, 3],
        [4, 4, 5],
        [6, 7, 3],
        [6, 1, 2]])