In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [70]:
class EncoderRNN(nn.Module):
  def __init__(self, emb_size, hidden_size, num_layers, bidirectional, vocab_size, text_embedding_vectors, dropout=0):
    super(EncoderRNN, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.bidirectional = bidirectional
    if text_embedding_vectors == None:
      self.embedding = nn.Embedding(vocab_size, emb_size)
    else:
      self.embedding = nn.Embedding.from_pretrained(
          embeddings=text_embedding_vectors, freeze=True)
    self.lstm = nn.LSTM(emb_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
    self.thought = nn.Linear(hidden_size*2, hidden_size)


  def forward(self, input_seq, hidden=None):
    embedded = self.embedding(input_seq) #[batch, max_length, emb_size]
    outputs, (hn, cn) = self.lstm(embedded) #[batch, max_length, hidden*2], ([2, 64, 600], [2, 64, 600])
    outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] #[batch, max_length, hidden]

    encoder_hidden = tuple([hn[i, :, :] + hn[i+1, :, :] for i in range(0, self.num_layers+2*self.bidirectional, 2)])
    encoder_hidden = torch.stack(encoder_hidden, 0)

    return outputs, encoder_hidden

In [71]:
en = EncoderRNN(300, 600, 2, True, 2000, None, dropout=0)
outputs, encoder_hidden = en(torch.randint(0, 1000, size=(128, 20)))

In [65]:
torch.stack(a, 0).shape

torch.Size([2, 128, 600])

In [22]:
lstm = nn.LSTM(300, 600, 2, batch_first=True, bidirectional=True)
input = torch.randn(128, 20, 300)
h0 = torch.randn(4, 128, 600)
c0 = torch.randn(4, 128, 600)
output, (hn, cn) = lstm(input, (h0, c0))

In [4]:
output.shape

torch.Size([128, 20, 1200])

In [23]:
hn.shape

torch.Size([4, 128, 600])

In [24]:
#num_layer, num_directions
hn2 = hn.view(2, 2, 128, 600)

In [29]:
hn[3, 0, :] == hn2[1, 1, 0, :]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, Tr

In [None]:
hn[0]

In [18]:
output[0, 0, :600].shape

torch.Size([600])

In [17]:
hn[1, 0, 0, :].shape

torch.Size([600])

In [21]:
output[0, 0, 600:] == hn[1, 1, 0, :]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, Tr

In [28]:
2*False

0