# Introduction

In this notebook, we'll implement implement the forward pass of an SSM (State Space Model) using recursion and convolution based approaches. We'll also compare the two approaches in terms of speed and memory usage.

You will need GPU for this notebook which can be enabled via changing the runtime.

You can copy your solutions from the `q_coding_ssm_forward_cpu` notebook for the first part of the notebook.

## Imports

In [None]:
import time
import math
import matplotlib.pyplot as plt

import torch
import numpy as np
import torch.nn.functional as F

torch.set_default_device('cuda')

# SSM Update Rule

We consider an RNN described by the update:

$$
h_{t+1} = W\,h_t + U\,x_t + b
$$

for $t = 0, 1, \ldots, T - 1$. The variables are:

- $h_t \in R^H$, the hidden state at time $t$.
- $x_t \in R^{N \times D}$, the input at time $t$.
- $W \in R^{H \times H}$, the recurrent weight matrix.
- $U \in R^{H \times D}$, the input projection matrix.
- $b \in R^H$, the bias vector.

$N$ is the batch size, $D$ is the input dimension, and $H$ is the hidden state dimension. We assume $h_0 = 0$, the all-zero vector of dimension $H$.

Below you will implement the forward pass for the SSM using recursion based approach. The `unrolled_ssm_forward` function will take weights $W$, $U$, $b$ and input $x$ and return the hidden states $h$ across different time steps.

In [None]:
def unrolled_ssm_forward(W, U, b, x):
    """
    Unroll the linear RNN in time:
        h_{t+1} = W h_t + U x_t + b
    with initial h_0 = 0.

    Args:
      W: (H, H) weight matrix
      U: (H, D) input projection
      b: (H,)   bias
      x: (N, T, D) input sequence over T steps
    Returns:
      h_all: (N, T, H) hidden states for t=1..T
             (h_all[t] corresponds to h_{t+1} in the usual notation).
    """
    ##############################################################################
    #                   TODO: Implement the recurrent pass here                  #
    ##############################################################################

    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return h_all

# Convolution Based Implementation

In the previous problem, you showed that the forward pass of an SSM can be implmemented using a convolution operation. In this problem, you will implement the forward pass of an SSM using a convolution based approach. You can assume that T is a power of 2.


You will implement two functions
- `make_conv_kernel(W, T)`: This function will take the recurrent weight matrix $W$ and the number of time steps $T$ and return the convolution kernel $K$. Given that T is a power of 2, you can implement this using a divide and conquer based approach.
- `conv_ssm_forward(W, U, b, x)`: This function will take weights $W$, $U$, $b$ and input $x$ and return the hidden states $h$ across different time steps.



In [None]:
def make_conv_kernel(W, T):
    """
    Build a 3D kernel tensor K of shape (H, H, T) we will use when implementing
    the ssm forward pass using conv1d.

    Args:
      W: (H, H) weight matrix
      T: scalar

    Returns:
      kernel_for_conv: (H, H, T) tensor
    """
    ##############################################################################
    #                         TODO: Implement the kernel here                    #
    ##############################################################################

    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return kernel_for_conv

In [None]:
def conv_ssm_forward(W, U, b, x):
    """
    Convolution-based forward pass for a batch of sequences.

    RNN update:  h_{t+1} = W h_t + U x_t + b

    Args:
      W: (H, H) weight matrix
      U: (H, D) input projection
      b: (H,)   bias
      x: (N, T, D) input (batch=N, time steps=T, input dim=D)

    Returns:
      h_all: (N, T, H) hidden states
    """
    N, T, D = x.shape
    H = W.shape[0]

    s = x @ U.T + b  

    s = s.permute(0, 2, 1) 

    # Build the kernel with shape (H, H, T).
    kernel = make_conv_kernel(W, T)
    ##############################################################################
    #                         TODO: Implement the convolution here               #
    ##############################################################################

    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return h_all

# Sanity Check
We can compare the outputs of the two implementations to check if they are consistent.

In [None]:
def sanity_check():
    T = 8   # number of time steps
    H = 4   # hidden dimension
    D = 3   # input dimension
    N = 2

    torch.manual_seed(0)

    W = torch.randn(H, H) * 0.1
    U = torch.randn(H, D) * 0.1
    b = torch.randn(H) * 0.1

    x = torch.randn(N, T, D)

    h_unrolled = unrolled_ssm_forward(W, U, b, x)
    h_conv = conv_ssm_forward(W, U, b, x)

    diff = (h_unrolled - h_conv).abs().max()
    print("Unrolled h(t):")
    print(h_unrolled)
    print("\nConv-based h(t):")
    print(h_conv)
    print("\nMax absolute difference:", diff.item())

