#### Figuring out how stacked LSTMs work


In [10]:
%load_ext autoreload
%autoreload 2

import sys  # noqa
from pathlib import Path  # noqa

# Add parent directory to Python path to import from sibling directories
sys.path.append(str(Path.cwd().parent))

import torch  # noqa
import torch.nn as nn  # noqa
import torch.optim as optim  # noqa

from a2c.convlstm import MultiLayerLinearLSTM  # noqa


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
def test_overfit_multilayer_linearlstm():
    """
    A simple test to ensure that the MultiLayerLinearLSTM can overfit
    on a small synthetic sequence dataset -- but now using run_n_ticks()
    for each time step in the sequence. Overfitting a small dataset
    ensures that forward/backward passes are working as intended
    and the model can eventually drive the loss near zero.
    """

    torch.manual_seed(0)  # For reproducibility

    # ----------------------------------------------------------------
    # HYPERPARAMETERS
    # ----------------------------------------------------------------
    input_dim = 10
    hidden_dim = 8
    num_layers = 2
    seq_length = 5
    batch_size = 4
    lr = 1e-3
    num_epochs = 5000

    # Choose how many ticks to run per time step
    n_ticks_per_step = 3  # Example: run 3 ticks each time we see a new input vector

    # ----------------------------------------------------------------
    # 1) CREATE A DUMMY "SEQUENCE" DATASET
    # ----------------------------------------------------------------
    # We'll generate random input sequences (batch_size, seq_length, input_dim)
    # and want the model to predict the same sequence (identity mapping).
    x_data = torch.randn(batch_size, seq_length, input_dim)
    y_data = x_data.clone().detach()

    # ----------------------------------------------------------------
    # 2) DEFINE A SMALL MODEL THAT USES MultiLayerLinearLSTM
    # ----------------------------------------------------------------
    class OverfitModel(nn.Module):
        def __init__(self, input_dim, hidden_dim, num_layers):
            super().__init__()
            self.hidden_dim = hidden_dim
            self.num_layers = num_layers
            # Our custom multi-layer LSTM
            self.rnn = MultiLayerLinearLSTM(input_dim, hidden_dim, num_layers)
            # Final projection to get back to input_dim
            self.output_layer = nn.Linear(hidden_dim, input_dim)

        @property
        def memory_size(self):
            # For each layer, we have (c, h) = 2 * hidden_dim
            return 2 * self.hidden_dim * self.num_layers

        def forward(self, x, memory=None, n_ticks=1):
            """
            x       : (batch, seq_length, input_dim)
            memory  : (batch, 2 * hidden_dim * num_layers) if provided
            n_ticks : number of times to run the LSTM per input
            returns : (predictions, updated_memory),
                      where predictions has shape (batch, seq_length, input_dim)
            """
            batch_size, seq_len, _ = x.shape

            # If no memory is provided, initialize all c/h to zeros
            if memory is None:
                c_prev_list = [
                    torch.zeros(batch_size, self.hidden_dim, device=x.device)
                    for _ in range(self.num_layers)
                ]
                h_prev_list = [
                    torch.zeros(batch_size, self.hidden_dim, device=x.device)
                    for _ in range(self.num_layers)
                ]
            else:
                c_prev_list, h_prev_list = self.unpack_memory(memory)

            outputs = []

            # For each time step in the sequence
            for t in range(seq_len):
                i_t = x[:, t, :]  # shape: (batch, input_dim)

                # 3) Run the multi-layer LSTM for n_ticks for each i_t
                #    Instead of a single forward pass, we do run_n_ticks.
                #    This will repeatedly update c_prev_list/h_prev_list.
                c_prev_list, h_prev_list = self.rnn.run_n_ticks(
                    i_t, c_prev_list, h_prev_list, n_ticks=n_ticks
                )

                # The top-layer hidden state after run_n_ticks becomes our "embedding"
                out_t = self.output_layer(h_prev_list[-1])  # (batch, input_dim)
                outputs.append(out_t.unsqueeze(1))  # keep seq dimension

            # Concatenate outputs along time dimension
            outputs = torch.cat(outputs, dim=1)  # (batch, seq_length, input_dim)

            # Repack the final states into a single memory tensor
            new_memory = self.pack_memory(c_prev_list, h_prev_list)

            return outputs, new_memory

        def pack_memory(self, c_list, h_list):
            """
            Convert [c_1, h_1, ..., c_D, h_D] into a single tensor of shape
            (batch_size, 2 * hidden_dim * num_layers).
            """
            new_memory = []
            for d in range(self.num_layers):
                new_memory.append(c_list[d])
                new_memory.append(h_list[d])
            return torch.cat(new_memory, dim=1)

        def unpack_memory(self, memory):
            """
            Reverse of pack_memory. Extract each layer's (c, h) from memory.
            """
            c_prev_list = []
            h_prev_list = []
            slice_size = self.hidden_dim
            for d in range(self.num_layers):
                c_start = 2 * d * slice_size
                h_start = c_start + slice_size
                c_d = memory[:, c_start : c_start + slice_size]
                h_d = memory[:, h_start : h_start + slice_size]
                c_prev_list.append(c_d)
                h_prev_list.append(h_d)
            return c_prev_list, h_prev_list

    # ----------------------------------------------------------------
    # 3) INITIALIZE MODEL, OPTIMIZER, LOSS
    # ----------------------------------------------------------------
    model = OverfitModel(input_dim, hidden_dim, num_layers)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    # We will keep track of memory across epochs if we like,
    # or re-init each epoch. For demonstration we keep it persistent.
    memory = None

    # ----------------------------------------------------------------
    # 4) TRAINING LOOP TO OVERFIT
    # ----------------------------------------------------------------
    for epoch in range(num_epochs):
        optimizer.zero_grad()

        # Forward pass (run_n_ticks is invoked inside model.forward for each step)
        preds, memory = model(x_data, memory=memory, n_ticks=n_ticks_per_step)

        # Compute MSE loss
        loss = criterion(preds, y_data)

        # Backward
        loss.backward()
        optimizer.step()

        # Optionally detach memory to avoid accumulating gradients through time
        memory = memory.detach()

        if (epoch + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.6f}")

    # Final check: after training, the loss should be close to zero
    print("Final loss:", loss.item())
    assert loss.item() < 1e-3, (
        "Loss did not go below 1e-3. Overfitting test failed. "
        "Check your LSTM or training logic."
    )


