TODO:
- Work on parallelizable, optimized 2d RNN version with CUDA/Triton/PyTorch
- Investigate applying attention between hidden state top, hidden state left, and x
- Apply rotations/reflections (rotate 0/90/180/270, flip vertically/horizontally/diagonally/antidiagonally) and use the same LSTM with surrounding border hidden + cell state learnable
- Investigate applying attention to final hidden state vectors
- Investigate multilayer 2d LSTM
- Investigate GRU vs LSTM vs other approaches
- Investigate multi-layer LSTM vs LSTM then hidden states of that LSTM -> another LSTM
- Investigate skip connections in LSTM

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
class MDLSTMCell(nn.Module):
    """
    A 2D LSTM cell following the equations in the provided image.
    For each pixel (i, j), it takes:
      - x: input at (i, j)
      - y_i-1,j (top hidden state), y_i,j-1 (left hidden state)
      - c_i-1,j (top cell state), c_i,j-1 (left cell state)
    Returns:
      - y_ij: output
      - c_ij: cell state
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Input weight matrices W_* for a, f, g, k, o
        self.Wa = nn.Linear(input_size, hidden_size)
        self.Wf = nn.Linear(input_size, hidden_size)
        self.Wg = nn.Linear(input_size, hidden_size)
        self.Wk = nn.Linear(input_size, hidden_size)
        self.Wo = nn.Linear(input_size, hidden_size)

        # Recurrent weight matrices U_* for y_{i-1,j} (top)
        self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Uf = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Ug = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Uk = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Uo = nn.Linear(hidden_size, hidden_size, bias=False)

        # Recurrent weight matrices V_* for y_{i,j-1} (left)
        self.Va = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Vf = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Vg = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Vk = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Vo = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, x, y_top, c_top, y_left, c_left):
        """
        x:       (B, input_size)
        y_top:   (B, hidden_size)
        c_top:   (B, hidden_size)
        y_left:  (B, hidden_size)
        c_left:  (B, hidden_size)
        returns: y, c (B, hidden_size)
        """

        a = torch.tanh(self.Wa(x) + self.Ua(y_top) + self.Va(y_left))
        f = torch.sigmoid(self.Wf(x) + self.Uf(y_top) + self.Vf(y_left))
        g = torch.sigmoid(self.Wg(x) + self.Ug(y_top) + self.Vg(y_left))
        k = torch.sigmoid(self.Wk(x) + self.Uk(y_top) + self.Vk(y_left))
        o = torch.sigmoid(self.Wo(x) + self.Uo(y_top) + self.Vo(y_left))

        c = f * c_top + g * c_left + a * k
        y = o * torch.tanh(c)
        return y, c


class MDLSTM(nn.Module):
    """
    2D LSTM module that applies MDLSTMCell over a 2D grid.
    Input shape: (B, H, W, input_size)
    Output:      (B, H, W, hidden_size) of hidden states
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.cell = MDLSTMCell(input_size, hidden_size)

        self.h_top = nn.Parameter(torch.zeros(1, hidden_size))
        self.c_top = nn.Parameter(torch.zeros(1, hidden_size))

        self.h_left = nn.Parameter(torch.zeros(1, hidden_size))
        self.c_left = nn.Parameter(torch.zeros(1, hidden_size))
        
        # self.h_bottom = nn.Parameter(torch.zeros(1, hidden_size))
        # self.c_bottom = nn.Parameter(torch.zeros(1, hidden_size))

        # self.h_right = nn.Parameter(torch.zeros(1, hidden_size))
        # self.c_right = nn.Parameter(torch.zeros(1, hidden_size))

    def forward(self, x):
        """
        x: (B, H, W, input_size)
        returns: h_out of shape (B, H, W, hidden_size)
        """
        B, H, W, _ = x.size()

        h_rows = []
        c_rows = []

        for i in range(H):
            h_row = []
            c_row = []
            for j in range(W):
                x_ij = x[:, i, j, :]  # (B, input_size)

                if i > 0:
                    h1 = h_rows[i-1][j]
                    c1 = c_rows[i-1][j]
                else:
                    h1 = self.h_top.expand(B, -1)
                    c1 = self.c_top.expand(B, -1)

                if j > 0:
                    h2 = h_row[j-1]
                    c2 = c_row[j-1]
                else:
                    h2 = self.h_left.expand(B, -1)
                    c2 = self.c_left.expand(B, -1)

                h_ij, c_ij = self.cell(x_ij, h1, c1, h2, c2)

                h_row.append(h_ij)
                c_row.append(c_ij)

            h_rows.append(h_row)
            c_rows.append(c_row)

        h_out = torch.stack([torch.stack(row, dim=1) for row in h_rows], dim=1) # shape (B, H, W, hidden_size)
        return h_out

