In [None]:
import jax
from jax import random
import jax.numpy as jnp
from util import filter_step

In [None]:
I = 3 # number of inputs
O = 2 # number of outputs
T = 1000 # number of time steps
B = 32 # batch size
na = 4
nb = 5

In [None]:
key = random.key(200)
a_coeff = random.normal(key, (O, I, na)) * 1e-3
b_coeff = random.normal(key, (O, I, nb)) * 1e-3

In [None]:
# initial conditions
u_carry = jnp.zeros((O, I, nb - 1))  # u-1, u-2
y_carry = jnp.zeros((O, I, na))  # y-1, y-2

In [None]:
params = (b_coeff, a_coeff)
carry = (u_carry, y_carry)
u_step = random.normal(key, (O, 1))

In [None]:
filter_step_simo = jax.vmap(filter_step, in_axes=(0, 0, 0)) # params, carry, u_step
filter_step_mimo = jax.vmap(filter_step_simo, in_axes=(0, 0, None)) # params, carry, u_step

In [None]:
u = random.normal(key, (T, I))
def mimo_filter(params, carry, u):
    _, y_all = jax.lax.scan(lambda carry, u: filter_step_mimo(params, carry, u), carry, u)
    return  y_all.mean(axis=-1)

y = mimo_filter(params, carry, u)

#func = lambda carry, u: filter_step_mimo(params, carry, u)
#carry_last, y_all = jax.lax.scan(func, carry, u)
#y = y_all.mean(axis=-1)

In [None]:
import matplotlib.pyplot as plt
from scipy.signal import lfilter
import numpy as np

y_filt = np.empty((T, O, I))
for idx_o in range(O):
    for idx_i in range(I):
        y_filt[:, idx_o, idx_i] = lfilter(
            b_coeff[idx_o, idx_i], np.r_[1.0, a_coeff[idx_o, idx_i]], u[:, idx_i]
        ).ravel()
y_filt = y_filt.mean(axis=-1)

In [None]:
for idx in range(O):
    plt.figure()
    plt.plot(y[:, idx], "b")
    plt.plot(y_filt[:, idx], "k")
    plt.plot(y[:, idx] - y_filt[:, idx], "r")

In [None]:
u = random.normal(key, (B, T, I))
# initial conditions
u_carry = jnp.zeros((B, O, I, nb - 1))  # u-1, u-2
y_carry = jnp.zeros((B, O, I, na))  # y-1, y-2
batched_mimo_filter = jax.vmap(mimo_filter, in_axes=(None, 0, 0))
batched_mimo_filter(params, (u_carry, y_carry), u).shape