For Blog: [Luong Attention Understanding | Yam](https://yam.gift/2020/04/14/Paper/2020-04-14-Luong-Attention/)

In [6]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

In [None]:
cd ~/Documents/Study/DL-Models/pytorch-batch-luong-attention/

In [None]:
from utils.embeddings import create_embedding_maps
from utils.batches import batches, data_from_batch
from models.luong_attention import luong_attention
from torch.autograd import Variable
from utils.tokens import Tokens

In [101]:
import importlib
def load_data(dataset_module, train_dir, debug_restrict_data):
    dataset_module_path = f"utils.load_and_preprocessing.{dataset_module}"
    dataset_module = importlib.import_module(dataset_module_path)
    train, val = dataset_module.load_data(train_dir, debug_restrict_data)
    return train, val

In [104]:
train, val = load_data("translation", "data/translation", None)

100%|██████████| 10835/10835 [00:00<00:00, 970679.74it/s]
100%|██████████| 10/10 [00:00<00:00, 48044.72it/s]


In [107]:
train.shape

(10835, 2)

In [108]:
val.shape

(10, 2)

In [794]:
train.head()

Unnamed: 0,source,target
0,je vais dormir .,i am going to bed .
1,je suis presque prete .,i am almost ready .
2,tu es encore un bleu .,you re still green .
3,c est toi qui m as entraine .,you re the one who trained me .
4,on apprend encore a se connaitre .,we re still getting to know each other .


In [808]:
encoder_embedding_map, \
decoder_embedding_map, \
encoder_embedding_matrix, \
decoder_embedding_matrix = create_embedding_maps(train, val, 64, different_vocab=True)

  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 10835/10835 [00:00<00:00, 111409.53it/s]
 50%|█████     | 1/2 [00:00<00:00,  9.12it/s]
100%|██████████| 10/10 [00:00<00:00, 45003.26it/s]
100%|██████████| 2/2 [00:00<00:00, 15.54it/s]
  0%|          | 0/2 [00:00<?, ?it/s]
  0%|          | 0/10835 [00:00<?, ?it/s][A
100%|██████████| 10835/10835 [00:00<00:00, 86121.99it/s][A
 50%|█████     | 1/2 [00:00<00:00,  7.59it/s]
100%|██████████| 10/10 [00:00<00:00, 36503.95it/s]
100%|██████████| 2/2 [00:00<00:00, 13.93it/s]


In [809]:
encoder_embedding_map.n_words

4490

In [810]:
encoder_embedding_matrix.weight.shape

torch.Size([4490, 64])

In [811]:
decoder_embedding_matrix.weight.shape

torch.Size([2927, 64])

In [812]:
decoder = luong_attention.LuongAttnDecoderRNN("general", 32, 64, 
                                              decoder_embedding_map.n_words, 
                                              2, 0.1, decoder_embedding_matrix, False)

In [813]:
decoder

LuongAttnDecoderRNN(
  (embedding): Embedding(2927, 64)
  (gru): GRU(64, 32, num_layers=2, dropout=0.1)
  (concat): Linear(in_features=64, out_features=32, bias=True)
  (out): Linear(in_features=32, out_features=2927, bias=True)
  (attn): Attn(
    (attn): Linear(in_features=32, out_features=32, bias=True)
  )
)

In [823]:
encoder = luong_attention.EncoderRNN(32, 64, 2, 0.1, encoder_embedding_matrix, "GRU", False)

In [824]:
encoder

EncoderRNN(
  (embedding): Embedding(4490, 64)
  (rnn): GRU(64, 32, num_layers=2, dropout=0.1, bidirectional=True)
)

In [825]:
for batch in batches(train, encoder_embedding_map, decoder_embedding_map, use_cuda=False, batch_size=4):
    batch
    break

In [826]:
batch

{'source_var': tensor([[  11,   48,    4,   32],
         [  49, 2573,   54,   17],
         [1186,  637, 3453, 1159],
         [  52,  747,    7,    7],
         [ 164,    7,    3,    3],
         [1187,    3,    0,    0],
         [ 128,    0,    0,    0],
         [1188,    0,    0,    0],
         [   7,    0,    0,    0],
         [   3,    0,    0,    0]]),
 'source_lengths': [10, 6, 5, 5],
 'target_var': tensor([[  12,   21,    4,   31],
         [  13,   13,   46,   54],
         [  44, 1844, 1961, 1938],
         [  29, 1845,    9,    9],
         [ 931,    9,    3,    3],
         [   9,    3,    0,    0],
         [   3,    0,    0,    0]]),
 'target_lengths': [7, 6, 5, 5]}

In [827]:
source_var, source_lengths, target_var, target_lengths = data_from_batch(batch)
current_batch_size = len(target_lengths)

In [828]:
encoder_outputs, encoder_hidden = encoder(source_var, encoder.init_hidden(current_batch_size), source_lengths)

In [829]:
encoder_outputs.shape

torch.Size([10, 4, 32])

In [830]:
encoder_hidden.shape

torch.Size([2, 4, 32])

In [831]:
decoder_input = Variable(torch.LongTensor([Tokens.SOS_token] * current_batch_size))
decoder_hidden = encoder_hidden

In [833]:
max_target_length = max(target_lengths)
all_decoder_outputs = Variable(torch.zeros(max_target_length, current_batch_size, decoder.output_size))
max_target_length

7

In [834]:
all_decoder_outputs.shape

torch.Size([7, 4, 2927])

In [835]:
decoder_hidden.shape

torch.Size([2, 4, 32])

In [837]:
encoder_outputs.shape

torch.Size([10, 4, 32])

In [838]:
for t in range(max_target_length):
    decoder_output, decoder_hidden, attn_weights = decoder(decoder_input, decoder_hidden, encoder_outputs)
    all_decoder_outputs[t] = decoder_output
    decoder_input = target_var[t]

In [918]:
all_decoder_outputs.shape

torch.Size([7, 4, 2927])

In [919]:
decoder_output.shape

torch.Size([4, 2927])

In [920]:
decoder_hidden.shape

torch.Size([2, 4, 32])

In [921]:
attn_weights.shape

torch.Size([4, 1, 10])

In [889]:
from utils.masked_cross_entropy import masked_cross_entropy, sequence_mask

In [856]:
masked_cross_entropy(all_decoder_outputs.transpose(0, 1).contiguous(),
                                        target_var.transpose(0, 1).contiguous(),
                                        target_lengths)

  log_probs_flat = F.log_softmax(logits_flat)


tensor(7.9472, grad_fn=<DivBackward0>)

In [862]:
target = target_var.transpose(0, 1).contiguous()

In [870]:
logits = all_decoder_outputs.transpose(0, 1).contiguous()
logits.shape

torch.Size([4, 7, 2927])

In [871]:
logits_flat = logits.view(-1, logits.size(-1))
logits_flat.shape

torch.Size([28, 2927])

In [874]:
log_probs_flat = F.log_softmax(logits_flat, dim=1)

In [877]:
target_flat = target.view(-1, 1)
target_flat.shape

torch.Size([28, 1])

In [878]:
# losses_flat: (batch * max_len, 1)
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)

In [885]:
target.size()

torch.Size([4, 7])

In [886]:
losses = losses_flat.view(*target.size())

In [888]:
losses.shape

torch.Size([4, 7])

In [899]:
print(Variable(torch.LongTensor(target_lengths)), target.size(1))

tensor([7, 6, 5, 5]) 7


In [890]:
mask = sequence_mask(sequence_length=Variable(torch.LongTensor(target_lengths)), max_len=target.size(1))

In [906]:
import numpy as np
sequence_length=Variable(torch.LongTensor(target_lengths))

In [907]:
sequence_length

tensor([7, 6, 5, 5])

In [910]:
batch_size = sequence_length.size(0)
batch_size

4

In [911]:
seq_range = torch.arange(0, max_len).long()
seq_range

tensor([0, 1, 2, 3, 4, 5, 6])

In [912]:
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_range_expand

tensor([[0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6]])

In [913]:
seq_range_expand = Variable(seq_range_expand)
seq_range_expand

tensor([[0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6]])

In [914]:
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
seq_length_expand

tensor([[7, 7, 7, 7, 7, 7, 7],
        [6, 6, 6, 6, 6, 6, 6],
        [5, 5, 5, 5, 5, 5, 5],
        [5, 5, 5, 5, 5, 5, 5]])

In [915]:
mask = seq_range_expand < seq_length_expand
mask

tensor([[ True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True, False, False]])

In [917]:
losses

tensor([[7.8367, 7.9381, 7.8862, 7.8695, 8.0942, 7.8369, 8.1076],
        [8.0719, 7.9407, 7.8208, 7.8800, 7.8822, 8.0856, 8.3650],
        [7.9334, 7.9262, 7.9541, 7.7630, 8.1345, 8.3291, 8.3044],
        [7.7619, 7.9629, 8.0988, 7.8717, 8.1288, 8.3198, 8.2923]],
       grad_fn=<ViewBackward>)

In [895]:
target_var

tensor([[  12,   21,    4,   31],
        [  13,   13,   46,   54],
        [  44, 1844, 1961, 1938],
        [  29, 1845,    9,    9],
        [ 931,    9,    3,    3],
        [   9,    3,    0,    0],
        [   3,    0,    0,    0]])

In [892]:
target.size(1)

7

In [935]:
decoder_output.data.shape

torch.Size([4, 2927])

In [927]:
topv, topi = decoder_output.data.topk(1)

In [928]:
topv

tensor([[0.5208],
        [0.4515],
        [0.4341],
        [0.5078]])

In [938]:
topi

tensor([[ 260],
        [2120],
        [2853],
        [2853]])

In [936]:
topi.squeeze().detach().shape

torch.Size([4])

In [952]:
topi[0].item()

260

In [961]:
[index2word(topi[i].item()) for i in range(batch_size)]

['agent', 'ideal', 'startled', 'startled']

In [960]:
def index2word(index: int):
    return decoder_embedding_map.index2word[index]

## Pytorch

In [10]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=9):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        
        print("self embedding shape: ", self.embedding(input).shape)
        embedded = self.embedding(input).view(1, 1, -1)
        print("embedded shape: ", embedded.shape)
        embedded = self.dropout(embedded)
        
        attn_weights = F.softmax(self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        print("attn_weights shape: ", attn_weights.shape)
        
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
        print("attn_weights.unsqueeze(0)", attn_weights.unsqueeze(0).shape)
        print("attn_applied", attn_applied.shape)

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        print("self.attn_combine(output)", self.attn_combine(output).shape)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        print("hidden shape", hidden.shape)
        hidden = hidden[0].unsqueeze(0)
        print("hidden shape", hidden.shape)
        print("output shape", output.shape)
        output, hidden = self.gru(output, hidden)
        print("output, hidden", output.shape, hidden.shape)

        output = F.log_softmax(self.out(output[0]), dim=1)
        print("output", output.shape)
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [None]:
attn_decoder = AttnDecoderRNN(32, decoder_embedding_map.n_words, dropout_p=0.1).to(device)

In [685]:
decoder_input = Variable(torch.LongTensor([Tokens.SOS_token] * current_batch_size))
decoder_hidden = encoder_hidden

In [686]:
embedded = nn.Embedding(2927, 32)(decoder_input).view(1, 1, -1)

In [687]:
encoder_outputs.squeeze(1).shape

torch.Size([8, 4, 32])

In [None]:
decoder_output2, decoder_hidden2, attn_weights2 = attn_decoder(decoder_input, 
                                                               decoder_hidden, 
                                                               encoder_outputs.squeeze(1)
                                                              )

## LuongAttn Modify

In [13]:
class LuongAttnDecoder(nn.Module):
    def __init__(self, hidden_size, input_size, output_size,  n_layers, dropout):
        super(LuongAttnDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        self.embedding = nn.Embedding(self.output_size, self.input_size)
        self.attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.gru = nn.GRU(self.input_size, self.hidden_size, n_layers, dropout=self.dropout)
        self.concat = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)
    def forward(self, input, hidden, encoder_outputs):
        batch_size = input.size(0)
        print(input.shape)
        print(self.embedding(input).shape)
        embedded = self.embedding(input).view(1, batch_size, self.input_size)
        print(embedded.shape)
        output, hidden = self.gru(embedded, hidden)
        print(output.shape, hidden.shape)
        attn_weights = F.softmax(torch.bmm(self.attn(output).transpose(1, 0), 
                                           encoder_outputs.permute(1, 2, 0)), 2)
        print(attn_weights.shape)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        print(context.shape)
        output = output.squeeze(0)
        context = context.squeeze(1)
        print(output.shape, context.shape)
        concat_input = torch.cat((output, context), 1)
        print(concat_input.shape)
        concat_output = F.tanh(self.concat(concat_input))
        print(concat_output.shape)
        output = self.out(concat_output)
        return output, hidden, attn_weights

In [14]:
latt_decoder = LuongAttnDecoder(32, 64, 2927, 2, 0.1)

In [None]:
decoder_input = Variable(torch.LongTensor([Tokens.SOS_token] * current_batch_size))
decoder_hidden = encoder_hidden

In [788]:
encoder_outputs.shape

torch.Size([8, 4, 32])

In [789]:
decoder_output3, decoder_hidden3, attn_weights3 = latt_decoder(
    decoder_input, decoder_hidden, encoder_outputs)

torch.Size([4])
torch.Size([4, 64])
torch.Size([1, 4, 64])
torch.Size([1, 4, 32]) torch.Size([2, 4, 32])
torch.Size([4, 1, 8])
torch.Size([4, 1, 32])
torch.Size([4, 32]) torch.Size([4, 32])
torch.Size([4, 64])
torch.Size([4, 32])




In [790]:
decoder_output3.shape

torch.Size([4, 2927])

In [791]:
decoder_hidden3.shape

torch.Size([2, 4, 32])

In [792]:
attn_weights3.shape

torch.Size([4, 1, 8])

In [729]:
encoder_outputs.shape

torch.Size([8, 4, 32])

In [728]:
encoder_outputs.permute(1, 2, 0).shape

torch.Size([4, 32, 8])

## Playground

In [91]:
rnn = nn.LSTM(10, 20)
input = torch.randn(1, 3, 10)
h0 = torch.randn(1, 3, 20)
c0 = torch.rand(1, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

In [962]:
rnn = nn.GRU(10, 20)
input = torch.randn(1, 3, 10)
h0 = torch.randn(1, 3, 20)
output, hn = rnn(input, h0)

In [965]:
hn.shape

torch.Size([1, 3, 20])

In [966]:
output.shape

torch.Size([1, 3, 20])

In [612]:
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

In [641]:
embedding = nn.Embedding(2927, 32)

In [644]:
source_var

tensor([[  4,  32],
        [180, 685],
        [ 38,  93],
        [ 36, 183],
        [490, 200],
        [ 59,   7],
        [491,   3],
        [  7,   0],
        [  3,   0]])

In [645]:
embedded = embedding(source_var)

In [655]:
embedded.shape

torch.Size([9, 2, 32])

In [657]:
embedded.view(1, 2, -1).shape

torch.Size([1, 2, 288])

In [647]:
source_var.shape

torch.Size([9, 2])

In [650]:
source_lengths

[9, 7]

In [755]:
gru = nn.GRU(64, 32, 2, dropout=0.2)

In [774]:
input = torch.randn(8, 1, 10)
mat2 = torch.randn(8, 10, 32)
res = torch.bmm(input, mat2)
res.size()

torch.Size([8, 1, 32])