In [1]:
import torch
from torch import nn
from torch import optim

# Overview

- Batch size: $n$
- Input size: $d$
- Input: `X`, $X_t \in \mathbb R^{n \times d}$
  - Thus a "batch" at time $t$ comprises training examples $\mathbf x_1, \mathbf x_2, \ldots, \mathbf x_n \in \mathbb R^{1 \times d}$ packaged as *row vectors* into an $n \times d$ matrix like so:
    $$
        X_t \coloneqq
            \begin{pmatrix}
                \mathbf x_1 \\ \mathbf x_2 \\ \vdots \\ \mathbf x_n
            \end{pmatrix}
    $$
- Hidden state: `H`, $H_{t - 1} \in \mathbb R^{n \times h}$
- Forget gate: `F`, $F_t \in \mathbb R^{n \times h}$
- Input gate: `I`, $I_t \in \mathbb R^{n \times h}$
- Input node: `C_tilde`, $\tilde C_t \in \mathbb R^{n \times h}$
- Output gate: `O`, $O_t \in \mathbb R^{n \times h}$

Define an affine map
\begin{align*}
    \mathrm{Aff}_{W_i, \mathbf b} \coloneqq \mathrm{Aff} : \mathbb R^{n \times d} \times \cdots \times \mathbb R^{n \times d}
        &\to \mathbb R^{n \times h}
\\
    X_1, \ldots, X_\ell &\mapsto \sum_{i = 1}^\ell X_i W_i \oplus \mathbf b
\end{align*}
where $W_i \in \mathbb R^{d \times h}$ and $\mathbf b \in \mathbb R^{1 \times h}$ are weights and biases, and $\oplus$ denotes row-wise addition.

The gates and input node are calculated thus:
\begin{align*}
    F_t &\coloneqq (\sigma \circ \mathrm{Aff})(X_t, H_{t - 1}) \\
    I_t &\coloneqq (\sigma \circ \mathrm{Aff})(X_t, H_{t - 1}) \\
    O_t &\coloneqq (\sigma \circ \mathrm{Aff})(X_t, H_{t - 1}) \\
    \tilde C_t &\coloneqq (\tanh \circ \mathrm{Aff})(X_t, H_{t - 1}) .
\end{align*}
Here $\sigma : \mathbb R^{n \times h} \to \mathbb R^{n \times h}$ is the activation function (either $\mathrm{sigmoid}$ or $\mathrm{ReLU}$) applied component-wise. Finally, the two outputs are computed as
\begin{align*}
    C_t &\coloneqq F_t \odot C_{t - 1} + I_t \odot \tilde C_t
\\
    H_t &\coloneqq O_t \odot \tanh(C_t) ,
\end{align*}
where $\odot$ is the Hadamard product.

In [2]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, sigma=0.01):
        super(LSTM, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)

        triple = lambda: (
            init_weight(input_size, hidden_size),
            init_weight(hidden_size, hidden_size),
            nn.Parameter(torch.zeros(hidden_size))
        )

        self.W_xi, self.W_hi, self.b_i = triple()  # input gate
        self.W_xf, self.W_hf, self.b_f = triple()  # forget gate
        self.W_xo, self.W_ho, self.b_o = triple()  # output gate
        self.W_xc, self.W_hc, self.b_c = triple()  # input node

    def forward(self, inputs, H_C=None):
        if H_C is None:
            # initial state with shape: (batch_size, hidden_size)
            H = torch.zeros((inputs.shape[1], self.hidden_size),
                device=inputs.device)
            C = torch.zeros((inputs.shape[1], self.hidden_size),
                device=inputs.device)
        
        else:
            H, C = H_C
        
        outputs = []
        for X in inputs:
            I = torch.sigmoid(
                torch.matmul(X, self.W_xi) +
                torch.matmul(H, self.W_hi) + self.b_i
            )
            F = torch.sigmoid(
                torch.matmul(X, self.W_xf) +
                torch.matmul(H, self.W_hf) + self.b_f
            )
            O = torch.sigmoid(
                torch.matmul(X, self.W_xo) +
                torch.matmul(H, self.W_ho) + self.b_o
            )
            C_tilde = torch.tanh(
                torch.matmul(X, self.W_xc) +
                torch.matmul(H, self.W_hc) + self.b_c
            )

            C = F * C + I * C_tilde
            H = O * torch.tanh(C)

            outputs.append(H)

        return outputs, (H, C)

