TODO:
- Work on parallelizable, optimized 2d RNN version with CUDA/Triton/PyTorch
- Investigate applying attention between hidden state top, hidden state left, and x
- Apply rotations/reflections (rotate 0/90/180/270, flip vertically/horizontally/diagonally/antidiagonally) and use the same LSTM with surrounding border hidden + cell state learnable
- Investigate applying attention to final hidden state vectors
- Investigate multilayer 2d LSTM
- Investigate GRU vs LSTM vs other approaches
- Investigate multi-layer LSTM vs LSTM then hidden states of that LSTM -> another LSTM
- Investigate skip connections in LSTM

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
class MDLSTMCell(nn.Module):
    """
    A 2D LSTM cell following the equations in the provided image.
    For each pixel (i, j), it takes:
      - x: input at (i, j)
      - y_i-1,j (top hidden state), y_i,j-1 (left hidden state)
      - c_i-1,j (top cell state), c_i,j-1 (left cell state)
    Returns:
      - y_ij: output
      - c_ij: cell state
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Input weight matrices W_* for a, f, g, k, o
        self.Wa = nn.Linear(input_size, hidden_size)
        self.Wf = nn.Linear(input_size, hidden_size)
        self.Wg = nn.Linear(input_size, hidden_size)
        self.Wk = nn.Linear(input_size, hidden_size)
        self.Wo = nn.Linear(input_size, hidden_size)

        # Recurrent weight matrices U_* for y_{i-1,j} (top)
        self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Uf = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Ug = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Uk = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Uo = nn.Linear(hidden_size, hidden_size, bias=False)

        # Recurrent weight matrices V_* for y_{i,j-1} (left)
        self.Va = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Vf = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Vg = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Vk = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Vo = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, x, y_top, c_top, y_left, c_left):
        """
        x:       (B, input_size)
        y_top:   (B, hidden_size)
        c_top:   (B, hidden_size)
        y_left:  (B, hidden_size)
        c_left:  (B, hidden_size)
        returns: y, c (B, hidden_size)
        """

        a = torch.tanh(self.Wa(x) + self.Ua(y_top) + self.Va(y_left))
        f = torch.sigmoid(self.Wf(x) + self.Uf(y_top) + self.Vf(y_left))
        g = torch.sigmoid(self.Wg(x) + self.Ug(y_top) + self.Vg(y_left))
        k = torch.sigmoid(self.Wk(x) + self.Uk(y_top) + self.Vk(y_left))
        o = torch.sigmoid(self.Wo(x) + self.Uo(y_top) + self.Vo(y_left))

        c = f * c_top + g * c_left + a * k
        y = o * torch.tanh(c)
        return y, c


class MDLSTM(nn.Module):
    """
    2D LSTM module that applies MDLSTMCell over a 2D grid.
    Input shape: (B, H, W, input_size)
    Output:      (B, H, W, hidden_size) of hidden states
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.cell = MDLSTMCell(input_size, hidden_size)

        self.h_top = nn.Parameter(torch.zeros(1, hidden_size))
        self.c_top = nn.Parameter(torch.zeros(1, hidden_size))

        self.h_left = nn.Parameter(torch.zeros(1, hidden_size))
        self.c_left = nn.Parameter(torch.zeros(1, hidden_size))
        
        self.h_bottom = nn.Parameter(torch.zeros(1, hidden_size))
        self.c_bottom = nn.Parameter(torch.zeros(1, hidden_size))

        self.h_right = nn.Parameter(torch.zeros(1, hidden_size))
        self.c_right = nn.Parameter(torch.zeros(1, hidden_size))

    def forward(self, x):
        """
        x: (B, H, W, input_size)
        returns: h_out of shape (B, H, W, hidden_size)
        """
        B, H, W, _ = x.size()

        h_rows = []
        c_rows = []

        for i in range(H):
            h_row = []
            c_row = []
            for j in range(W):
                x_ij = x[:, i, j, :]  # (B, input_size)

                if i > 0:
                    h1 = h_rows[i-1][j]
                    c1 = c_rows[i-1][j]
                else:
                    h1 = self.h_top.expand(B, -1)
                    c1 = self.c_top.expand(B, -1)

                if j > 0:
                    h2 = h_row[j-1]
                    c2 = c_row[j-1]
                else:
                    h2 = self.h_left.expand(B, -1)
                    c2 = self.c_left.expand(B, -1)

                h_ij, c_ij = self.cell(x_ij, h1, c1, h2, c2)

                h_row.append(h_ij)
                c_row.append(c_ij)

            h_rows.append(h_row)
            c_rows.append(c_row)

        h_out = torch.stack([torch.stack(row, dim=1) for row in h_rows], dim=1) # shape (B, H, W, hidden_size)
        return h_out

