In [None]:
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn

In [None]:
class SimpleDense(nn.Module):
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros_init()

    @nn.compact
    def __call__(self, inputs):
        kernel = self.param(
            "kernel",
            self.kernel_init,  # Initialization function
            (inputs.shape[-1], self.features),
        )  # shape info.
        y = jnp.dot(inputs, kernel)
        bias = self.param("bias", self.bias_init, (self.features,))
        y = y + bias
        return y

In [None]:
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4, 4))

In [None]:
model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)
y.shape

In [None]:
params["params"]["kernel"].shape

In [None]:
# simulate one step with a function (for later use with scan)
def filter_step(params, carry, u_step):

    b_coeff, a_coeff = params
    u_carry, y_carry = carry
    u_carry = jnp.r_[u_step, u_carry]
    y_new = jnp.dot(b_coeff, u_carry) - jnp.dot(a_coeff, y_carry)

    u_carry = u_carry[:-1]
    y_carry = jnp.r_[y_new,  y_carry][:-1]
    carry = (u_carry, y_carry)
    return carry, y_new


filter_step_simo = jax.vmap(filter_step, in_axes=(0, 0, 0)) # params, carry, u_step
filter_step_mimo = jax.vmap(filter_step_simo, in_axes=(0, 0, None)) # params, carry, u_step

def mimo_filter(params, carry, u):
    _, y_all = jax.lax.scan(lambda carry, u: filter_step_mimo(params, carry, u), carry, u)
    return  y_all.mean(axis=-1)

batched_mimo_filter = jax.vmap(mimo_filter, in_axes=(None, 0, 0))

In [None]:
filter_step_simo = jax.vmap(filter_step, in_axes=(0, 0, 0)) # params, carry, u_step
filter_step_mimo = jax.vmap(filter_step_simo, in_axes=(0, 0, None)) # params, carry, u_step

def mimo_filter(params, carry, u):
    _, y_all = jax.lax.scan(lambda carry, u: filter_step_mimo(params, carry, u), carry, u)
    return  y_all.mean(axis=-1)

batched_mimo_filter = jax.vmap(mimo_filter, in_axes=(None, 0, 0))

In [None]:
def fixed_std_initializer(std):
    """
    Returns a Flax initializer that initializes the weights with a fixed standard deviation.
    
    Args:
    variance (float): The desired variance of the weights.

    Returns:
    An initializer function.
    """
    def initializer(key, shape, dtype=jnp.float32):
        # Calculate standard deviation from the desired variance
        # Initialize weights from a normal distribution scaled by the std_dev
        return jax.random.normal(key, shape, dtype) * std
    return initializer

In [None]:
class MimoLTI(nn.Module):
    in_channels: int = 1
    out_channels: int = 1
    nb: int = 3
    na: int = 2

    kernel_init: Callable = fixed_std_initializer(1e-3)

    @nn.compact
    def __call__(self, inputs):

        b_coeff = self.param(
            "b_coeff",
            self.kernel_init,  # Initialization function
            (self.out_channels, self.in_channels, self.nb),
        )  # shape info.

        a_coeff = self.param(
            "a_coeff",
            self.kernel_init,  # Initialization function
            (self.out_channels, self.in_channels, self.na),
        )  # shape info.
        params = (b_coeff, a_coeff)

        u_carry = jnp.zeros((inputs.shape[0], self.out_channels, self.in_channels, self.nb - 1))
        y_carry = jnp.zeros((inputs.shape[0], self.out_channels, self.in_channels, self.na))
        carry = (u_carry, y_carry)
        y = batched_mimo_filter(params, carry, inputs)
        #y = inputs + 1
        return y

In [None]:
I = 3 # number of inputs
O = 2 # number of outputs
T = 1000 # number of time steps
B = 32 # batch size
na = 4
nb = 5

In [None]:
u = random.normal(key2, (B, T, I))
model = MimoLTI(I, O, nb, na)
params = model.init(key2, u)
y1 = model.apply(params, u)
y1.shape

In [None]:
u_carry = jnp.zeros((B, O, I, nb - 1))  # u-1, u-2
y_carry = jnp.zeros((B, O, I, na))  # y-1, y-2

carry = (u_carry, y_carry)
y2 = batched_mimo_filter((params["params"]["b_coeff"], params["params"]["a_coeff"]), (u_carry, y_carry), u)
y2.shape