In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x7f20f4018dd0>

In [66]:
IN_DIM = 2
OUT_DIM = 3
SEQ_NUM = 6

lstm = nn.LSTM(IN_DIM, OUT_DIM)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, IN_DIM) for _ in range(SEQ_NUM)]  # make a sequence of length 5

# initialize the hidden state.
# (h_t, c_t)
hidden = (torch.randn(1, 1, OUT_DIM),
          torch.randn(1, 1, OUT_DIM))
print('Input shape: ', inputs[0].shape)
print('Input shape: ', inputs[0].view(1, 1, -1).shape)

for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    # .view(1, 1, -1) is to add dimension to match lstm
    out, hidden = lstm(i.view(1, 1, -1), hidden)
print('out: ', out, '\n', out.shape, '\n')
print('hidden', hidden, '\n', hidden[0].shape, '\n', hidden[1].shape, '\n')
# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument  to the lstm at a later time
# Add the extra 2nd dimension
print(len(inputs))
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
print('Input shape: ', inputs.shape)
hidden = (torch.randn(1, 1, OUT_DIM), 
          torch.randn(1, 1, OUT_DIM))  # clean out hidden state
out, hidden = lstm(inputs, hidden)
print('out: ', out, '\n', out.shape, '\n')
print('hidden', hidden, '\n', hidden[0].shape, '\n', hidden[1].shape)

Input shape:  torch.Size([1, 2])
Input shape:  torch.Size([1, 1, 2])
out:  tensor([[[-0.3508, -0.0305, -0.0209]]], grad_fn=<CatBackward>) 
 torch.Size([1, 1, 3]) 

hidden (tensor([[[-0.3508, -0.0305, -0.0209]]], grad_fn=<ViewBackward>), tensor([[[-0.6717, -0.0446, -0.0398]]], grad_fn=<ViewBackward>)) 
 torch.Size([1, 1, 3]) 
 torch.Size([1, 1, 3]) 

6
Input shape:  torch.Size([6, 1, 2])
out:  tensor([[[-0.3693,  0.3967, -0.1858]],

        [[-0.1375,  0.3370, -0.2531]],

        [[-0.3246, -0.0412, -0.0182]],

        [[-0.3911, -0.1074, -0.0485]],

        [[-0.3773, -0.0558, -0.0146]],

        [[-0.3753, -0.0224,  0.0149]]], grad_fn=<CatBackward>) 
 torch.Size([6, 1, 3]) 

hidden (tensor([[[-0.3753, -0.0224,  0.0149]]], grad_fn=<ViewBackward>), tensor([[[-0.7461, -0.0327,  0.0286]]], grad_fn=<ViewBackward>)) 
 torch.Size([1, 1, 3]) 
 torch.Size([1, 1, 3])