In [3]:
# Generate synthetic data for long-range dependency
def generate_data(num_samples, sequence_length, input_size, threshold):
    X = torch.randn(num_samples, sequence_length, input_size)
    y = ((X[:, 0, 0] + X[:, -1, 0]) > threshold).long()  # Label based on first and last element
    return X, y

# Hyperparameters
num_samples = 1000
sequence_length = 10  # Long sequence
input_size = 1
hidden_size = 16
output_size = 2  # Binary classification
batch_size = 32
num_epochs = 100
threshold = 0.0
learning_rate = 0.001

# Create dataset
X, y = generate_data(num_samples, sequence_length, input_size, threshold)

# Split into training and test sets
train_size = int(0.8 * num_samples)
test_size = num_samples - train_size

train_X, test_X = X[:train_size], X[train_size:]
train_y, test_y = y[:train_size], y[train_size:]

# Data loaders
train_data = torch.utils.data.TensorDataset(train_X, train_y)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

test_data = torch.utils.data.TensorDataset(test_X, test_y)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Fully connected classification head
class SequenceClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SequenceClassifier, self).__init__()
        self.lstm = LSTM(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        outputs, _ = self.lstm(x.permute(1, 0, 2))  # seq_len, batch, input_size
        last_output = outputs[-1]  # Use the last hidden state
        return self.fc(last_output)

In [4]:
# Training setup
model = SequenceClassifier(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    if (epoch % 10) == 0:
        print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader)}")

# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch_X, batch_y in test_loader:
        outputs = model(batch_X)
        predictions = outputs.argmax(dim=1)
        correct += (predictions == batch_y).sum().item()
        total += batch_y.size(0)

print(f"Test Accuracy: {correct / total:.2%}")

Epoch 0, Loss: 0.690939257144928
Epoch 10, Loss: 0.5118549025058746
Epoch 20, Loss: 0.4067698395252228
Epoch 30, Loss: 0.27837460935115815
Epoch 40, Loss: 0.2187814524769783
Epoch 50, Loss: 0.1719493129849434
Epoch 60, Loss: 0.15739808827638627
Epoch 70, Loss: 0.13721859976649284
Epoch 80, Loss: 0.13250648885965347
Epoch 90, Loss: 0.11599033936858177
Test Accuracy: 94.50%


In [5]:
# Example Test: Print inputs, predictions, and expected outputs
model.eval()

# Select a few samples from the test set
example_X = test_X[:5]
example_y = test_y[:5]

with torch.no_grad():
    outputs = model(example_X)
    predictions = outputs.argmax(dim=1)

print("Input sequences, model predictions, and expected outputs:")
for i in range(len(example_X)):
    print(f"Sequence {i + 1}:")
    print(f"  Input: {example_X[i].squeeze(-1).numpy()}")
    print(f"  Prediction: {predictions[i].item()}")
    print(f"  Expected: {example_y[i].item()}")

Input sequences, model predictions, and expected outputs:
Sequence 1:
  Input: [ 1.3026822  -0.91571444  0.59605736 -0.6351513  -0.04250909  2.0221457
  1.2990092  -0.54431796  1.2314717  -2.2671065 ]
  Prediction: 1
  Expected: 0
Sequence 2:
  Input: [ 0.9253716  -1.5769877  -0.3064221   1.0338665  -2.0728173   1.443539
  0.71105564 -0.6184454  -0.23085828  0.9072259 ]
  Prediction: 1
  Expected: 1
Sequence 3:
  Input: [-1.3915073  -0.48210853  0.35327092  0.44365886 -0.76442474  0.32325813
 -0.12421709 -0.25350264  0.14209668  1.8866785 ]
  Prediction: 1
  Expected: 1
Sequence 4:
  Input: [ 2.2446215  -0.17471376  0.6772357  -1.3377231  -1.3301219  -0.28191957
 -0.20286323  1.2673936   0.65420157 -1.82388   ]
  Prediction: 1
  Expected: 1
Sequence 5:
  Input: [ 0.6712648  -1.0581725  -0.56995696  0.16180553  0.556193    0.5367684
  0.74680084  0.4026059  -0.3619496  -0.07254669]
  Prediction: 1
  Expected: 1