sanity_check()

# Implementation Complexity

We can compare the two implementations in terms of efficiency. Particularly, we will compare the time taken by the two implementations to compute the hidden states for a given input and weights with varying number of time steps.

In [None]:
import ipywidgets as widgets
from ipywidgets import interact

def measure_runtime(method_fn, W, U, b, x, warmup=1, repeats=10):
    # Warm-up runs (ignored in timing):
    for _ in range(warmup):
        method_fn(W, U, b, x)

    # Timed runs:
    start = time.time()
    for _ in range(repeats):
        method_fn(W, U, b, x)
    end = time.time()

    avg_time = (end - start) / repeats
    return avg_time


def run():
    T_values_cache = {}
    times_unrolled_vs_T_cache = {}
    times_conv_vs_T_cache = {}

    for H in [4, 8, 16, 32, 64, 128, 256, 512]:
      # We'll keep D, N fixed
      D = 32
      N = 32

      T_values = [8, 32, 128, 256, 512]

      # Build random U, b
      U = torch.randn(H, D)*0.1
      b = torch.randn(H)*0.1

      times_unrolled_vs_T = []
      times_conv_vs_T = []

      for T in T_values:

          diag_vals = torch.randn(H)*0.05
          W = torch.randn(H, H)*0.05
          x = torch.randn(N, T, D)

          t_unrolled = measure_runtime(unrolled_ssm_forward, W, U, b, x)

          t_conv = measure_runtime(conv_ssm_forward, W, U, b, x)

          times_unrolled_vs_T.append(t_unrolled)
          times_conv_vs_T.append(t_conv)

      T_values_cache[H] = T_values
      times_unrolled_vs_T_cache[H] = times_unrolled_vs_T
      times_conv_vs_T_cache[H] = times_conv_vs_T
    return T_values_cache, times_unrolled_vs_T_cache, times_conv_vs_T_cache

T_values_cache, times_unrolled_vs_T_cache, times_conv_vs_T_cache = run()

@interact(H=widgets.FloatLogSlider(min=2, max=9, base=2, value=4, step=1))
def interactive_benchmark(H):
    """
    Compare unrolled vs. diagonal-convolution RNN forward for various T,
    at a chosen hidden dimension H from the slider.
    """
    H = int(H)
    T_values = T_values_cache[H]
    T_unrolled = times_unrolled_vs_T_cache[H]
    T_conv = times_conv_vs_T_cache[H]

    # Plot
    plt.figure(figsize=(6,4))
    plt.plot(T_values, T_unrolled, label="Unrolled", marker='o')
    plt.plot(T_values, T_conv, label="Conv", marker='s')
    plt.title(f"Runtime vs T, H={H}")
    plt.xlabel("Time Steps (T)")
    plt.ylabel("Runtime (sec)")
    plt.yscale('log')
    plt.grid(True)
    plt.legend()
    plt.show()


### Question 6

What do you observe about the runtime of the two implementations as $T$ and $H$ increase? Do you observe the same trend as the CPU notebook? Explain your reasoning.

# Introducing structure: Diagonal Weight Matrics

We can optimize the convolution implemntation via adding a constraint on the $W$ matrix ensuring it is diagonal. We can make the convolution implementation more efficient by leveraging depthwise convolutions for implementing the forward operation.

In [None]:
def make_diag_depthwise_kernel(W, T):
    """
    Construct a depthwise 1D conv kernel for W with shape (H,H). W is a diagonal matrix.

    Args:
      W: (H, H) diagonal matrix
      T: (int) number of time steps

    Return:
      kernel of shape (H, 1, T), which can be used in depthwise convolution
    """
    ##############################################################################
    #                         TODO: Implement the kernel here                    #
    ##############################################################################

    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return kernel

In [None]:
def diag_conv_ssm_forward(W, U, b, x):
    """
    Convolution-based forward pass for an RNN with W

    RNN update:  h_{t+1} = W h_t + U x_t + b
    but W is diagonal, so h_{t+1}(i) = diagW[i]*h_t(i) + [U x_t + b](i).

    Args:
      W: (H, H) [diagonal in practice]
      U: (H, D)
      b: (H,)
      x: (N, T, D) => N = batch, T = time steps, D = input dim

    Returns:
      h_all: (N, T, H)
    """
    N, T, D = x.shape
    H = W.shape[0]

    s = x @ U.T + b

    s = s.permute(0, 2, 1)  # (N,H,T)

    # Build kernel (H,1,T) for depthwise conv
    kernel = make_diag_depthwise_kernel(W, T)
    ##############################################################################
    #                         TODO: Implement the convolution here               #
    #                         Hint: Use `groups` argument                        #
    ##############################################################################
    
    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return h_all


