In [1]:
from ssm_layer import SSM
import torch 
import numpy as np
from torch_scan import torch_scan

In [None]:
# Example Run
from functools import partial
def example_mass(k, b, m):
    A = torch.tensor([[0, 1], [-k / m, -b / m]], dtype=torch.float32)
    B = torch.tensor([[0], [1.0 / m]], dtype=torch.float32)
    C = torch.tensor([[1.0, 0]], dtype=torch.float32)
    return A, B, C

@partial(np.vectorize, signature="()->()")
def example_force(t):
    x = np.sin(10 * t)
    return torch.tensor(x * (x > 0.5), dtype=torch.float32)

def example_ssm():
    # SSM
    # A_n, B_n, C_n = example_mass(k=40, b=5, m=1)
    # ssm = SSM(A=A_n, B=B_n, C=C_n)
    ssm = example_mass(k=40, b=5, m=1)
    for i, val in enumerate(ssm):
        print(f"{i} -> {val}")
    state = SSM(*ssm)
    # L samples of u(t).
    L = 100
    step = 1.0 / L
    ks = np.arange(L)
    u = example_force(ks * step)

    # Approximation of y(t).
    y = state.run_SSM(*ssm, u)

    # Plotting ---
    import matplotlib.pyplot as plt
    import seaborn
    from celluloid import Camera

    seaborn.set_context("paper")
    fig, (ax1, ax2, ax3) = plt.subplots(3)
    camera = Camera(fig)
    ax1.set_title("Force $u_k$")
    ax2.set_title("Position $y_k$")
    ax3.set_title("Object")
    ax1.set_xticks([], [])
    ax2.set_xticks([], [])

    # Animate plot over time
    for k in range(0, L, 2):
        ax1.plot(ks[:k], u[:k], color="red")
        ax2.plot(ks[:k], y[:k], color="blue")
        ax3.boxplot(
            [[y[k, 0] - 0.04, y[k, 0], y[k, 0] + 0.04]],
            showcaps=False,
            whis=False,
            vert=False,
            widths=10,
        )
        camera.snap()
    anim = camera.animate()
    anim.save("images/test_ssm.gif", dpi=150, writer="imagemagick")
example_ssm()

![gif](images/test_ssm.gif)