<a href="https://colab.research.google.com/github/durml91/State-Space-Models/blob/main/S4_exploration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


S4 exploration
---
Taken from https://srush.github.io/annotated-s4/. Alternative approach to attention-only models (see Mamba next). Big problem with attention is quadratic scaling with longer and longer context windows (hence solutions such as sliding window attention). Key ideas involves long range sequence modelling, continuous vs. discrete (think neural ODEs) and CNN at training to RNN at inference (convolution is an operation that mixes information across vectors, analagous to attention).



In [2]:
# !pip install equinox

In [5]:
from functools import partial
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox

In [7]:
key = jr.PRNGKey(2024)

State space model: $x'(t) = \boldsymbol{A} x(t) + \boldsymbol{B} u(t)$ and $y(t) = \boldsymbol{C} x(t) + \boldsymbol{D} u(t)$. This is an ODE w.r.t. time whereby the input is actually $u(t)$ in 1D, $x(t)$ is multi-dimensional and the output "signal" $y(t)$ is 1D. The parameters/matrices in bold are learnt, although the authors adopt the convention $\boldsymbol{D}=0$ as this bit is actually just a skip connection.

In [None]:
#initiliase random matrices

def random_params(N, key):
  a_key, b_key, c_key = jr.split(key, 3)

  A = jr.uniform(key=a_key, shape=(N,N))
  B = jr.uniform(key=b_key, shape=(N,1))
  C = jr.uniform(key=c_key, shape=(1,N))

  return A,B,C

Discretise input sequence $u(t) \to (u_{0}, u_{1},...)$. So we are basically sampling from the underlying signal every step of size $\Delta$ such that $u_{i} = u(i \cdot \Delta)$. The authors use the bilinear method from control theory (continuous time in Laplace domain to discrete time in complex plane) to approximate the parameters: $\bar{\boldsymbol{A}} = (\boldsymbol{I} - \frac{\Delta}{2} \cdot \boldsymbol{A})^{-1}(\boldsymbol{I} + \frac{\Delta}{2} \cdot \boldsymbol{A})$, $\bar{\boldsymbol{B}} = (\boldsymbol{I} - \frac{\Delta}{2} \cdot \boldsymbol{A})^{-1} \Delta \boldsymbol{B}$ and $\bar{\boldsymbol{C}}= \boldsymbol{C}$. This discretisation (analagous to the Euler discretisation) allows us to write the ODE as a recurrence relation, namely $x_k = \bar{\boldsymbol{A}}x_{k-1} + \bar{\boldsymbol{B}}u_k$. Given $\boldsymbol{D} = 0$ (sort of), this also gives us a nice closed form solution for y, namely $y_k = \bar{\boldsymbol{C}} (\bar{\boldsymbol{A}}x_{k-1} + \bar{\boldsymbol{B}}u_k)$. This looks like an RNN layer.

In [None]:
#compute bilinear transformation
def discretise(A, B, C, step):
  I = jnp.eye(A.shape[0])
  bl = jnp.linalg.inv(I - (step / 2.) * A)
  A_bl = bl @ (I + (step / 2.) * A)
  B_bl = (bl * step) @ B

  return A_bl, B_bl, C

def scan_SSM(A_bl, B_bl, C_bl, u, x0):
  def step(x_k_1, u_k):
    x_k = A_bl @ x_k_1 + B_bl @ u_k
    y_k = C_bl @ x_k
    return x_k, y_k

  return jax.lax.scan(step, x0, u) # basically apply step with x0 as carry and u as invariant

def run_SSM(A, B, C, u):
  L= u.shape[0]
  N = A.shape[0]
  A_bl, B_bl, C_bl = discretise(A, B, C, step=1./L)

  return scan_SSM(A_bl, B_bl, C_bl, jnp.expand_dims(u, axis=-1))