In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x7f7aebfd64b0>

In [2]:
lstm = nn.LSTM(input_size=3, hidden_size=3)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)]  # make a sequence of length 5

In [3]:
inputs

[tensor([[-0.5525,  0.6355, -0.3968]]),
 tensor([[-0.6571, -1.6428,  0.9803]]),
 tensor([[-0.0421, -0.8206,  0.3133]]),
 tensor([[-1.1352,  0.3773, -0.2824]]),
 tensor([[-2.5667, -1.4303,  0.5009]])]

In [4]:
# initialize the hidden state.
# hidden = (torch.randn(1, 1, 3),
#           torch.randn(1, 1, 3))
hidden = (torch.ones(1, 1, 3),
          torch.ones(1, 1, 3))
hidden

(tensor([[[1., 1., 1.]]]), tensor([[[1., 1., 1.]]]))

In [5]:
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)
 
# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument  to the lstm at a later time
# Add the extra 2nd dimension
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))  # clean out hidden state
out, hidden = lstm(inputs, hidden)
print(out)
print(hidden)

tensor([[[-0.2682,  0.0304, -0.1526]],

        [[-0.5370,  0.0346, -0.1958]],

        [[-0.3947,  0.0391, -0.1217]],

        [[-0.1854,  0.0740, -0.0979]],

        [[-0.3600,  0.0893,  0.0215]]], grad_fn=<MkldnnRnnLayerBackward0>)
(tensor([[[-0.3600,  0.0893,  0.0215]]], grad_fn=<StackBackward0>), tensor([[[-1.1298,  0.4467,  0.0254]]], grad_fn=<StackBackward0>))


In [None]:
inputs