In [1]:
import torch
from torch import nn
from torch import optim

In [2]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, sigma=0.01):
        super(LSTM, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)
        triple = lambda: (init_weight(input_size, hidden_size),
                         init_weight(hidden_size, hidden_size),
                         nn.Parameter(torch.zeros(hidden_size)))

        self.W_xi, self.W_hi, self.b_i = triple()  # input gate
        self.W_xf, self.W_hf, self.b_f = triple()  # forget gate
        self.W_xo, self.W_ho, self.b_o = triple()  # output gate
        self.W_xc, self.W_hc, self.b_c = triple()  # input node

    def forward(self, inputs, H_C=None):
        if H_C is None:
            # initial state with shape: (batch_size, hidden_size)
            H = torch.zeros((inputs.shape[1], self.hidden_size),
                device=inputs.device)
            C = torch.zeros((inputs.shape[1], self.hidden_size),
                device=inputs.device)
        
        else:
            H, C = H_C
        
        outputs = []
        for X in inputs:
            I = torch.sigmoid(torch.matmul(X, self.W_xi) +
                        torch.matmul(H, self.W_hi) + self.b_i)
            F = torch.sigmoid(torch.matmul(X, self.W_xf) +
                            torch.matmul(H, self.W_hf) + self.b_f)
            O = torch.sigmoid(torch.matmul(X, self.W_xo) +
                            torch.matmul(H, self.W_ho) + self.b_o)
            C_tilde = torch.tanh(torch.matmul(X, self.W_xc) +
                            torch.matmul(H, self.W_hc) + self.b_c)
            C = F * C + I * C_tilde
            H = O * torch.tanh(C)

            outputs.append(H)

        return outputs, (H, C)

In [3]:
torch.manual_seed(42)
sequence_length = 10
input_size = 5
hidden_size = 16
batch_size = 32
num_classes = 2

# random input sequences
X = torch.randn(sequence_length, batch_size, input_size)

# random target outputs for many-to-one classification
y = torch.randint(0, num_classes, (batch_size,))

In [4]:
# define model and output layer for classification
hidden_size = 16
output_size = num_classes

model = LSTM(input_size=input_size, hidden_size=hidden_size)
fc = nn.Linear(hidden_size, output_size)

# Define loss and optimiser
criterion = nn.CrossEntropyLoss()
optimiser = optim.Adam(list(model.parameters()) + list(fc.parameters()), lr=0.001)

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    # reset gradients
    optimiser.zero_grad()

    # forward pass
    outputs, (H, C) = model(X)
    logits = fc(H)  # Use the last hidden state for many-to-one classification
    loss = criterion(logits, y)

    # backward pass
    loss.backward()
    optimiser.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

Epoch 1/20, Loss: 0.6888551115989685
Epoch 2/20, Loss: 0.6884186267852783
Epoch 3/20, Loss: 0.687986433506012
Epoch 4/20, Loss: 0.6875566244125366
Epoch 5/20, Loss: 0.6871271729469299
Epoch 6/20, Loss: 0.6866956353187561
Epoch 7/20, Loss: 0.6862595677375793
Epoch 8/20, Loss: 0.6858169436454773
Epoch 9/20, Loss: 0.6853659749031067
Epoch 10/20, Loss: 0.6849049925804138
Epoch 11/20, Loss: 0.6844324469566345
Epoch 12/20, Loss: 0.6839473247528076
Epoch 13/20, Loss: 0.6834484338760376
Epoch 14/20, Loss: 0.6829347014427185
Epoch 15/20, Loss: 0.682405412197113
Epoch 16/20, Loss: 0.6818597316741943
Epoch 17/20, Loss: 0.6812968254089355
Epoch 18/20, Loss: 0.680715799331665
Epoch 19/20, Loss: 0.6801159977912903
Epoch 20/20, Loss: 0.6794965267181396


In [5]:
# test
with torch.no_grad():
    outputs, (H, C) = model(X)
    logits = fc(H)
    predictions = torch.argmax(logits, dim=1)
    print(f"Predictions: {predictions}")
    print(f"Ground Truth: {y}")

Predictions: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
Ground Truth: tensor([1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1,
        1, 0, 0, 0, 0, 1, 0, 0])
