In [1]:
import torch
import torch.nn as nn

In [5]:
batch_size = 2
seq_len = 10
input_dim = 200
emb_dim = 5
enc_hid_dim = 3
dec_hid_dim = 3
dropout = 0.3

In [60]:
x = torch.randint(1, input_dim-1, (seq_len, batch_size))
x[0] = torch.LongTensor([0]*2)
x[-1] = torch.LongTensor([input_dim-1]*2)
x

tensor([[  0,   0],
        [ 51,  63],
        [133,  67],
        [ 10, 118],
        [154,  84],
        [138,  12],
        [ 84, 139],
        [183, 159],
        [ 52,  58],
        [199, 199]])

In [61]:
embedding = nn.Embedding(input_dim, emb_dim)

In [62]:
embedding.weight.data.shape

torch.Size([200, 5])

In [64]:
embedded = embedding(x)

In [66]:
embedded.shape # seq_len, batch_size, emb_dim

torch.Size([10, 2, 5])

In [67]:
# num_layers: default is 1
# bias: default is True
# batch_first: default is False
#     if True, input and output tensor's shape is
#         (batch, seq, feature)
# dropout: default is 0
# bidirectional: default is False
rnn = nn.GRU(emb_dim, enc_hid_dim)

In [71]:
w_ih, w_hh, b_ih, b_hh = list(rnn.parameters())

In [74]:
w_ih.shape, w_hh.shape, b_ih.shape, b_hh.shape
# Since bidirectional is False, num_directions is 1.
# Since rnn's hidden unit is GRU, gate_size is 3.
# w_ih's shape: (9, 5) == (gate_size * hidden_size, input_size)
# w_hh's shape: (9, 3) == (gate_size * hidden_size, hidden_size)
# b_ih's shape: (9) == (gate_size * hidden_size)
# b_hh's shape: (9) == (gate_size * hidden_size)

(torch.Size([9, 5]), torch.Size([9, 3]), torch.Size([9]), torch.Size([9]))

In [287]:
num_layers = 1
num_directions = 1
max_batch_size = embedded.size(1)

In [289]:
hx = torch.zeros(num_layers*num_directions,
                 max_batch_size, enc_hid_dim)

In [291]:
hx.shape

torch.Size([1, 2, 3])

In [270]:
embedded.shape

torch.Size([10, 2, 5])

In [300]:
r_t = torch.sigmoid(
    embedded.matmul(w_ih[:3,:].T.contiguous()) + b_ih[:3] + 
    hx.matmul(w_hh[:3,:].T.contiguous()) + b_hh[:3])

In [304]:
z_t = torch.sigmoid(
    embedded.matmul(w_ih[3:6,:].T.contiguous()) + b_ih[3:6] + 
    hx.matmul(w_hh[3:6,:].T.contiguous()) + b_hh[3:6])

In [327]:
n_t = torch.tanh(
    embedded.matmul(w_ih[6:,:].T.contiguous()) + b_ih[6:] + 
    r_t * (hx.matmul(w_hh[6:,:].T.contiguous()) + b_hh[6:]))

In [332]:
h_t = (1 - z_t) * n_t + z_t * hx

In [337]:
outputs, hiddens = rnn(embedded)

In [338]:
outputs.shape, hiddens.shape

(torch.Size([10, 2, 3]), torch.Size([1, 2, 3]))

In [340]:
h_t

tensor([[[ 0.1079, -0.1235,  0.0699],
         [ 0.1079, -0.1235,  0.0699]],

        [[ 0.1702,  0.8597,  0.5208],
         [-0.1957,  0.5138,  0.0871]],

        [[-0.3517,  0.5044,  0.0150],
         [ 0.1909, -0.0404,  0.4272]],

        [[ 0.4123,  0.2468,  0.2104],
         [-0.1673,  0.6032, -0.0051]],

        [[ 0.1030, -0.2435,  0.0642],
         [ 0.0549,  0.6815,  0.3285]],

        [[-0.2560,  0.4002, -0.1629],
         [ 0.0359,  0.3880,  0.2751]],

        [[ 0.0549,  0.6815,  0.3285],
         [-0.3142,  0.6174, -0.1427]],

        [[ 0.3864,  0.1922,  0.2020],
         [-0.3375,  0.1038,  0.1796]],

        [[ 0.1717, -0.2850, -0.2530],
         [ 0.5966,  0.6466, -0.0836]],

        [[-0.2410,  0.7671, -0.0040],
         [-0.2410,  0.7671, -0.0040]]], grad_fn=<AddBackward0>)

In [341]:
hiddens

tensor([[[-0.1882,  0.7398, -0.0412],
         [-0.2194,  0.8933, -0.1362]]], grad_fn=<StackBackward>)

In [277]:
w_ih[:3,:].permute(1,0)

tensor([[ 0.0171,  0.0963,  0.5388],
        [ 0.1237, -0.3646, -0.0688],
        [ 0.5130,  0.2453,  0.3575],
        [-0.4128, -0.5138,  0.5604],
        [ 0.3955,  0.5427,  0.3718]], grad_fn=<PermuteBackward>)

In [278]:
w_ih[:3,:].T

tensor([[ 0.0171,  0.0963,  0.5388],
        [ 0.1237, -0.3646, -0.0688],
        [ 0.5130,  0.2453,  0.3575],
        [-0.4128, -0.5138,  0.5604],
        [ 0.3955,  0.5427,  0.3718]], grad_fn=<PermuteBackward>)

In [240]:
w_ih[:3,:]

tensor([[ 0.0171,  0.1237,  0.5130, -0.4128,  0.3955],
        [ 0.0963, -0.3646,  0.2453, -0.5138,  0.5427],
        [ 0.5388, -0.0688,  0.3575,  0.5604,  0.3718]],
       grad_fn=<SliceBackward>)