In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [3]:
SEQ_LEN = 10
BATCH_SIZE = 3
INPUT_DIM = 30
OUTPUT_DIM = 37
ENC_EMB_DIM = DEC_EMB_DIM = 32
ENC_HID_DIM = DEC_HID_DIM = 64
ENC_DROPOUT = DEC_DROPOUT = 0.5

x = torch.randint(0+1, INPUT_DIM-2, size=(SEQ_LEN, BATCH_SIZE))
x[0, :] = 0 
x[-1, :] = INPUT_DIM - 1

y = torch.randint(0+1, OUTPUT_DIM-2, size=(SEQ_LEN, BATCH_SIZE))
y[0, :] = 0
y[-1, :] = OUTPUT_DIM - 1

print(x, x.shape, end='\n\n')
print(y, y.shape)

tensor([[ 0,  0,  0],
        [23, 20, 16],
        [ 3, 23, 11],
        [21,  9, 18],
        [27,  9, 11],
        [19, 18,  1],
        [ 7,  8, 11],
        [15,  9,  5],
        [ 8, 11, 26],
        [29, 29, 29]]) torch.Size([10, 3])

tensor([[ 0,  0,  0],
        [25, 27, 22],
        [32, 16, 19],
        [20, 28, 32],
        [27, 15,  3],
        [18, 20, 22],
        [ 7,  6,  5],
        [28, 33, 21],
        [ 8,  7, 32],
        [36, 36, 36]]) torch.Size([10, 3])


In [4]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, emb_dim, hid_dim, dropout):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        return hidden

In [7]:
encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, ENC_DROPOUT)

In [11]:
context = encoder(x)

In [12]:
context.shape

torch.Size([1, 3, 64])

In [13]:
hidden = context

In [15]:
y.shape

torch.Size([10, 3])

In [18]:
embedding = nn.Embedding(OUTPUT_DIM, DEC_EMB_DIM)
rnn = nn.GRU(DEC_EMB_DIM+DEC_HID_DIM, DEC_HID_DIM)
fc_out = nn.Linear(DEC_EMB_DIM+DEC_HID_DIM*2, OUTPUT_DIM)
dropout = nn.Dropout(DEC_DROPOUT)

In [31]:
trg = y[0, :]
trg = trg.unsqueeze(0)
trg.shape

torch.Size([1, 3])

In [32]:
embedded = embedding(trg)

In [33]:
embedded.shape

torch.Size([1, 3, 32])

In [35]:
context.shape

torch.Size([1, 3, 64])

In [36]:
emb_con = torch.cat((embedded, context), dim=2)
emb_con.shape

torch.Size([1, 3, 96])

In [42]:
output_, hidden_ = rnn(emb_con, hidden) # f(y_{t-1}, context, h_{t-1})

In [43]:
output_.shape, hidden_.shape

(torch.Size([1, 3, 64]), torch.Size([1, 3, 64]))

In [44]:
embedded.shape, hidden.shape, context.shape

(torch.Size([1, 3, 32]), torch.Size([1, 3, 64]), torch.Size([1, 3, 64]))

In [45]:
output_ = torch.cat(
    (embedded.squeeze(0), 
     hidden.squeeze(0), 
     context.squeeze(0)),
    dim=1)

In [46]:
output_.shape

torch.Size([3, 160])

In [49]:
fc_out(output_).shape

torch.Size([3, 37])

In [51]:
output_.argmax(dim=1)

tensor([21, 21, 21])