In [20]:
test_overfit_multilayer_linearlstm()

Epoch [100/5000], Loss: 0.667332
Epoch [200/5000], Loss: 0.416788
Epoch [300/5000], Loss: 0.245756
Epoch [400/5000], Loss: 0.156294
Epoch [500/5000], Loss: 0.108760
Epoch [600/5000], Loss: 0.076148
Epoch [700/5000], Loss: 0.060032
Epoch [800/5000], Loss: 0.048084
Epoch [900/5000], Loss: 0.040887
Epoch [1000/5000], Loss: 0.038212
Epoch [1100/5000], Loss: 0.036612
Epoch [1200/5000], Loss: 0.035683
Epoch [1300/5000], Loss: 0.035018
Epoch [1400/5000], Loss: 0.034926
Epoch [1500/5000], Loss: 0.035180
Epoch [1600/5000], Loss: 0.033926
Epoch [1700/5000], Loss: 0.033326
Epoch [1800/5000], Loss: 0.032906
Epoch [1900/5000], Loss: 0.032651
Epoch [2000/5000], Loss: 0.032485
Epoch [2100/5000], Loss: 0.032367
Epoch [2200/5000], Loss: 0.032274
Epoch [2300/5000], Loss: 0.032194
Epoch [2400/5000], Loss: 0.032124
Epoch [2500/5000], Loss: 0.032060
Epoch [2600/5000], Loss: 0.031969
Epoch [2700/5000], Loss: 0.031887
Epoch [2800/5000], Loss: 0.031827
Epoch [2900/5000], Loss: 0.031812
Epoch [3000/5000], Loss

AssertionError: Loss did not go below 1e-3. Overfitting test failed. Check your LSTM or training logic.