In [3]:
class MDLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size   = input_size
        self.hidden_size  = hidden_size
        self.cell         = MDLSTMCell(input_size, hidden_size)

        # only need top‐ and left‐boundary now:
        self.h_top    = nn.Parameter(torch.zeros(1, hidden_size))
        self.c_top    = nn.Parameter(torch.zeros(1, hidden_size))
        self.h_left   = nn.Parameter(torch.zeros(1, hidden_size))
        self.c_left   = nn.Parameter(torch.zeros(1, hidden_size))

    def forward(self, x):
        B, H, W, _ = x.shape
        dev        = x.device

        # flatten the input grid to (B, H*W, input_size)
        x_flat = x.view(B, H*W, self.input_size)

        # prepare padded h/c of shape (B, (H+1)*(W+1), hidden_size)
        # with row 0 = top‐boundary, col 0 = left‐boundary
        pad_h = torch.zeros(B, (H+1)*(W+1), self.hidden_size, device=dev)
        pad_c = torch.zeros_like(pad_h)

        # fill in the boundaries:
        # top boundary at padded indices [0,  W+1 ... 2*(W+1)-1 …]
        pad_h[:, 1:(W+1), :] = self.h_top
        pad_c[:, 1:(W+1), :] = self.c_top
        # left boundary at every (i*(W+1) + 0)
        rows = torch.arange(1, H+1, device=dev) * (W+1)
        pad_h[:, rows, :]    = self.h_left
        pad_c[:, rows, :]    = self.c_left

        # now process anti‑diagonals
        out_h = torch.zeros(B, H*W, self.hidden_size, device=dev)
        out_c = torch.zeros_like(out_h)

        for k in range(2, H+W+2):
            # valid padded‑indices i',j' with i'+j' = k  (i',j' in [1..H]×[1..W])
            i_pos = torch.arange(max(1, k-W), min(H+1, k), device=dev)
            j_pos = k - i_pos

            # compute flat‑indices into the padded array:
            # idx = i'*(W+1) + j'
            idx_cur  = i_pos*(W+1) + j_pos
            idx_top  = (i_pos-1)*(W+1) + j_pos
            idx_left = i_pos*(W+1) + (j_pos-1)

            # gather inputs and states for the whole diagonal:
            x_k  = x_flat[:, idx_cur-1, :]         # flatten index to original grid
            h1   = pad_h .index_select(1, idx_top)
            c1   = pad_c .index_select(1, idx_top)
            h2   = pad_h .index_select(1, idx_left)
            c2   = pad_c .index_select(1, idx_left)

            # flatten batch×diagonal into one big batch
            N    = idx_cur.size(0)
            x_f  = x_k .reshape(B*N, -1)
            h1_f = h1  .reshape(B*N, -1)
            c1_f = c1  .reshape(B*N, -1)
            h2_f = h2  .reshape(B*N, -1)
            c2_f = c2  .reshape(B*N, -1)

            # single call for the entire diagonal
            h_f, c_f = self.cell(x_f, h1_f, c1_f, h2_f, c2_f)

            # un‑flatten back to (B, N, hidden)
            h_k = h_f.view(B, N, self.hidden_size)
            c_k = c_f.view(B, N, self.hidden_size)

            # scatter into both the padded state (for future diagonals)
            pad_h[:, idx_cur, :] = h_k
            pad_c[:, idx_cur, :] = c_k

            # and into your output grid
            out_h[:, idx_cur-1, :] = h_k
            out_c[:, idx_cur-1, :] = c_k

        # reshape to (B, H, W, hidden)
        return out_h.view(B, H, W, self.hidden_size)