In [62]:
class MDGRUCell(nn.Module):
    """
    A 2D GRU cell with a learned softmax over the two incoming hidden states.
    For each pixel (i, j), inputs are:
      - x:       (B, input_size)
      - y_top:   (B, hidden_size)
      - y_left:  (B, hidden_size)
    Returns:
      - y:       (B, hidden_size)
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size  = input_size
        self.hidden_size = hidden_size

        # GRU gates: update (z), reset (r), candidate (n)
        self.Wz = nn.Linear(input_size, hidden_size)
        self.Wr = nn.Linear(input_size, hidden_size)
        self.Wn = nn.Linear(input_size, hidden_size)

        self.Uz = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Ur = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Un = nn.Linear(hidden_size, hidden_size, bias=False)

        self.Vz = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Vr = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Vn = nn.Linear(hidden_size, hidden_size, bias=False)

        # Softmax gating for combining y_top and y_left:
        # produces 2 scores per hidden unit, which we softmax over the direction dimension
        self.Wd = nn.Linear(input_size, 2 * hidden_size)
        self.Ud = nn.Linear(hidden_size, 2 * hidden_size, bias=False)
        self.Vd = nn.Linear(hidden_size, 2 * hidden_size, bias=False)

    def forward(self, x, y_top, y_left):
        """
        x:      (B, input_size)
        y_top:  (B, hidden_size)
        y_left: (B, hidden_size)
        returns: y  (B, hidden_size)
        """
        # 1) GRU-style gates
        z = torch.sigmoid(self.Wz(x) + self.Uz(y_top) + self.Vz(y_left))
        r = torch.sigmoid(self.Wr(x) + self.Ur(y_top) + self.Vr(y_left))
        n = torch.tanh(   self.Wn(x) 
                        + self.Un(r * y_top) 
                        + self.Vn(r * y_left))

        # 2) Softmax over the two directions for each hidden unit
        #    shape (B, 2*H) → (B, 2, H)
        d_scores = (self.Wd(x)
                  + self.Ud(y_top)
                  + self.Vd(y_left)
                  ).view(x.size(0), 2, self.hidden_size)
        d = torch.softmax(d_scores, dim=1)               # (B, 2, H)
        alpha_top  = d[:, 0, :]                          # (B, H)
        alpha_left = d[:, 1, :]                          # (B, H)

        # 3) Combined previous hidden state
        y_prev = alpha_top  * y_top + alpha_left * y_left

        # 4) Final update
        y = (1 - z) * n + z * y_prev
        return y

class MDGRU(nn.Module):
    """
    2D GRU module that applies MDGRUCell over a 2D grid.
    Input shape: (B, H, W, input_size)
    Output:      (B, H, W, hidden_size) of hidden states
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.cell = MDGRUCell(input_size, hidden_size)

        # learnable initial hidden states for top and left edges
        self.h_top  = nn.Parameter(torch.zeros(1, hidden_size))
        self.h_left = nn.Parameter(torch.zeros(1, hidden_size))

    def forward(self, x):
        """
        x: (B, H, W, input_size)
        returns: h_out of shape (B, H, W, hidden_size)
        """
        B, H, W, _ = x.size()

        # we’ll accumulate rows of hidden states here
        h_rows = []

        for i in range(H):
            h_row = []
            for j in range(W):
                x_ij = x[:, i, j, :]  # (B, input_size)

                # get top neighbor (or initial top)
                if i > 0:
                    h_top = h_rows[i-1][j]
                else:
                    h_top = self.h_top.expand(B, -1)

                # get left neighbor (or initial left)
                if j > 0:
                    h_left = h_row[j-1]
                else:
                    h_left = self.h_left.expand(B, -1)

                # one MDGRU step
                h_ij = self.cell(x_ij, h_top, h_left)

                h_row.append(h_ij)

            h_rows.append(h_row)

        # stack into a tensor of shape (B, H, W, hidden_size)
        # first stack each row’s list over the W dimension, then stack rows over H
        h_out = torch.stack([torch.stack(row, dim=1) for row in h_rows], dim=1)
        return h_out

In [63]:
# class MDLSTM(nn.Module):
#     """
#     Multi-Directional 2D LSTM module with 4 directional passes.
#     Input shape: (B, H, W, input_size)
#     Output:      (B, H, W, 4 * hidden_size)
#     """
#     def __init__(self, input_size, hidden_size):
#         super().__init__()
#         self.input_size = input_size
#         self.hidden_size = hidden_size
#         self.cell = MDLSTMCell(input_size, hidden_size)

