In [15]:
from functools import partial
import jax
import jax.numpy as np
from flax import linen as nn
from jax.nn.initializers import lecun_normal, normal
from jax.numpy.linalg import eigh, inv, matrix_power
from jax.scipy.signal import convolve

In [16]:
rng = jax.random.PRNGKey(1)

In [17]:
def random_SSM(rng, N):
    # N is the dimension of the latent (hidden) state
    a_r, b_r, c_r = jax.random.split(rng, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return A, B, C

In [18]:
A, B, C = random_SSM(rng, N=16)

In [19]:
def discretize(A, B, C, step):
    # uses bilinaer method to discretize continuous-time SSM matrices
    # matrices retain original shape
    # of note, C is not discretized
    I = np.eye(A.shape[0])
    BL = inv(I - (step / 2.0) * A)
    Ab = BL @ (I + (step / 2.0) * A)
    Bb = (BL * step) @ B
    return Ab, Bb, C

In [20]:
Ab, Bb, C = discretize(A, B, C, step=0.1)

In [21]:
def scan_SSM(Ab, Bb, Cb, u, x0):
    def step(x_k_1, u_k):
        x_k = Ab @ x_k_1 + Bb @ u_k
        y_k = Cb @ x_k
        return x_k, y_k

    return jax.lax.scan(step, x0, u)

## Example from mechanics

In [22]:
def example_mass(k, b, m):
    # parameterizes diff eq for mass attached to a wall with a spring
    # k = spring constant
    # b = friction constant
    # m = mass
    A = np.array([[0, 1], [-k / m, -b / m]])
    B = np.array([[0], [1.0 / m]])
    C = np.array([[1.0, 0]])
    return A, B, C

In [23]:
ssm = example_mass(k=40, b=5, m=1)
for _ in ssm:
    print(_.shape)

# 2D latent state

(2, 2)
(2, 1)
(1, 2)


In [24]:
@partial(np.vectorize, signature="()->()")
def example_force(t):
    x = np.sin(10 * t)
    return x * (x > 0.5)

In [25]:
L = 100 # 100 samples
step = 1 / L
ks = np.arange(L)
u = example_force(ks * step)

print(u.shape)

(100,)


In [26]:
A, B, C = example_mass(k=40, b=5, m=1) # ssm
L = u.shape[0]
step = 1 / L 
N = A.shape[0]
Ab, Bb, Cb = discretize(A, B, C, step=step)
scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N, )))[1].shape

(100, 1)

In [27]:
scan_SSM(Ab, Bb, C, u[:, ])

TypeError: scan_SSM() missing 1 required positional argument: 'x0'

In [14]:
def K_conv(Ab, Bb, Cb, L):
    return np.array(
        [(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
    )


K = K_conv(Ab, Bb, Cb, L)
print(K.shape)

NameError: name 'Cb' is not defined

In [13]:
convolve(u, K, mode="full").shape

NameError: name 'K' is not defined