In [4]:
# class MDLSTM(nn.Module):
#     """
#     Multi-Directional 2D LSTM module with 4 directional passes.
#     Input shape: (B, H, W, input_size)
#     Output:      (B, H, W, 4 * hidden_size)
#     """
#     def __init__(self, input_size, hidden_size):
#         super().__init__()
#         self.input_size = input_size
#         self.hidden_size = hidden_size
#         self.cell = MDLSTMCell(input_size, hidden_size)

#         # Parameters for all 4 directions
#         self.h_top    = nn.Parameter(torch.zeros(1, hidden_size))
#         self.c_top    = nn.Parameter(torch.zeros(1, hidden_size))
#         self.h_left   = nn.Parameter(torch.zeros(1, hidden_size))
#         self.c_left   = nn.Parameter(torch.zeros(1, hidden_size))

#         self.h_right  = nn.Parameter(torch.zeros(1, hidden_size))
#         self.c_right  = nn.Parameter(torch.zeros(1, hidden_size))
#         self.h_bottom = nn.Parameter(torch.zeros(1, hidden_size))
#         self.c_bottom = nn.Parameter(torch.zeros(1, hidden_size))

#     def forward(self, x):
#         B, H, W, _ = x.size()

#         # Direction 1: top-left to bottom-right
#         h_grid = [[None for _ in range(W)] for _ in range(H)]
#         c_grid = [[None for _ in range(W)] for _ in range(H)]
#         for i in range(H):
#             for j in range(W):
#                 x_ij = x[:, i, j, :]
#                 h1 = h_grid[i-1][j] if i > 0 else self.h_top.expand(B, -1)
#                 c1 = c_grid[i-1][j] if i > 0 else self.c_top.expand(B, -1)
#                 h2 = h_grid[i][j-1] if j > 0 else self.h_left.expand(B, -1)
#                 c2 = c_grid[i][j-1] if j > 0 else self.c_left.expand(B, -1)
#                 h, c = self.cell(x_ij, h1, c1, h2, c2)
#                 h_grid[i][j] = h
#                 c_grid[i][j] = c
#         dir1 = torch.stack([torch.stack(row, dim=1) for row in h_grid], dim=1)

#         # Direction 2: top-right to bottom-left
#         h_grid = [[None for _ in range(W)] for _ in range(H)]
#         c_grid = [[None for _ in range(W)] for _ in range(H)]
#         for i in range(H):
#             for j in reversed(range(W)):
#                 x_ij = x[:, i, j, :]
#                 h1 = h_grid[i-1][j] if i > 0 else self.h_top.expand(B, -1)
#                 c1 = c_grid[i-1][j] if i > 0 else self.c_top.expand(B, -1)
#                 h2 = h_grid[i][j+1] if j < W-1 else self.h_right.expand(B, -1)
#                 c2 = c_grid[i][j+1] if j < W-1 else self.c_right.expand(B, -1)
#                 h, c = self.cell(x_ij, h1, c1, h2, c2)
#                 h_grid[i][j] = h
#                 c_grid[i][j] = c
#         dir2 = torch.stack([torch.stack(row, dim=1) for row in h_grid], dim=1)

#         # Direction 3: bottom-left to top-right
#         h_grid = [[None for _ in range(W)] for _ in range(H)]
#         c_grid = [[None for _ in range(W)] for _ in range(H)]
#         for i in reversed(range(H)):
#             for j in range(W):
#                 x_ij = x[:, i, j, :]
#                 h1 = h_grid[i+1][j] if i < H-1 else self.h_bottom.expand(B, -1)
#                 c1 = c_grid[i+1][j] if i < H-1 else self.c_bottom.expand(B, -1)
#                 h2 = h_grid[i][j-1] if j > 0 else self.h_left.expand(B, -1)
#                 c2 = c_grid[i][j-1] if j > 0 else self.c_left.expand(B, -1)
#                 h, c = self.cell(x_ij, h1, c1, h2, c2)
#                 h_grid[i][j] = h
#                 c_grid[i][j] = c
#         dir3 = torch.stack([torch.stack(row, dim=1) for row in h_grid], dim=1)