#         # Parameters for all 4 directions
#         self.h_top    = nn.Parameter(torch.zeros(1, hidden_size))
#         self.c_top    = nn.Parameter(torch.zeros(1, hidden_size))
#         self.h_left   = nn.Parameter(torch.zeros(1, hidden_size))
#         self.c_left   = nn.Parameter(torch.zeros(1, hidden_size))

#         self.h_right  = nn.Parameter(torch.zeros(1, hidden_size))
#         self.c_right  = nn.Parameter(torch.zeros(1, hidden_size))
#         self.h_bottom = nn.Parameter(torch.zeros(1, hidden_size))
#         self.c_bottom = nn.Parameter(torch.zeros(1, hidden_size))

#     def forward(self, x):
#         B, H, W, _ = x.size()

#         # Direction 1: top-left to bottom-right
#         h_grid = [[None for _ in range(W)] for _ in range(H)]
#         c_grid = [[None for _ in range(W)] for _ in range(H)]
#         for i in range(H):
#             for j in range(W):
#                 x_ij = x[:, i, j, :]
#                 h1 = h_grid[i-1][j] if i > 0 else self.h_top.expand(B, -1)
#                 c1 = c_grid[i-1][j] if i > 0 else self.c_top.expand(B, -1)
#                 h2 = h_grid[i][j-1] if j > 0 else self.h_left.expand(B, -1)
#                 c2 = c_grid[i][j-1] if j > 0 else self.c_left.expand(B, -1)
#                 h, c = self.cell(x_ij, h1, c1, h2, c2)
#                 h_grid[i][j] = h
#                 c_grid[i][j] = c
#         dir1 = torch.stack([torch.stack(row, dim=1) for row in h_grid], dim=1)

#         # Direction 2: top-right to bottom-left
#         h_grid = [[None for _ in range(W)] for _ in range(H)]
#         c_grid = [[None for _ in range(W)] for _ in range(H)]
#         for i in range(H):
#             for j in reversed(range(W)):
#                 x_ij = x[:, i, j, :]
#                 h1 = h_grid[i-1][j] if i > 0 else self.h_top.expand(B, -1)
#                 c1 = c_grid[i-1][j] if i > 0 else self.c_top.expand(B, -1)
#                 h2 = h_grid[i][j+1] if j < W-1 else self.h_right.expand(B, -1)
#                 c2 = c_grid[i][j+1] if j < W-1 else self.c_right.expand(B, -1)
#                 h, c = self.cell(x_ij, h1, c1, h2, c2)
#                 h_grid[i][j] = h
#                 c_grid[i][j] = c
#         dir2 = torch.stack([torch.stack(row, dim=1) for row in h_grid], dim=1)

#         # Direction 3: bottom-left to top-right
#         h_grid = [[None for _ in range(W)] for _ in range(H)]
#         c_grid = [[None for _ in range(W)] for _ in range(H)]
#         for i in reversed(range(H)):
#             for j in range(W):
#                 x_ij = x[:, i, j, :]
#                 h1 = h_grid[i+1][j] if i < H-1 else self.h_bottom.expand(B, -1)
#                 c1 = c_grid[i+1][j] if i < H-1 else self.c_bottom.expand(B, -1)
#                 h2 = h_grid[i][j-1] if j > 0 else self.h_left.expand(B, -1)
#                 c2 = c_grid[i][j-1] if j > 0 else self.c_left.expand(B, -1)
#                 h, c = self.cell(x_ij, h1, c1, h2, c2)
#                 h_grid[i][j] = h
#                 c_grid[i][j] = c
#         dir3 = torch.stack([torch.stack(row, dim=1) for row in h_grid], dim=1)

#         # Direction 4: bottom-right to top-left
#         h_grid = [[None for _ in range(W)] for _ in range(H)]
#         c_grid = [[None for _ in range(W)] for _ in range(H)]
#         for i in reversed(range(H)):
#             for j in reversed(range(W)):
#                 x_ij = x[:, i, j, :]
#                 h1 = h_grid[i+1][j] if i < H-1 else self.h_bottom.expand(B, -1)
#                 c1 = c_grid[i+1][j] if i < H-1 else self.c_bottom.expand(B, -1)
#                 h2 = h_grid[i][j+1] if j < W-1 else self.h_right.expand(B, -1)
#                 c2 = c_grid[i][j+1] if j < W-1 else self.c_right.expand(B, -1)
#                 h, c = self.cell(x_ij, h1, c1, h2, c2)
#                 h_grid[i][j] = h
#                 c_grid[i][j] = c
#         dir4 = torch.stack([torch.stack(row, dim=1) for row in h_grid], dim=1)

