# Demonstration of SSM Functionality
This Jupyter Notebook gives an overview of the functionality of State Space Models (SSMs)

The code is adapted and simplified from [The Annotated S4](https://srush.github.io/annotated-s4/).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from celluloid import Camera
from IPython.display import Image, Video, HTML

## Discretization
Discretization will be explained in a later chapter.

For now, please note that the SSM can be used in a continuous or discrete context. Since we want to work in a discrete context in this Notebook, we first discretize it.

In [None]:
def discretize(A: np.ndarray, B: np.ndarray, dt: float) -> tuple[np.ndarray, np.ndarray]:
    """
    Converts continuous-time matrices A, B into discrete-time versions using the bilinear transform.

    Parameters:
        A (ndarray): continuous-time state transition matrix
        B (ndarray): continuous-time input matrix
        dt (float): time step for discretization
    
    Returns:
        (A_d, B_d) (ndarray, ndarray): discrete-time state transition matrix A_d and discrete-time input matrix B_d
    """
    I = np.eye(A.shape[0])
    A_d = np.linalg.inv(I - 0.5 * A * dt) @ (I + 0.5 * A * dt)
    B_d = np.linalg.inv(I - 0.5 * A * dt) @ (B * dt)

    # Alternatively, using solve_triangular for better numerical stability and faster computation
    # A_d = la.solve_triangular(I - 0.5 * A * dt, I + 0.5 * A * dt, lower=True)
    # B_d = la.solve_triangular(I - 0.5 * A * dt, B * dt, lower=True)
    return A_d, B_d

## Recurrent Calculation of the Output
The SSM is given in a continuous context as:
$$
\begin{aligned}
x'(t) &= Ax(t) + Bu(t) &&\text{(State Equation)}\\
y(t) &= Cx(t) + Du(t) &&\text{(Output Equation)}
\end{aligned}
$$

The discretized version is:

$$
\begin{aligned}
x_k &= \bar{A}x_{k-1} + \bar{B}u_{k} &&\text{(State Equation)}\\
y_k &= Cx_k &&\text{(Output Equation)}
\end{aligned}
$$

We can recurrently calculate the state $x_k$ and output $y_k$, which we implemented in the `step_SSM()` and `run_SSM` functions.

In [None]:
def step_SSM(A: np.ndarray, B: np.ndarray, C: np.ndarray, x_k: np.ndarray, u_k: np.ndarray):
    """Calculates the next step x_k+1 and y_k+1 for the SSM"""
    x_k = A @ x_k + B * u_k
    y_k = C @ x_k
    return x_k, y_k

In [None]:
def run_SSM(A: np.ndarray, B: np.ndarray, C: np.ndarray, x0: np.ndarray, u: np.ndarray):
    """Runs the SSM recurrently over all inputs u"""
    y = np.zeros_like(u)
    x_k = x0
    for idx, u_k in enumerate(u):
        x_k, y_k = step_SSM(A, B, C, x_k, u_k)
        y[idx] = y_k.squeeze()  # Remove dimensions
    return y

## Example: Spring System
The following code gives an example of a mass attached to a wall with a spring.
We give it an input force and using the SSM, we can calculate the position of the mass.

The derivation for the $A$, $B$ and $C$ matrices is given in the article.

In [None]:
# Parameters
k = 40
b = 5
m = 1

# Initialize SSM Matrices
A = np.array([[0, 1], 
              [-k/m, -b/m]])
B = np.array([[0], 
              [1.0/m]])
C = np.array([[1.0, 0]])

# Generate an Input Signal
L = 200
step = 1.0 / L
ks = np.arange(L)
x = np.sin(10 * ks * step)
u = x * (x > 0.5)

# Discretize the SSM
A, B = discretize(A, B, step)

# Run the SSM
x_0 = np.ones_like(B) * 0.001
y = run_SSM(A, B, C, x_0, u)

In [None]:
# Plot Parameters
mass_width = 0.0016
mass_height = 0.1
spring_height = 0.05
num_coils = 15
wall_x = 0
time_s = 5

fig, (ax1, ax2, ax3) = plt.subplots(3, figsize=(12, 8))
camera = Camera(fig)
ax1.set_title("Force $u_k$")
ax2.set_title("Position $y_k$")
ax3.set_title("Object")
ax1.set_xticks([], [])
ax2.set_xticks([], [])
ax3.set_xticks([], [])
ax3.set_xlim(-0.4 * max(y), 1.4 * max(y))
ax3.set_ylim(-0.2, 0.2)
fig.tight_layout()

# Animate plot over time
for k in range(0, L):
    # Plot applied force
    ax1.plot(ks[:k], u[:k], color="red")

    # Plot object position
    ax2.plot(ks[:k], y[:k], color="blue")

    # Plot wall
    ax3.plot([wall_x, wall_x], [-0.2, 0.2], color='black', linewidth=2)

    # Plot spring
    spring_x = np.linspace(wall_x, y[k], num=500)
    spring_y = (spring_height / 2) * np.sin(2 * np.pi * num_coils * np.linspace(0, 1, len(spring_x)))
    ax3.plot(spring_x, spring_y, color="gray", linewidth=2)

    # Plot Mass
    mass = plt.Rectangle((y[k], -mass_height / 2),
                        mass_width, mass_height,
                        fc="steelblue", ec="black")
    ax3.add_patch(mass)

    camera.snap()

interval = int(time_s * 1000 / L)
anim = camera.animate(interval=interval)

# Save the animation
anim.save("media/ssm.gif", dpi=150)
# Image("line.gif")
# anim.save("line.mp4", fps=25)
# Video("line.mp4")

# Show the animation
plt.close()
HTML(anim.to_html5_video())

## Convolution
**TODO:** Add code to explain the convolutional representation of the SSM.

$$
y = \bar{K} * u
$$

## Linear State Space Layer (LSSL)
**TODO:** After presenting all concepts of the SSM (Discretization, Recurrent Representation, Convolution), add some code that puts everything into one LSSL class.

Reference: https://tinkerd.net/blog/machine-learning/state-space-models/#linear-state-space-layers-lssl
