TODO:
- Work on parallelizable, optimized 2d RNN version
- Investigate applying attention between hidden state top, hidden state left, and x
- Apply rotations/reflections (rotate 0/90/180/270, flip vertically/horizontally/diagonally/antidiagonally) and use the same LSTM with surrounding border hidden + cell state learnable
- Investigate applying attention to final hidden state vectors
- Investigate multilayer 2d LSTM

In [1]:
import torch
import torch.nn as nn

class LSTM2DCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTM2DCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Gates: input, forget (horizontal and vertical), output, and candidate cell
        self.W = nn.Linear(input_size + 2 * hidden_size, 5 * hidden_size)

    def forward(self, x, h_left, c_left, h_top, c_top):
        combined = torch.cat([x, h_left, h_top], dim=1)  # concatenate along features
        gates = self.W(combined)

        i, f_left, f_top, o, g = gates.chunk(5, dim=1)
        i = torch.sigmoid(i)
        f_left = torch.sigmoid(f_left)
        f_top = torch.sigmoid(f_top)
        o = torch.sigmoid(o)
        g = torch.tanh(g)

        # Combine memory from left and top
        c = i * g + f_left * c_left + f_top * c_top
        h = o * torch.tanh(c)

        return h, c

class LSTM2D(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTM2D, self).__init__()
        self.cell = LSTM2DCell(input_size, hidden_size)

    def forward(self, input_grid):
        """
        input_grid: Tensor of shape (H, W, B, input_size)
        Returns:
            h_grid: Tensor of shape (H, W, B, hidden_size)
        """
        H, W, B, _ = input_grid.shape
        device = input_grid.device
        h_grid = torch.zeros(H, W, B, self.cell.hidden_size, device=device)
        c_grid = torch.zeros(H, W, B, self.cell.hidden_size, device=device)

        for i in range(H):
            for j in range(W):
                x = input_grid[i, j]
                h_left = h_grid[i, j - 1] if j > 0 else torch.zeros(B, self.cell.hidden_size, device=device)
                c_left = c_grid[i, j - 1] if j > 0 else torch.zeros(B, self.cell.hidden_size, device=device)
                h_top = h_grid[i - 1, j] if i > 0 else torch.zeros(B, self.cell.hidden_size, device=device)
                c_top = c_grid[i - 1, j] if i > 0 else torch.zeros(B, self.cell.hidden_size, device=device)

                h, c = self.cell(x, h_left, c_left, h_top, c_top)
                h_grid[i, j] = h
                c_grid[i, j] = c

        return h_grid

In [2]:
H, W, B, input_size, hidden_size = 10, 10, 4, 16, 32
input_grid = torch.randn(H, W, B, input_size)

model = LSTM2D(input_size, hidden_size)
output = model(input_grid)  # output shape: (H, W, B, hidden_size)

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchinfo import summary

In [14]:
from __future__ import annotations

import time
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# -----------------------------
#  2‑D LSTM building blocks
# -----------------------------

