In [None]:
import jax
from jax import random
import jax.numpy as jnp
from util import filter_step

In [None]:
I = 3  # number of outputs
T = 1000
na = 4
nb = 5

In [None]:
key = random.key(200)
b_coeff = random.normal(key, (I, nb)) * 1e-3
a_coeff = random.normal(key, (I, na)) * 1e-3

In [None]:
# initial conditions
u_carry = jnp.zeros((I, nb - 1))  # u-1, u-2
y_carry = jnp.zeros((I, na))  # y-1, y-2

In [None]:
params = (b_coeff, a_coeff)
carry = (u_carry, y_carry)

In [None]:
u_step = random.normal(key, (I,))
filter_step_simo = jax.vmap(filter_step, in_axes=(0, 0, 0)) # params, carry, u_step
carry_new, y_new = filter_step_simo(params, carry, u_step)
carry_new[0].shape, carry_new[1].shape, y_new.shape

In [None]:
u = random.normal(key, (T, I))
func = lambda carry, u: filter_step_simo(params, carry, u)
carry_last, y = jax.lax.scan(func, carry, u)
y = y.mean(axis=-1)

In [None]:
import matplotlib.pyplot as plt
import scipy.signal
import numpy as np

y_filt = np.empty((T, I))
for idx in range(I):
    y_filt[:, idx] = scipy.signal.lfilter(b_coeff[idx], np.r_[1.0, a_coeff[idx]], u[:, idx])
y_filt = y_filt.mean(axis=-1)

In [None]:
plt.figure()
plt.plot(y, "b")
plt.plot(y_filt, "k")
plt.plot(y - y_filt, "r")