In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
SEQ_LEN = 10
BATCH_SIZE = 3
INPUT_DIM = 30
OUTPUT_DIM = 37
ENC_EMB_DIM = DEC_EMB_DIM = 32
ENC_HID_DIM = DEC_HID_DIM = 64
ENC_DROPOUT = DEC_DROPOUT = 0.5

x = torch.randint(0+1, INPUT_DIM-2, size=(SEQ_LEN, BATCH_SIZE))
x[0, :] = 0 
x[-1, :] = INPUT_DIM - 1

y = torch.randint(0+1, OUTPUT_DIM-2, size=(SEQ_LEN, BATCH_SIZE))
y[0, :] = 0
y[-1, :] = OUTPUT_DIM - 1

print(x, x.shape, end='\n\n')
print(y, y.shape)

tensor([[ 0,  0,  0],
        [27, 27, 13],
        [ 7, 16, 15],
        [ 9, 13, 26],
        [10,  8,  8],
        [18,  7, 10],
        [11, 21, 12],
        [23, 13,  9],
        [27, 17,  7],
        [29, 29, 29]]) torch.Size([10, 3])

tensor([[ 0,  0,  0],
        [27, 17, 12],
        [16,  1, 34],
        [ 8, 31,  3],
        [20, 11, 17],
        [28, 27, 12],
        [ 4, 19, 20],
        [14, 13, 34],
        [30, 18, 24],
        [36, 36, 36]]) torch.Size([10, 3])


In [3]:
class Encoder(nn.Module):

    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        H = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
        hidden = torch.tanh(self.fc(H))
        return outputs, hidden

In [4]:
encoder = Encoder(INPUT_DIM, 
                  ENC_EMB_DIM, 
                  ENC_HID_DIM, 
                  DEC_HID_DIM, 
                  ENC_DROPOUT)

In [5]:
encoder_outputs, hidden = encoder(x)

In [6]:
encoder_outputs.shape, hidden.shape

(torch.Size([10, 3, 128]), torch.Size([3, 64]))

In [7]:
trg = y[0, :]
trg = trg.unsqueeze(0)
trg.shape

torch.Size([1, 3])

In [8]:
embedding = nn.Embedding(OUTPUT_DIM, DEC_EMB_DIM)
rnn       = nn.GRU((ENC_HID_DIM * 2) + DEC_EMB_DIM, DEC_HID_DIM)
fc_out    = nn.Linear(
    (ENC_HID_DIM * 2) + DEC_HID_DIM + DEC_EMB_DIM, 
    OUTPUT_DIM)
dropout   = nn.Dropout(DEC_DROPOUT)

In [9]:
embedded = dropout(embedding(trg))

In [10]:
embedded.shape

torch.Size([1, 3, 32])

In [11]:
encoder_outputs.shape, hidden.shape

(torch.Size([10, 3, 128]), torch.Size([3, 64]))

In [12]:
attn = nn.Linear((ENC_HID_DIM * 2) + DEC_HID_DIM, DEC_HID_DIM)
v    = nn.Linear(DEC_HID_DIM, 1, bias=False)

In [13]:
batch_size = encoder_outputs.shape[1] # 3
src_len    = encoder_outputs.shape[0] # 10

In [14]:
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
hidden.shape

torch.Size([3, 10, 64])

In [15]:
encoder_outputs = encoder_outputs.permute(1, 0, 2)
encoder_outputs.shape

torch.Size([3, 10, 128])

In [16]:
attn_input = torch.cat((hidden, encoder_outputs), dim=2)
attn_input.shape

torch.Size([3, 10, 192])

In [17]:
energy = torch.tanh(attn(attn_input))
energy.shape

torch.Size([3, 10, 64])

In [18]:
attention = v(energy).squeeze(2)
attention

tensor([[ 0.0189,  0.0935,  0.0608,  0.2113,  0.1403,  0.1453,  0.0755,  0.1548,
          0.1265,  0.0429],
        [-0.0150,  0.0411,  0.0435,  0.0629,  0.0530,  0.0198,  0.0844,  0.1897,
          0.0761,  0.1103],
        [ 0.0310,  0.0026,  0.1806,  0.2035,  0.0789,  0.1099,  0.0044,  0.1412,
          0.0234,  0.0902]], grad_fn=<SqueezeBackward1>)

In [19]:
annotation = F.softmax(attention, dim=1)
annotation

tensor([[0.0914, 0.0985, 0.0953, 0.1108, 0.1032, 0.1037, 0.0967, 0.1047, 0.1018,
         0.0937],
        [0.0920, 0.0973, 0.0976, 0.0995, 0.0985, 0.0953, 0.1017, 0.1129, 0.1008,
         0.1043],
        [0.0944, 0.0917, 0.1096, 0.1121, 0.0990, 0.1021, 0.0919, 0.1054, 0.0937,
         0.1001]], grad_fn=<SoftmaxBackward>)

In [20]:
# Softmax
a = torch.exp(attention) / \
torch.exp(attention).sum(dim=1).unsqueeze(1)

a

tensor([[0.0914, 0.0985, 0.0953, 0.1108, 0.1032, 0.1037, 0.0967, 0.1047, 0.1018,
         0.0937],
        [0.0920, 0.0973, 0.0976, 0.0995, 0.0985, 0.0953, 0.1017, 0.1129, 0.1008,
         0.1043],
        [0.0944, 0.0917, 0.1096, 0.1121, 0.0990, 0.1021, 0.0919, 0.1054, 0.0937,
         0.1001]], grad_fn=<DivBackward0>)

In [21]:
a = a.unsqueeze(1)
a.shape

torch.Size([3, 1, 10])

In [22]:
encoder_outputs.shape

torch.Size([3, 10, 128])

In [23]:
weighted = torch.bmm(a, encoder_outputs)
weighted.shape
# (b, n, m) X (b, m, p) ==>> (b, n, p)

torch.Size([3, 1, 128])

In [24]:
weighted = weighted.permute(1, 0, 2)
weighted.shape

torch.Size([1, 3, 128])

In [25]:
rnn_input = torch.cat((embedded, weighted), dim=2)
rnn_input.shape

torch.Size([1, 3, 160])