class RowCol2DLSTM(nn.Module):
    """A simple 2‑D LSTM composed of a row LSTM followed by a column LSTM.

    Args:
        row_hidden: Hidden units in the row LSTM (×2 when bidirectional).
        col_hidden: Hidden units in the column LSTM (×2 when bidirectional).
        num_classes: Output classes.
    """

    def __init__(self, *, row_hidden: int = 8, col_hidden: int = 8, num_classes: int = 10):
        super().__init__()
        self.row_hidden = row_hidden
        self.col_hidden = col_hidden

        # Processes each *row* as a sequence of W pixels → (B, H, W, 2·row_hidden)
        self.row_lstm = nn.LSTM(
            input_size=1,
            hidden_size=row_hidden,
            num_layers=1,
            bidirectional=True,
            batch_first=True,
        )

        # Processes each *column* as a sequence of H row features → (B, W, 2·col_hidden)
        self.col_lstm = nn.LSTM(
            input_size=row_hidden * 2,
            hidden_size=col_hidden,
            num_layers=1,
            bidirectional=True,
            batch_first=True,
        )

        # Global representation → logits
        self.classifier = nn.Linear(col_hidden * 2, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x shape = (B, 1, H, W) (MNIST uses H=W=28)."""
        b, c, h, w = x.shape
        assert c == 1, "Expect single‑channel greyscale images"

        # ── Row pass ────────────────────────────────────────────────────────────
        x = x.squeeze(1)  # (B, H, W)
        row_feats = []
        for i in range(h):
            row_seq = x[:, i, :].unsqueeze(-1)  # (B, W, 1)
            out, _ = self.row_lstm(row_seq)     # (B, W, 2·row_hidden)
            row_feats.append(out)
        row_feats = torch.stack(row_feats, dim=1)  # (B, H, W, 2·row_hidden)

        # ── Column pass ─────────────────────────────────────────────────────────
        col_feats = []
        for j in range(w):
            col_seq = row_feats[:, :, j, :]                # (B, H, 2·row_hidden)
            out, _ = self.col_lstm(col_seq)                # (B, H, 2·col_hidden)
            col_feats.append(out[:, -1, :])                # last step along column
        col_feats = torch.stack(col_feats, dim=1)          # (B, W, 2·col_hidden)

        # ── Classification head ────────────────────────────────────────────────
        global_feat = col_feats.mean(dim=1)                # (B, 2·col_hidden)
        logits = self.classifier(global_feat)              # (B, num_classes)
        return logits

In [None]:
import torch
import torch.nn as nn

class MDLSTMCell(nn.Module):
    """
    A single 2D LSTM cell that takes input x_t, hidden+cell states from top (h1,c1)
    and from left (h2,c2), and computes new (h, c).
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size  = input_size
        self.hidden_size = hidden_size
        # we'll compute 5 vectors: i, f1, f2, o, g
        self.linear = nn.Linear(input_size + 2*hidden_size, 5*hidden_size)

    def forward(self, x, h1, c1, h2, c2):
        """
        x:      (B, input_size)
        h1, c1: (B, hidden_size) from top
        h2, c2: (B, hidden_size) from left
        returns: (h, c) each (B, hidden_size)
        """
        B = x.size(0)
        # concat input and two hidden states
        combined = torch.cat([x, h1, h2], dim=1)  # (B, input + 2*hidden)
        gates = self.linear(combined)
        # split into gates
        i, f1, f2, o, g = gates.chunk(5, dim=1)

        i  = torch.sigmoid(i)
        f1 = torch.sigmoid(f1)
        f2 = torch.sigmoid(f2)
        o  = torch.sigmoid(o)
        g  = torch.tanh(g)

        # new cell: combine both prev cells
        c = f1 * c1 + f2 * c2 + i * g
        h = o * torch.tanh(c)
        return h, c


class MDLSTM(nn.Module):
    """
    2D LSTM module that applies MDLSTMCell over a 2D grid.
    Input shape: (B, H, W, input_size)
    Output:      (B, H, W, hidden_size) of hidden states
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.cell = MDLSTMCell(input_size, hidden_size)
        self.cell2 = MDLSTMCell(hidden_size, hidden_size)
        self.h_top = nn.Parameter(torch.zeros(1, hidden_size))
        self.c_top = nn.Parameter(torch.zeros(1, hidden_size))
        self.h_left = nn.Parameter(torch.zeros(1, hidden_size))
        self.c_left = nn.Parameter(torch.zeros(1, hidden_size))

    def forward(self, x):
        """
        x: (B, H, W, input_size)
        returns: h_out of shape (B, H, W, hidden_size)
        """
        B, H, W, _ = x.size()
        device = x.device

        # We'll store hidden states row by row
        h_rows = []
        c_rows = []

        for i in range(H):
            h_row = []
            c_row = []
            for j in range(W):
                x_ij = x[:, i, j, :]  # (B, input_size)

                if i > 0:
                    h1 = h_rows[i-1][j]
                    c1 = c_rows[i-1][j]
                else:
                    h1 = self.h_top.expand(B, -1)
                    c1 = self.c_top.expand(B, -1)

                if j > 0:
                    h2 = h_row[j-1]
                    c2 = c_row[j-1]
                else:
                    h2 = self.h_left.expand(B, -1)
                    c2 = self.c_left.expand(B, -1)

                h_ij, c_ij = self.cell(x_ij, h1, c1, h2, c2)

                h_row.append(h_ij)
                c_row.append(c_ij)

            h_rows.append(h_row)
            c_rows.append(c_row)

        # Stack everything to get shape (B, H, W, hidden_size)
        h_out = torch.stack([torch.stack(row, dim=1) for row in h_rows], dim=1)
        return h_out

In [80]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [81]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [88]:
train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

batch_size = 5000
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [89]:
class MNIST2DLSTMClassifier(nn.Module):
    def __init__(self, hidden_size=32, num_classes=10):
        super().__init__()

        self.mdlstm = MDLSTM(input_size=1, hidden_size=hidden_size)
        self.classifier = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # x: (B, 1, 28, 28)
        # reformat to (B, H, W, 1)
        x = x.permute(0, 2, 3, 1).contiguous()
        # run the 2D‐LSTM
        h = self.mdlstm(x)                # → (B, H, W, hidden_size)
        # global average pool over all H×W cells
        final_state = h[:, -1, -1, :].clone()
    
        logits = self.classifier(final_state)  # → (B, num_classes)
        return logits

In [90]:
learning_rate = 0.001 * 1
epochs = 1000

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = MNIST2DLSTMClassifier().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [91]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

mdlstm.cell.linear.weight: 10400 params, requires_grad=True
mdlstm.cell.linear.bias: 160 params, requires_grad=True
classifier.weight: 320 params, requires_grad=True
classifier.bias: 10 params, requires_grad=True

10890


In [None]:
patience = 1000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.2957
Epoch [1/1000], Validation Loss: 2.2898, Validation Accuracy: 13.75%
Output Summary: Max=0.4475, Min=-0.3677, Median=0.0605, Mean=0.0795

Epoch [2/1000], Training Loss: 2.2215
Epoch [2/1000], Validation Loss: 2.1714, Validation Accuracy: 19.55%
Output Summary: Max=0.5182, Min=-0.3746, Median=0.1035, Mean=0.0700

Epoch [3/1000], Training Loss: 2.1153
Epoch [3/1000], Validation Loss: 2.0406, Validation Accuracy: 22.57%
Output Summary: Max=0.7091, Min=-0.6550, Median=0.0728, Mean=0.0765

Epoch [4/1000], Training Loss: 2.0633
Epoch [4/1000], Validation Loss: 2.0161, Validation Accuracy: 25.48%
Output Summary: Max=0.9221, Min=-0.8647, Median=0.0966, Mean=0.1004

Epoch [5/1000], Training Loss: 1.9554
Epoch [5/1000], Validation Loss: 1.9014, Validation Accuracy: 22.25%
Output Summary: Max=1.0513, Min=-1.0341, Median=0.0903, Mean=0.0891

Epoch [6/1000], Training Loss: 1.8490
Epoch [6/1000], Validation Loss: 1.7864, Validation Accuracy: 29.71%
Output Summar

KeyboardInterrupt: 