In [1]:
import torch
import torch.nn as nn

In [7]:
lstm = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)]  # make a sequence of length 5

# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)

# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument  to the lstm at a later time
# Add the extra 2nd dimension
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
print(inputs.shape)
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))  # clean out hidden state
print(hidden[0].shape)
out, hidden = lstm(inputs, hidden)
print(out)
print(hidden)

torch.Size([5, 1, 3])
torch.Size([1, 1, 3])
tensor([[[ 0.0287,  0.0790,  0.2252]],

        [[ 0.0983, -0.0231,  0.2040]],

        [[ 0.1608, -0.2256,  0.2451]],

        [[ 0.1576, -0.1451,  0.1595]],

        [[ 0.1995, -0.1398,  0.1382]]], grad_fn=<StackBackward0>)
(tensor([[[ 0.1995, -0.1398,  0.1382]]], grad_fn=<StackBackward0>), tensor([[[ 0.3620, -0.4810,  0.4798]]], grad_fn=<StackBackward0>))
