In [1]:
import torch
import torch.nn as nn

# LSTM cell

<img src="./assets/1.png" width="500"/>

- Input:
    - $c_{t-1}, h_{t-1}$: Input cell, hidden
    - $x_t$: Input data


#### Forget gate
- f = 0: completely forget memories from $c_{t-1}$
- f = 1: completely includes memories from $c_{t-1}$

$$ f = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf})$$


#### Input gate
- Determining how important the (transformed) new external input ($x_t$) is.



$$i = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi})$$


#### Cell gate
- non-linear transformation of the new external input $x_t$

$$g = tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg})$$


#### output gate
- Controls how much of the new cell state $c_t$ should go to the output (and the hidden state $h_t$)

$$ o = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho})$$


- Outputs
    + $c_t = f*c + i*g$
    + $h_t = o*tanh(c_t)$

## Example

<img src="./assets/2.png" width="500"/>


In [2]:
INPUT_DIM = 10
HIDDEN_DIM = 20

rnn = nn.LSTMCell(
    input_size=INPUT_DIM,
    hidden_size=HIDDEN_DIM)

In [3]:
T = 6
X = torch.randn(T, INPUT_DIM)
h = torch.randn(1, HIDDEN_DIM)
c = torch.randn(1, HIDDEN_DIM)

In [4]:
NUM_CELL = 3

outs = torch.zeros(NUM_CELL, 1, HIDDEN_DIM)
for t in range(NUM_CELL):
    x_t = X[t].unsqueeze(0)
    h, c = rnn(x_t, (h, c))
    outs[t] = h

outs.size()

torch.Size([3, 1, 20])

# LSTM layers
`batch_first=True`: Give batch_size at dim=0 in output 

In [5]:
INPUT_DIM = 258
BATCH_SIZE = 64

X = torch.rand(BATCH_SIZE, 192, INPUT_DIM)

In [6]:
HID_DIM = 512
N_LAYERS = 2

rnn = nn.LSTM(
    input_size=INPUT_DIM,
    hidden_size=HID_DIM,
    num_layers=N_LAYERS,
    batch_first=True,
    dropout=0.4)

#### Feed with h_0, c_0 rand

In [7]:
out, (hidden, cell) = rnn(X)

print(out.size())
print(hidden.size())
print(cell.size())

torch.Size([64, 192, 512])
torch.Size([2, 64, 512])
torch.Size([2, 64, 512])


#### Feed with initialized h_0, c_0

In [8]:
h_0 = torch.zeros(N_LAYERS, BATCH_SIZE, HID_DIM).float()
c_0 = torch.zeros(N_LAYERS, BATCH_SIZE, HID_DIM).float()

out, (hidden, cell) = rnn(X, (h_0, c_0))

print(out.size())
print(hidden.size())
print(cell.size())

torch.Size([64, 192, 512])
torch.Size([2, 64, 512])
torch.Size([2, 64, 512])


# bi-LSTM layers

In [9]:
INPUT_DIM = 258
BATCH_SIZE = 64

X = torch.rand(BATCH_SIZE, 192, INPUT_DIM)

In [10]:
HID_DIM = 512
N_LAYERS = 2

bi_rnn = nn.LSTM(
    input_size=INPUT_DIM,
    hidden_size=HID_DIM,
    num_layers=N_LAYERS,
    batch_first=True,
    bidirectional=True,
    dropout=0.4)

#### Feed with h_0, c_0 rand

In [11]:
out, (hidden, cell) = bi_rnn(X)

print(out.size())
print(hidden.size())
print(cell.size())

torch.Size([64, 192, 1024])
torch.Size([4, 64, 512])
torch.Size([4, 64, 512])


#### Feed with initialized h_0, c_0

In [12]:
h_0 = torch.zeros(2*N_LAYERS, BATCH_SIZE, HID_DIM).float()
c_0 = torch.zeros(2*N_LAYERS, BATCH_SIZE, HID_DIM).float()

out, (hidden, cell) = bi_rnn(X, (h_0, c_0))

print(out.size())
print(hidden.size())
print(cell.size())

torch.Size([64, 192, 1024])
torch.Size([4, 64, 512])
torch.Size([4, 64, 512])