#         # Stack outputs from all directions: shape (B, H, W, hidden_size, 4)
#         h_out = torch.stack([dir1, dir2, dir3, dir4], dim=-1)
#         return h_out

In [64]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [65]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [66]:
train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

batch_size = 10000
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [71]:
class MNIST2DLSTMClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.mdlstm = MDGRU(input_size=1, hidden_size=32)
        # self.mdlstm2 = MDGRU(input_size=8, hidden_size=32)
        self.classifier = nn.Linear(32, 10)

    def forward(self, x):
        # x: (B, 1, 28, 28)

        x = x.permute(0, 2, 3, 1).contiguous() # (B, H, W, 1)

        h = self.mdlstm(x) # (B, H, W, hidden_size)
        final_state = h[:, -1, -1, :].clone() # (B, hidden_size)

        logits = self.classifier(final_state)
        return logits

        # h = self.mdlstm(x) # (B, H, W, hidden_size)
        # h2 = self.mdlstm2(h) # (B, H, W, hidden_size)

        # final_state = h[:, -1, -1, :].clone() # (B, hidden_size)
        # final_state2 = h2[:, -1, -1, :].clone() # (B, hidden_size)
        
        # logits = self.classifier(torch.concat([final_state, final_state2], dim=-1))
        # return logits

In [72]:
# class MNIST2DLSTMClassifier(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.hidden_size = 8
#         self.mdlstm = MDLSTM(input_size=1, hidden_size=self.hidden_size)
#         self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
#         self.attn = nn.MultiheadAttention(embed_dim=self.hidden_size, num_heads=1, batch_first=True)
#         self.classifier = nn.Linear(self.hidden_size, 10)

#     def forward(self, x):
#         x = x.permute(0, 2, 3, 1).contiguous()  # (B, H, W, 1)
#         h = self.mdlstm(x)  # (B, H, W, hidden_size, 4)
#         print(h.shape)

#         h_top_left = h[:, 0, 0, :, :].mean(dim=-1)
#         h_top_right = h[:, 0, -1, :, :].mean(dim=-1)
#         h_bottom_left = h[:, -1, 0, :, :].mean(dim=-1)
#         h_bottom_right = h[:, -1, -1, :, :].mean(dim=-1)

#         corners = torch.stack([h_top_left, h_top_right, h_bottom_left, h_bottom_right], dim=1)  # (B, 4, hidden_size)

#         cls_token = self.cls_token.expand(x.size(0), -1, -1)  # (B, 1, hidden_size)
#         sequence = torch.cat([cls_token, corners], dim=1)  # (B, 5, hidden_size)

#         attn_output, _ = self.attn(sequence, sequence, sequence)  # (B, 5, hidden_size)
#         cls_output = attn_output[:, 0, :]  # (B, hidden_size)

#         logits = self.classifier(cls_output)
#         return logits

In [73]:
learning_rate = 0.001 * 1
epochs = 1000

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = MNIST2DLSTMClassifier().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [74]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

mdlstm.h_top: 32 params, requires_grad=True
mdlstm.h_left: 32 params, requires_grad=True
mdlstm.cell.Wz.weight: 32 params, requires_grad=True
mdlstm.cell.Wz.bias: 32 params, requires_grad=True
mdlstm.cell.Wr.weight: 32 params, requires_grad=True
mdlstm.cell.Wr.bias: 32 params, requires_grad=True
mdlstm.cell.Wn.weight: 32 params, requires_grad=True
mdlstm.cell.Wn.bias: 32 params, requires_grad=True
mdlstm.cell.Uz.weight: 1024 params, requires_grad=True
mdlstm.cell.Ur.weight: 1024 params, requires_grad=True
mdlstm.cell.Un.weight: 1024 params, requires_grad=True
mdlstm.cell.Vz.weight: 1024 params, requires_grad=True
mdlstm.cell.Vr.weight: 1024 params, requires_grad=True
mdlstm.cell.Vn.weight: 1024 params, requires_grad=True
mdlstm.cell.Wd.weight: 64 params, requires_grad=True
mdlstm.cell.Wd.bias: 64 params, requires_grad=True
mdlstm.cell.Ud.weight: 2048 params, requires_grad=True
mdlstm.cell.Vd.weight: 2048 params, requires_grad=True
classifier.weight: 320 params, requires_grad=True
class

In [75]:
patience = 1000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.3119
Epoch [1/1000], Validation Loss: 2.3022, Validation Accuracy: 11.38%
Output Summary: Max=0.0927, Min=-0.1625, Median=-0.0830, Mean=-0.0516