#         # Direction 4: bottom-right to top-left
#         h_grid = [[None for _ in range(W)] for _ in range(H)]
#         c_grid = [[None for _ in range(W)] for _ in range(H)]
#         for i in reversed(range(H)):
#             for j in reversed(range(W)):
#                 x_ij = x[:, i, j, :]
#                 h1 = h_grid[i+1][j] if i < H-1 else self.h_bottom.expand(B, -1)
#                 c1 = c_grid[i+1][j] if i < H-1 else self.c_bottom.expand(B, -1)
#                 h2 = h_grid[i][j+1] if j < W-1 else self.h_right.expand(B, -1)
#                 c2 = c_grid[i][j+1] if j < W-1 else self.c_right.expand(B, -1)
#                 h, c = self.cell(x_ij, h1, c1, h2, c2)
#                 h_grid[i][j] = h
#                 c_grid[i][j] = c
#         dir4 = torch.stack([torch.stack(row, dim=1) for row in h_grid], dim=1)

#         # Stack outputs from all directions: shape (B, H, W, hidden_size, 4)
#         h_out = torch.stack([dir1, dir2, dir3, dir4], dim=-1)
#         return h_out

In [17]:
import torch
import torch.nn as nn

class AntidiagMDLSTM(nn.Module):
    """
    MD-LSTM whose update order follows the antidiagonals
    (i + j is constant).  At every step we launch a single
    fused kernel that updates all cells on that antidiagonal.
    
    Input :  (B, H, W, input_size)
    Output:  (B, H, W, hidden_size)
    """
    def __init__(self, input_size: int, hidden_size: int, H: int, W: int):
        super().__init__()
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.H, self.W   = H, W

        self.cell = MDLSTMCell(input_size, hidden_size)   # unchanged

        # learnable borders
        self.register_parameter("h_top",  nn.Parameter(torch.zeros(1, hidden_size)))
        self.register_parameter("c_top",  nn.Parameter(torch.zeros(1, hidden_size)))
        self.register_parameter("h_left", nn.Parameter(torch.zeros(1, hidden_size)))
        self.register_parameter("c_left", nn.Parameter(torch.zeros(1, hidden_size)))

        # -------------------- pre‑compute antidiagonal index tensors --------------------
        diags = []
        for d in range(H + W - 1):
            # i + j = d  ▸  j = d − i
            i0 = max(0, d - (W - 1))
            i1 = min(H - 1, d)
            i  = torch.arange(i0, i1 + 1)
            j  = d - i
            diags.append((i, j))     # each (N_d,) tensor lives on CPU; moved to GPU in forward
        self.register_buffer("_diag_i", nn.utils.rnn.pad_sequence([t[0] for t in diags],
                                                                  batch_first=True, padding_value=-1),
                             persistent=False)
        self.register_buffer("_diag_j", nn.utils.rnn.pad_sequence([t[1] for t in diags],
                                                                  batch_first=True, padding_value=-1),
                             persistent=False)
        self._max_diag_len = self._diag_i.size(1)

    # ------------------------------------------------------------------------------------
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, H, W, input_size)
        returns: (B, H, W, hidden_size)
        """
        B, H, W, _ = x.shape
        device     = x.device

        h = x.new_zeros(B, H, W, self.hidden_size)
        c = x.new_zeros(B, H, W, self.hidden_size)

        # copy pre‑computed indices to the correct device once
        diag_i = self._diag_i.to(device)
        diag_j = self._diag_j.to(device)

        for d in range(H + W - 1):                         # ❶ only loop that remains
            i = diag_i[d]                                  # (N_pad,) ↦ padding value −1 for fillers
            j = diag_j[d]

            valid_mask = (i >= 0)                          # shape (N_pad,)
            i_valid    = i[valid_mask]
            j_valid    = j[valid_mask]
            n          = i_valid.numel()                   # N_d for this antidiagonal
            if n == 0:
                continue

            # gather current inputs in one read
            x_d = x[:, i_valid, j_valid, :]                # (B, n, input_size)

            # get neighbours
            # top neighbours
            top_exists = i_valid > 0
            h1 = torch.where(
                top_exists.unsqueeze(0).unsqueeze(-1),
                h[:, i_valid - 1, j_valid, :],
                self.h_top.expand(B, n, -1)
            )
            c1 = torch.where(
                top_exists.unsqueeze(0).unsqueeze(-1),
                c[:, i_valid - 1, j_valid, :],
                self.c_top.expand(B, n, -1)
            )

            # left neighbours
            left_exists = j_valid > 0
            h2 = torch.where(
                left_exists.unsqueeze(0).unsqueeze(-1),
                h[:, i_valid, j_valid - 1, :],
                self.h_left.expand(B, n, -1)
            )
            c2 = torch.where(
                left_exists.unsqueeze(0).unsqueeze(-1),
                c[:, i_valid, j_valid - 1, :],
                self.c_left.expand(B, n, -1)
            )

            # ❷     fully‑vectorised cell update for the whole antidiagonal
            h_new, c_new = self.cell(
                x_d.reshape(B * n, -1),
                h1.reshape(B * n, -1),
                c1.reshape(B * n, -1),
                h2.reshape(B * n, -1),
                c2.reshape(B * n, -1),
            )
            h_new = h_new.view(B, n, -1)
            c_new = c_new.view(B, n, -1)

            # scatter results back
            h[:, i_valid, j_valid, :] = h_new
            c[:, i_valid, j_valid, :] = c_new

        return h

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [5]:
train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

batch_size = 10000
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [9]:
class MNIST2DLSTMClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.mdlstm = MDLSTM(input_size=1, hidden_size=32)
        # self.mdlstm2 = MDGRU(input_size=8, hidden_size=32)
        self.classifier = nn.Linear(32, 10)

    def forward(self, x):
        # x: (B, 1, 28, 28)

        x = x.permute(0, 2, 3, 1).contiguous() # (B, H, W, 1)

        h = self.mdlstm(x) # (B, H, W, hidden_size)
        final_state = h[:, -1, -1, :].clone() # (B, hidden_size)

        logits = self.classifier(final_state)
        return logits

        # h = self.mdlstm(x) # (B, H, W, hidden_size)
        # h2 = self.mdlstm2(h) # (B, H, W, hidden_size)

        # final_state = h[:, -1, -1, :].clone() # (B, hidden_size)
        # final_state2 = h2[:, -1, -1, :].clone() # (B, hidden_size)
        
        # logits = self.classifier(torch.concat([final_state, final_state2], dim=-1))
        # return logits

In [10]:
# class MNIST2DLSTMClassifier(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.hidden_size = 8
#         self.mdlstm = MDLSTM(input_size=1, hidden_size=self.hidden_size)
#         self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
#         self.attn = nn.MultiheadAttention(embed_dim=self.hidden_size, num_heads=1, batch_first=True)
#         self.classifier = nn.Linear(self.hidden_size, 10)

#     def forward(self, x):
#         x = x.permute(0, 2, 3, 1).contiguous()  # (B, H, W, 1)
#         h = self.mdlstm(x)  # (B, H, W, hidden_size, 4)
#         print(h.shape)

#         h_top_left = h[:, 0, 0, :, :].mean(dim=-1)
#         h_top_right = h[:, 0, -1, :, :].mean(dim=-1)
#         h_bottom_left = h[:, -1, 0, :, :].mean(dim=-1)
#         h_bottom_right = h[:, -1, -1, :, :].mean(dim=-1)

#         corners = torch.stack([h_top_left, h_top_right, h_bottom_left, h_bottom_right], dim=1)  # (B, 4, hidden_size)

#         cls_token = self.cls_token.expand(x.size(0), -1, -1)  # (B, 1, hidden_size)
#         sequence = torch.cat([cls_token, corners], dim=1)  # (B, 5, hidden_size)

#         attn_output, _ = self.attn(sequence, sequence, sequence)  # (B, 5, hidden_size)
#         cls_output = attn_output[:, 0, :]  # (B, hidden_size)

#         logits = self.classifier(cls_output)
#         return logits

In [11]:
learning_rate = 0.001 * 1
epochs = 1000

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = MNIST2DLSTMClassifier().to(device)
# model = torch.compile(MNIST2DLSTMClassifier()).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

mdlstm.h_top: 32 params, requires_grad=True
mdlstm.c_top: 32 params, requires_grad=True
mdlstm.h_left: 32 params, requires_grad=True
mdlstm.c_left: 32 params, requires_grad=True
mdlstm.cell.Wa.weight: 32 params, requires_grad=True
mdlstm.cell.Wa.bias: 32 params, requires_grad=True
mdlstm.cell.Wf.weight: 32 params, requires_grad=True
mdlstm.cell.Wf.bias: 32 params, requires_grad=True
mdlstm.cell.Wg.weight: 32 params, requires_grad=True
mdlstm.cell.Wg.bias: 32 params, requires_grad=True
mdlstm.cell.Wk.weight: 32 params, requires_grad=True
mdlstm.cell.Wk.bias: 32 params, requires_grad=True
mdlstm.cell.Wo.weight: 32 params, requires_grad=True
mdlstm.cell.Wo.bias: 32 params, requires_grad=True
mdlstm.cell.Ua.weight: 1024 params, requires_grad=True
mdlstm.cell.Uf.weight: 1024 params, requires_grad=True
mdlstm.cell.Ug.weight: 1024 params, requires_grad=True
mdlstm.cell.Uk.weight: 1024 params, requires_grad=True
mdlstm.cell.Uo.weight: 1024 params, requires_grad=True
mdlstm.cell.Va.weight: 1024

In [13]:
patience = 1000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.2956
Epoch [1/1000], Validation Loss: 2.2787, Validation Accuracy: 18.41%
Output Summary: Max=0.3589, Min=-0.3319, Median=-0.0739, Mean=-0.0826

Epoch [2/1000], Training Loss: 2.2739
Epoch [2/1000], Validation Loss: 2.2611, Validation Accuracy: 18.34%
Output Summary: Max=0.2645, Min=-0.3802, Median=-0.0657, Mean=-0.0659

Epoch [3/1000], Training Loss: 2.2547
Epoch [3/1000], Validation Loss: 2.2413, Validation Accuracy: 18.60%
Output Summary: Max=0.3291, Min=-0.4226, Median=-0.0528, Mean=-0.0485

Epoch [4/1000], Training Loss: 2.2344
Epoch [4/1000], Validation Loss: 2.2188, Validation Accuracy: 18.43%
Output Summary: Max=0.4217, Min=-0.4403, Median=-0.0220, Mean=-0.0433



KeyboardInterrupt: 

In [25]:
patience = 1000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.3316
Epoch [1/1000], Validation Loss: 2.3313, Validation Accuracy: 11.35%
Output Summary: Max=0.5718, Min=-0.2931, Median=0.1348, Mean=0.1629

Epoch [2/1000], Training Loss: 2.3282
Epoch [2/1000], Validation Loss: 2.3281, Validation Accuracy: 11.35%
Output Summary: Max=0.5527, Min=-0.2773, Median=0.1500, Mean=0.1587

Epoch [3/1000], Training Loss: 2.3252
Epoch [3/1000], Validation Loss: 2.3252, Validation Accuracy: 11.35%
Output Summary: Max=0.5342, Min=-0.2615, Median=0.1628, Mean=0.1547

Epoch [4/1000], Training Loss: 2.3225
Epoch [4/1000], Validation Loss: 2.3226, Validation Accuracy: 11.35%
Output Summary: Max=0.5164, Min=-0.2457, Median=0.1605, Mean=0.1509

Epoch [5/1000], Training Loss: 2.3201
Epoch [5/1000], Validation Loss: 2.3203, Validation Accuracy: 11.35%
Output Summary: Max=0.4993, Min=-0.2299, Median=0.1533, Mean=0.1471



KeyboardInterrupt: 

In [18]:
patience = 1000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.3082
Epoch [1/1000], Validation Loss: 2.2904, Validation Accuracy: 14.54%
Output Summary: Max=0.4717, Min=0.0500, Median=0.2391, Mean=0.2231

Epoch [2/1000], Training Loss: 2.2846
Epoch [2/1000], Validation Loss: 2.2709, Validation Accuracy: 18.62%
Output Summary: Max=0.4334, Min=-0.1091, Median=0.2322, Mean=0.2158

Epoch [3/1000], Training Loss: 2.2611
Epoch [3/1000], Validation Loss: 2.2473, Validation Accuracy: 16.00%
Output Summary: Max=0.5065, Min=-0.1740, Median=0.2446, Mean=0.2056

Epoch [4/1000], Training Loss: 2.2305
Epoch [4/1000], Validation Loss: 2.2166, Validation Accuracy: 20.91%
Output Summary: Max=0.5957, Min=-0.2189, Median=0.2327, Mean=0.1921

Epoch [5/1000], Training Loss: 2.2063
Epoch [5/1000], Validation Loss: 2.1864, Validation Accuracy: 24.79%
Output Summary: Max=0.6781, Min=-0.4170, Median=0.1958, Mean=0.1710

Epoch [6/1000], Training Loss: 2.1782
Epoch [6/1000], Validation Loss: 2.1614, Validation Accuracy: 26.55%
Output Summary

KeyboardInterrupt: 

In [61]:
patience = 1000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.3199
Epoch [1/1000], Validation Loss: 2.2994, Validation Accuracy: 10.05%
Output Summary: Max=0.3197, Min=-0.4100, Median=0.0177, Mean=-0.0204

Epoch [2/1000], Training Loss: 2.2885
Epoch [2/1000], Validation Loss: 2.2744, Validation Accuracy: 11.07%
Output Summary: Max=0.2439, Min=-0.3253, Median=-0.0285, Mean=-0.0319

Epoch [3/1000], Training Loss: 2.2686
Epoch [3/1000], Validation Loss: 2.2575, Validation Accuracy: 21.28%
Output Summary: Max=0.2824, Min=-0.2883, Median=-0.0509, Mean=-0.0330

Epoch [4/1000], Training Loss: 2.2546
Epoch [4/1000], Validation Loss: 2.2601, Validation Accuracy: 18.81%
Output Summary: Max=0.4426, Min=-0.3655, Median=-0.0352, Mean=-0.0317



KeyboardInterrupt: 

In [43]:
patience = 1000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.2841
Epoch [1/1000], Validation Loss: 2.2534, Validation Accuracy: 20.79%
Output Summary: Max=0.6553, Min=-0.3376, Median=0.1000, Mean=0.1031

Epoch [2/1000], Training Loss: 2.2355
Epoch [2/1000], Validation Loss: 2.2170, Validation Accuracy: 20.38%
Output Summary: Max=0.6273, Min=-0.3303, Median=0.1150, Mean=0.1061

Epoch [3/1000], Training Loss: 2.2054
Epoch [3/1000], Validation Loss: 2.1807, Validation Accuracy: 19.85%
Output Summary: Max=0.6261, Min=-0.3204, Median=0.0984, Mean=0.1073

Epoch [4/1000], Training Loss: 2.1697
Epoch [4/1000], Validation Loss: 2.1444, Validation Accuracy: 26.45%
Output Summary: Max=0.6688, Min=-0.3628, Median=0.1046, Mean=0.1189

Epoch [5/1000], Training Loss: 2.1296
Epoch [5/1000], Validation Loss: 2.1278, Validation Accuracy: 17.13%
Output Summary: Max=0.8167, Min=-0.4004, Median=0.1101, Mean=0.1277

Epoch [6/1000], Training Loss: 2.1130
Epoch [6/1000], Validation Loss: 2.1043, Validation Accuracy: 20.50%
Output Summar

KeyboardInterrupt: 

In [44]:
for name, param in model.named_parameters():
    print(name)
    print(param)

mdlstm.h_top
Parameter containing:
tensor([[ 0.0116, -0.0051, -0.0116,  0.0060,  0.0063, -0.0119,  0.0113,  0.0055,
          0.0075,  0.0069,  0.0084, -0.0022,  0.0008, -0.0121,  0.0124, -0.0099,
         -0.0023,  0.0312,  0.0016, -0.0096,  0.0381, -0.0050,  0.0045,  0.0029,
         -0.0007,  0.0322, -0.0168,  0.0047, -0.0082,  0.0145, -0.0064,  0.0059]],
       device='mps:0', requires_grad=True)
mdlstm.c_top
Parameter containing:
tensor([[-0.0530, -0.0043, -0.0117,  0.0019, -0.0147, -0.0101, -0.0035, -0.0124,
          0.0069, -0.0061,  0.0036,  0.0016, -0.0005,  0.0021,  0.0026, -0.0028,
          0.0034,  0.0028,  0.0014, -0.0138,  0.0114, -0.0232,  0.0087, -0.0034,
         -0.0057,  0.0190,  0.0038,  0.0015, -0.0064,  0.0016,  0.0047, -0.0061]],
       device='mps:0', requires_grad=True)
mdlstm.h_left
Parameter containing:
tensor([[-0.0113,  0.0151,  0.0132, -0.0168, -0.0094, -0.0049, -0.0044, -0.0004,
          0.0180, -0.0222, -0.0023, -0.0076, -0.0136,  0.0012,  0.0054, -0.