In [None]:
import jax
from jax import random
import jax.numpy as jnp
from util import filter_step

In [None]:
# some coefficients to play with
a_coeff = jnp.array([-.3, -0.5]) # a1, a2 (a0 is always 1)
b_coeff = jnp.array([1.0, 0.0, 1.0]) # b0, b1, b2

In [None]:
# initial conditions
u_carry = jnp.array([0.0, 0.0]) # u-1, u-2
y_carry = jnp.array([0.0, 0.0]) # y-1, y-2

In [None]:
# simulate one step
u_step = jnp.array(1.0) # u0

u_carry = jnp.r_[u_step, u_carry]
y_new = jnp.dot(b_coeff, u_carry) - jnp.dot(a_coeff, y_carry)
u_carry = u_carry[:-1]
y_carry = jnp.r_[y_new, y_carry][:-1]


y_new

In [None]:
# some random data
key = random.key(200)
u = jax.random.normal(key, (200,))

In [None]:
# simulate with scan
u_carry = jnp.array([0.0, 0.0]) # u-1, u-2
y_carry = jnp.array([0.0, 0.0]) # y-1, y-2
carry = (u_carry, y_carry)
param = (b_coeff, a_coeff)
func = lambda carry, u: filter_step(param, carry, u)
last_carry, y = jax.lax.scan(func, carry, u)

In [None]:
# simulate with scipy
import scipy.signal as signal
y_lfilter = signal.lfilter(b_coeff, jnp.r_[1.0, a_coeff], u, axis=0)

In [None]:
import matplotlib.pyplot as plt
plt.plot(y, "k")
plt.plot(y_lfilter, "b")
plt.plot(y-y_lfilter, "r");