Epoch [2/1000], Training Loss: 2.3015
Epoch [2/1000], Validation Loss: 2.3008, Validation Accuracy: 11.35%
Output Summary: Max=0.0832, Min=-0.1894, Median=-0.0326, Mean=-0.0430

Epoch [3/1000], Training Loss: 2.3014
Epoch [3/1000], Validation Loss: 2.3005, Validation Accuracy: 11.82%
Output Summary: Max=0.0812, Min=-0.1855, Median=-0.0334, Mean=-0.0387



KeyboardInterrupt: 

In [61]:
patience = 1000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.3199
Epoch [1/1000], Validation Loss: 2.2994, Validation Accuracy: 10.05%
Output Summary: Max=0.3197, Min=-0.4100, Median=0.0177, Mean=-0.0204

Epoch [2/1000], Training Loss: 2.2885
Epoch [2/1000], Validation Loss: 2.2744, Validation Accuracy: 11.07%
Output Summary: Max=0.2439, Min=-0.3253, Median=-0.0285, Mean=-0.0319

Epoch [3/1000], Training Loss: 2.2686
Epoch [3/1000], Validation Loss: 2.2575, Validation Accuracy: 21.28%
Output Summary: Max=0.2824, Min=-0.2883, Median=-0.0509, Mean=-0.0330

Epoch [4/1000], Training Loss: 2.2546
Epoch [4/1000], Validation Loss: 2.2601, Validation Accuracy: 18.81%
Output Summary: Max=0.4426, Min=-0.3655, Median=-0.0352, Mean=-0.0317



KeyboardInterrupt: 

In [43]:
patience = 1000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.2841
Epoch [1/1000], Validation Loss: 2.2534, Validation Accuracy: 20.79%
Output Summary: Max=0.6553, Min=-0.3376, Median=0.1000, Mean=0.1031

Epoch [2/1000], Training Loss: 2.2355
Epoch [2/1000], Validation Loss: 2.2170, Validation Accuracy: 20.38%
Output Summary: Max=0.6273, Min=-0.3303, Median=0.1150, Mean=0.1061

Epoch [3/1000], Training Loss: 2.2054
Epoch [3/1000], Validation Loss: 2.1807, Validation Accuracy: 19.85%
Output Summary: Max=0.6261, Min=-0.3204, Median=0.0984, Mean=0.1073

Epoch [4/1000], Training Loss: 2.1697
Epoch [4/1000], Validation Loss: 2.1444, Validation Accuracy: 26.45%
Output Summary: Max=0.6688, Min=-0.3628, Median=0.1046, Mean=0.1189

Epoch [5/1000], Training Loss: 2.1296
Epoch [5/1000], Validation Loss: 2.1278, Validation Accuracy: 17.13%
Output Summary: Max=0.8167, Min=-0.4004, Median=0.1101, Mean=0.1277

Epoch [6/1000], Training Loss: 2.1130
Epoch [6/1000], Validation Loss: 2.1043, Validation Accuracy: 20.50%
Output Summar

KeyboardInterrupt: 

In [44]:
for name, param in model.named_parameters():
    print(name)
    print(param)

mdlstm.h_top
Parameter containing:
tensor([[ 0.0116, -0.0051, -0.0116,  0.0060,  0.0063, -0.0119,  0.0113,  0.0055,
          0.0075,  0.0069,  0.0084, -0.0022,  0.0008, -0.0121,  0.0124, -0.0099,
         -0.0023,  0.0312,  0.0016, -0.0096,  0.0381, -0.0050,  0.0045,  0.0029,
         -0.0007,  0.0322, -0.0168,  0.0047, -0.0082,  0.0145, -0.0064,  0.0059]],
       device='mps:0', requires_grad=True)
mdlstm.c_top
Parameter containing:
tensor([[-0.0530, -0.0043, -0.0117,  0.0019, -0.0147, -0.0101, -0.0035, -0.0124,
          0.0069, -0.0061,  0.0036,  0.0016, -0.0005,  0.0021,  0.0026, -0.0028,
          0.0034,  0.0028,  0.0014, -0.0138,  0.0114, -0.0232,  0.0087, -0.0034,
         -0.0057,  0.0190,  0.0038,  0.0015, -0.0064,  0.0016,  0.0047, -0.0061]],
       device='mps:0', requires_grad=True)
mdlstm.h_left
Parameter containing:
tensor([[-0.0113,  0.0151,  0.0132, -0.0168, -0.0094, -0.0049, -0.0044, -0.0004,
          0.0180, -0.0222, -0.0023, -0.0076, -0.0136,  0.0012,  0.0054, -0.