In [1]:
import torch
from torch import nn
from torch import optim

In [2]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, sigma=0.01):
        super(LSTM, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)
        triple = lambda: (init_weight(input_size, hidden_size),
                         init_weight(hidden_size, hidden_size),
                         nn.Parameter(torch.zeros(hidden_size)))

        self.W_xi, self.W_hi, self.b_i = triple()  # input gate
        self.W_xf, self.W_hf, self.b_f = triple()  # forget gate
        self.W_xo, self.W_ho, self.b_o = triple()  # output gate
        self.W_xc, self.W_hc, self.b_c = triple()  # input node

    def forward(self, inputs, H_C=None):
        if H_C is None:
            # initial state with shape: (batch_size, hidden_size)
            H = torch.zeros((inputs.shape[1], self.hidden_size),
                device=inputs.device)
            C = torch.zeros((inputs.shape[1], self.hidden_size),
                device=inputs.device)
        
        else:
            H, C = H_C
        
        outputs = []
        for X in inputs:
            I = torch.sigmoid(
                torch.matmul(X, self.W_xi) +
                torch.matmul(H, self.W_hi) + self.b_i
            )
            F = torch.sigmoid(
                torch.matmul(X, self.W_xf) +
                torch.matmul(H, self.W_hf) + self.b_f
            )
            O = torch.sigmoid(
                torch.matmul(X, self.W_xo) +
                torch.matmul(H, self.W_ho) + self.b_o
            )
            C_tilde = torch.tanh(
                torch.matmul(X, self.W_xc) +
                torch.matmul(H, self.W_hc) + self.b_c
            )

            C = F * C + I * C_tilde
            H = O * torch.tanh(C)

            outputs.append(H)

        return outputs, (H, C)

In [3]:
# Generate synthetic data for long-range dependency
def generate_data(num_samples, sequence_length, input_size, threshold):
    X = torch.randn(num_samples, sequence_length, input_size)
    y = ((X[:, 0, 0] + X[:, -1, 0]) > threshold).long()  # Label based on first and last element
    return X, y

# Hyperparameters
num_samples = 1000
sequence_length = 10  # Long sequence
input_size = 1
hidden_size = 16
output_size = 2  # Binary classification
batch_size = 32
num_epochs = 100
threshold = 0.0
learning_rate = 0.001

# Create dataset
X, y = generate_data(num_samples, sequence_length, input_size, threshold)

# Split into training and test sets
train_size = int(0.8 * num_samples)
test_size = num_samples - train_size

train_X, test_X = X[:train_size], X[train_size:]
train_y, test_y = y[:train_size], y[train_size:]

# Data loaders
train_data = torch.utils.data.TensorDataset(train_X, train_y)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

test_data = torch.utils.data.TensorDataset(test_X, test_y)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Fully connected classification head
class SequenceClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SequenceClassifier, self).__init__()
        self.lstm = LSTM(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        outputs, _ = self.lstm(x.permute(1, 0, 2))  # seq_len, batch, input_size
        last_output = outputs[-1]  # Use the last hidden state
        return self.fc(last_output)

In [4]:
# Training setup
model = SequenceClassifier(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    if (epoch % 10) == 0:
        print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader)}")

# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch_X, batch_y in test_loader:
        outputs = model(batch_X)
        predictions = outputs.argmax(dim=1)
        correct += (predictions == batch_y).sum().item()
        total += batch_y.size(0)

print(f"Test Accuracy: {correct / total:.2%}")

Epoch 0, Loss: 0.6944447898864746
Epoch 10, Loss: 0.557159765958786
Epoch 20, Loss: 0.48266212940216063
Epoch 30, Loss: 0.47914249300956724
Epoch 40, Loss: 0.47136218309402467
Epoch 50, Loss: 0.36938887476921084
Epoch 60, Loss: 0.2578399443626404
Epoch 70, Loss: 0.22192418575286865
Epoch 80, Loss: 0.1905925652384758
Epoch 90, Loss: 0.1663687564432621
Test Accuracy: 96.00%


In [5]:
# Example Test: Print inputs, predictions, and expected outputs
model.eval()

# Select a few samples from the test set
example_X = test_X[:5]
example_y = test_y[:5]

with torch.no_grad():
    outputs = model(example_X)
    predictions = outputs.argmax(dim=1)

print("Input sequences, model predictions, and expected outputs:")
for i in range(len(example_X)):
    print(f"Sequence {i + 1}:")
    print(f"  Input: {example_X[i].squeeze(-1).numpy()}")
    print(f"  Prediction: {predictions[i].item()}")
    print(f"  Expected: {example_y[i].item()}")

Input sequences, model predictions, and expected outputs:
Sequence 1:
  Input: [-0.20597954 -1.198035    1.6662222  -0.859738    0.47962168 -0.12167113
  1.0808519   0.35500762  1.585527   -0.42019925]
  Prediction: 0
  Expected: 0
Sequence 2:
  Input: [-0.6446867  -1.4907006   0.9217508   0.30053535 -1.6256307   0.1004274
  0.23721616 -1.6597656   0.6699194  -0.01552773]
  Prediction: 0
  Expected: 0
Sequence 3:
  Input: [-0.8113887   0.09116855 -0.69002235  0.7237788  -1.1907287  -0.9923395
 -0.108426   -0.22399668 -0.79839045 -1.5641514 ]
  Prediction: 0
  Expected: 0
Sequence 4:
  Input: [ 1.2063242   0.29052877 -0.26716873  0.7285104  -0.14099081  0.5734504
  0.18021065 -0.43092886 -0.7307493  -0.17880793]
  Prediction: 1
  Expected: 1
Sequence 5:
  Input: [ 0.3999417   0.59730524  2.2906618   1.2805545   1.1404402   0.32964057
 -0.96988887 -2.097572   -0.145355    0.74130315]
  Prediction: 1
  Expected: 1
