In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange, einsum
from torch.fft import rfft, irfft, rfft2, irfft2
import numpy as np
from matplotlib import pyplot as plt

# 1D

In [None]:
class LTI(nn.Module):
    """
    State-Free Inference of State-Space Models:
    The Transfer Function Approach
    https://arxiv.org/abs/2405.06147
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        order: int,
        causal: bool,
        mimo: bool,
    ):
        super().__init__()
        assert mimo or (input_dim == output_dim), "SISO in/out dimensions must match."
        assert order > 0, "Order must be a positive integer."
        self.causal = causal
        self.mimo = mimo
        self.order = order

        shape = (input_dim, output_dim) if mimo else (input_dim,)
        self.h0 = nn.Parameter(torch.zeros((*shape, 1)))
        self.numerators = nn.Parameter(torch.zeros((2, *shape, order)))
        self.denumerator = nn.Parameter(torch.zeros((*shape, order)))

    def forward(self, x: torch.Tensor):
        *B, L, D = x.shape
        # compute truncated impulse response
        a = F.pad(self.denumerator, (1, 0), value=1.0)
        b = F.pad(self.numerators, (1, 0), value=0.0)
        H_truncated = rfft(b, n=L) / rfft(a, n=L)
        h_truncated = irfft(H_truncated, n=L)

        # transfer function padded to 2L for convolution
        H = rfft(h_truncated, n=2 * L)
        H_causal, H_anticausal = H
        H = self.h0 + H_causal + (0 if self.causal else H_anticausal.conj())

        # convolution in frequency domain
        X = rfft(x, n=2 * L, axis=-2)
        if self.mimo:
            Y = einsum(X, H, "... L Din, ... Din Dout L -> ... L Dout")
        else:
            Y = einsum(X, H, "... L D, ... D L -> ... L D")
        y = irfft(Y, axis=-2)
        y = y[..., :L, :]
        return y

In [None]:
L = 128 + 1
D = 3
ORDER = 16


def random_init(model: LTI) -> LTI:
    model.numerators.data = torch.randn_like(model.numerators)
    model.denumerator.data = torch.randn_like(model.denumerator)
    return model


def plot_impulse_response(model, L, D, title=""):
    plt.figure(figsize=(10, 2 * D))
    for i in range(D):
        x = torch.zeros((L, D))
        x[L // 2, i] = 1.0
        h = model(x).detach().cpu().numpy()

        plt.subplot(311 + i)
        plt.plot(range(L), x[:, i], "k", label=f"$x_{i}$")
        for j in range(D):
            plt.plot(range(L), h[:, j], label=f"$h_{j}$")
        plt.legend(loc="upper left")
        plt.grid(True)
    plt.suptitle(title)
    plt.show()


model = LTI(D, D, order=ORDER, causal=True, mimo=False)
model = random_init(model)
plot_impulse_response(model, L, D, title="Causal SISO")

model = LTI(D, D, order=ORDER, causal=True, mimo=True)
model = random_init(model)
plot_impulse_response(model, L, D, title="Causal MIMO")

model = LTI(D, D, order=ORDER, causal=False, mimo=False)
model = random_init(model)
plot_impulse_response(model, L, D, title="Non-causal SISO")

model = LTI(D, D, order=ORDER, causal=False, mimo=True)
model = random_init(model)
plot_impulse_response(model, L, D, title="Non-causal MIMO")

# 2D

In [None]:
class LTI2d(nn.Module):
    """
    State-Free Inference of State-Space Models:
    The Transfer Function Approach
    https://arxiv.org/abs/2405.06147
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        order: int,
        causal: bool,
        mimo: bool,
    ):
        super().__init__()
        assert mimo or (input_dim == output_dim), "SISO in/out dimensions must match."
        assert order > 0, "Order must be a positive integer."
        self.causal = causal
        self.mimo = mimo
        self.order = order

        shape = (input_dim, output_dim) if mimo else (input_dim,)
        self.h0 = nn.Parameter(torch.zeros((*shape, 1, 1)))
        self.numerators = nn.Parameter(torch.zeros((4, *shape, order, order)))
        self.denumerator = nn.Parameter(torch.zeros((*shape, order, order)))
        self.denumerator.data[..., 0, 0] = 1.0

    def forward(self, x: torch.Tensor):
        *B, L1, L2, D = x.shape
        # compute truncated impulse response
        a = self.denumerator / self.denumerator[..., 0, 0].unsqueeze(-1).unsqueeze(-1)
        b = self.numerators
        H_truncated = rfft2(b, s=(L1, L2)) / rfft2(a, s=(L1, L2))
        h_truncated = irfft2(H_truncated, s=(L1, L2))

        # transfer function padded to (2H,2W) for convolution
        H = rfft2(h_truncated, s=(2 * L1, 2 * L2))
        H_cc, H_ca, H_ac, H_aa = H
        H_ca = torch.flip(H_ca, dims=(-2,))  # flip in height
        H_ac = torch.flip(H_ac, dims=(-2,)).conj()  # flip in width
        H_aa = H_aa.conj()  # flip both
        H = self.h0 + H_cc + (0 if self.causal else (H_ca + H_ac + H_aa))

        # convolution in frequency domain
        X = rfft2(x, s=(2 * L1, 2 * L2), dim=(-3, -2))
        if self.mimo:
            Y = einsum(X, H, "... L1 L2 Din, ... Din Dout L1 L2 -> ... L1 L2 Dout")
        else:
            Y = einsum(X, H, "... L1 L2 D, ... D L1 L2 -> ... L1 L2 D")
        y = irfft2(Y, s=(2 * L1, 2 * L2), dim=(-3, -2))
        y = y[..., :L1, :L2, :]
        return y

In [None]:
L1 = L2 = 128 + 1
D = 3
ORDER = 16


def random_init(model: LTI2d) -> LTI2d:
    model.numerators.data = torch.randn_like(model.numerators)
    model.denumerator.data = torch.randn_like(model.denumerator)
    return model


def plot_impulse_response(model, H, W, D, title=""):
    plt.figure(figsize=(3 * D, 3 * D))
    for i in range(D):
        x = torch.zeros((H, W, D))
        x[H // 2, W // 2, i] = 1.0
        h = model(x).detach().cpu().numpy()
        for j in range(D):
            plt.subplot(D, D, 1 + D * i + j)
            plt.title(f"$h_{j} | \\delta_{i}$")
            r = float(np.abs(h).max())
            plt.imshow(h[:, :, j], vmin=-r, vmax=r, cmap="RdBu_r")
            plt.xticks([])
            plt.yticks([])
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()


model = LTI2d(D, D, order=ORDER, causal=True, mimo=False)
model = random_init(model)
plot_impulse_response(model, L1, L2, D, title="Causal SISO")

model = LTI2d(D, D, order=ORDER, causal=True, mimo=True)
model = random_init(model)
plot_impulse_response(model, L1, L2, D, title="Causal MIMO")

model = LTI2d(D, D, order=ORDER, causal=False, mimo=False)
model = random_init(model)
plot_impulse_response(model, L1, L2, D, title="Non-causal SISO")

model = LTI2d(D, D, order=ORDER, causal=False, mimo=True)
model = random_init(model)
plot_impulse_response(model, L1, L2, D, title="Non-causal MIMO")