In [206]:
import torch

"""
This experiment shows that the correct split of the hidden-layer is the following:

outputs, hidden = rnn.forward(inputs, hidden)

Batch-First == True 
outputs = outputs.view(bs,      seqlen,   ndirects, hidden_size)
hidden  =  hidden.view(nlayers, ndirects, bs,       hidden_size)

Batch-First == False 
outputs = outputs.view(seqlen,  bs,       ndirects, hidden_size)
hidden  =  hidden.view(nlayers, ndirects, bs,       hidden_size)

"""

'\nThis experiment shows that the correct split of the hidden-layer is the following:\n\noutputs, hidden = rnn.forward(inputs, hidden)\n\nBatch-First == True \noutputs = outputs.view(bs,      seqlen,   ndirects, hidden_size)\nhidden  =  hidden.view(nlayers, ndirects, bs,       hidden_size)\n\nBatch-First == False \noutputs = outputs.view(seqlen,  bs,       ndirects, hidden_size)\nhidden  =  hidden.view(nlayers, ndirects, bs,       hidden_size)\n\n'

In [207]:
bs, seqlen, dim, hidden_size, nlayers, ndirects = 5, 10, 20, 40, 3, 2

In [208]:
"""
verify view-split is correct for batch-first == False
"""
gru = nn.GRU(dim, hidden_size, num_layers=nlayers, bidirectional=(ndirects==2), batch_first=False)

input = torch.ones(seqlen, bs, dim)
hidden = torch.ones(ndirects*nlayers, bs, hidden_size)

encoder_outputs, gru_hidden = gru.forward(input, hidden)
print("encoder_outputs.shape: {}".format(encoder_outputs.shape))
print("gru_hidden.shape:      {}".format(gru_hidden.shape))
encoder_outputs = encoder_outputs.view(seqlen, bs, ndirects, hidden_size)
gru_hidden      = gru_hidden.view(nlayers, ndirects, bs, hidden_size)
print("encoder_outputs.shape: {}".format(encoder_outputs.shape))
print("gru_hidden.shape:      {}".format(gru_hidden.shape))

# assume that the forward direction is pos 0
one = encoder_outputs[:,:,0]
two = gru_hidden[:, 0]
# assume that the last pos in sequence is 9
one = one[9]
# assume that the last layer is pos 2
two = two[2]
assert torch.equal(one, two)

# assume that the backward direction is pos 1
one = encoder_outputs[:,:,1]
two = gru_hidden[:, 1]
# assume that the last pos in sequence is 0
one = one[0]
# assume that the last layer is pos 2
two = two[2]
assert torch.equal(one, two)

encoder_outputs.shape: torch.Size([10, 5, 80])
gru_hidden.shape:      torch.Size([6, 5, 40])
encoder_outputs.shape: torch.Size([10, 5, 2, 40])
gru_hidden.shape:      torch.Size([3, 2, 5, 40])


In [209]:
"""
verify view-split is correct for batch-first == True
"""
gru = nn.GRU(dim, hidden_size, num_layers=nlayers, bidirectional=(ndirects==2), batch_first=True)

input = torch.ones(bs, seqlen, dim)
hidden = torch.ones(ndirects*nlayers, bs, hidden_size)

encoder_outputs, gru_hidden = gru.forward(input, hidden)
print("encoder_outputs.shape: {}".format(encoder_outputs.shape))
print("gru_hidden.shape:      {}".format(gru_hidden.shape))

encoder_outputs = encoder_outputs.view(bs, seqlen, ndirects, hidden_size)
gru_hidden      = gru_hidden.view(nlayers, ndirects, bs, hidden_size)
print("encoder_outputs.shape: {}".format(encoder_outputs.shape))
print("gru_hidden.shape:      {}".format(gru_hidden.shape))

# assume that the forward direction is pos 0
one = encoder_outputs[:,:,0]
two = gru_hidden[:, 0]
# assume that the last pos in sequence is 9
one = one[:,9]
# assume that the last layer is pos 2
two = two[2]
assert torch.equal(one, two)


# assume that the backward direction is pos 1
one = encoder_outputs[:,:,1]
two = gru_hidden[:, 1]
# assume that the last pos in sequence is 0
one = one[:,0]
# assume that the last layer is pos 2
two = two[2]
assert torch.equal(one, two)

encoder_outputs.shape: torch.Size([5, 10, 80])
gru_hidden.shape:      torch.Size([6, 5, 40])
encoder_outputs.shape: torch.Size([5, 10, 2, 40])
gru_hidden.shape:      torch.Size([3, 2, 5, 40])