# Optimizing the recurrent implementation

Similarly, we can also optimize the recurrent implementation by using the diagonal nature of the $W$ matrix.

In [None]:
def diag_unrolled_ssm_forward(W, U, b, x):
    """
    Forward pass for:
       h_{t+1} = W h_t + U x_t + b
    but W is diagonal, i.e. W = diag(diagW).

    Args:
      W: (H, H) diagonal matrix
      U: (H, D)
      b: (H,)
      x: (N, T, D)  => batch=N, time steps=T, input dim=D

    Returns:
      h_all: (N, T, H)
    """
    ##############################################################################
    #             TODO: Implement the optimized recurrent pass here              #
    ##############################################################################

    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return h_all

# Sanity Check

We can compare the outputs of the two implementations to check if they are consistent similar to above.

In [None]:
def diag_sanity_check():
    T = 8   # number of time steps
    H = 4   # hidden dimension
    D = 3   # input dimension
    N = 2

    torch.manual_seed(0)

    W = torch.eye(H) * 0.1 #### WE USE A DIAGONAL MATRIX HERE

    U = torch.randn(H, D) * 0.1
    b = torch.randn(H) * 0.1

    x = torch.randn(N, T, D)

    h_unrolled = diag_unrolled_ssm_forward(W, U, b, x)

    h_conv = diag_conv_ssm_forward(W, U, b, x)

    diff = (h_unrolled - h_conv).abs().max()
    print("Unrolled h(t):")
    print(h_unrolled)
    print("\nConv-based h(t):")
    print(h_conv)
    print("\nMax absolute difference:", diff.item())

diag_sanity_check()

# Measure Runtime with Optimization

Similarly, we will measure performance optimization with the diagonalized implementations.

In [None]:
def diag_run():
    T_values_cache = {}
    times_unrolled_vs_T_cache = {}
    times_conv_vs_T_cache = {}

    for H in [4, 8, 16, 32, 64, 128, 256, 512]:
      # We'll keep D, N fixed
      D = 32
      N = 512

      T_values = [8, 32, 128, 256, 512]

      # Build random U, b
      U = torch.randn(H, D)*0.1
      b = torch.randn(H)*0.1

      times_unrolled_vs_T = []
      times_conv_vs_T = []

      for T in T_values:
          diag_vals = torch.randn(H)*0.05

          W = torch.eye(H, H)*0.05 #### WE USE A DIAGONAL MATRIX HERE

          x = torch.randn(N, T, D)

          t_unrolled = measure_runtime(diag_unrolled_ssm_forward, W, U, b, x)

          t_conv = measure_runtime(diag_conv_ssm_forward, W, U, b, x)

          times_unrolled_vs_T.append(t_unrolled)
          times_conv_vs_T.append(t_conv)

      T_values_cache[H] = T_values
      times_unrolled_vs_T_cache[H] = times_unrolled_vs_T
      times_conv_vs_T_cache[H] = times_conv_vs_T
    return T_values_cache, times_unrolled_vs_T_cache, times_conv_vs_T_cache

diag_T_values_cache, diag_times_unrolled_vs_T_cache, diag_times_conv_vs_T_cache = diag_run()

@interact(H=widgets.FloatLogSlider(min=2, max=9, base=2, value=4, step=1))
def interactive_benchmark(H):
    """
    Compare unrolled vs. diagonal-convolution RNN forward for various T,
    at a chosen hidden dimension H from the slider.
    """
    H = int(H)
    T_values = diag_T_values_cache[H]
    T_unrolled = diag_times_unrolled_vs_T_cache[H]
    T_conv = diag_times_conv_vs_T_cache[H]

    # Plot
    plt.figure(figsize=(6,4))
    plt.plot(T_values, T_unrolled, label="Unrolled", marker='o')
    plt.plot(T_values, T_conv, label="Diag-Conv", marker='s')
    plt.title(f"Runtime vs T, H={H} diagonal weights")
    plt.xlabel("Time Steps (T)")
    plt.ylabel("Runtime (sec)")
    plt.grid(True)
    plt.legend()
    plt.show()


### Question 7

What do you observe here? How do your findings different from the unstructured matrix case? Explain your reasoning.