# Simple Demonstration for HiPPO Functionality
This code adapts and simplifies the core functionality from [HiPPO demonstration](https://github.com/state-spaces/s4/blob/main/notebooks/hippo_function_approximation.ipynb) of the official S4 repository.

It shows in detail the following concepts
- Generation of the $A$ and $B$ matrices using the scaled legendre measure
- Discretization of the $A$ and $B$ matrices using the bilinear transform
- Projection of a function $f(t)$ into the coefficients $c$
- Construction of the approximation $g(t)$ from the coefficients $c$

It also provides a function to generate a random whitenoise signal, which has been copied over from the original notebook.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import eval_legendre

## Generating the HiPPO LegS matrices $A$ and $B$
For the scaled Legendre Measure (HiPPO-LegS) the matrices are defined as follows.
$$
\begin{aligned}
A_{nk} &= \begin{cases}\begin{aligned}&(2n+1)^\frac{1}{2}(2k+1)^\frac{1}{2} &if\ n>k\\&n+1 &if\ n=k\\&0 &if\ n<k\\\end{aligned}\end{cases}\\
B_n &= (2n+1)^\frac{1}{2}
\end{aligned}
$$

In [None]:
def hippo_legs(N: int) -> tuple[np.ndarray, np.ndarray]:
    """
    Initializes HiPPO-LegS matrices A and B.

    Parameters:
        N (int): number of Legendre polynomials to use

    Returns:
        (A, B) (ndarray, ndarray): state transition matrix A and input matrix B
    """
    A = np.zeros((N, N))
    B = np.zeros((N,))
    for n in range(N):
        B[n] = np.sqrt(2*n + 1)
        for k in range(N):
            if n > k:
                A[n, k] = - (np.sqrt(2*n + 1) * np.sqrt(2*k + 1))
            elif n == k:
                A[n, k] = - (n + 1)
            else:
                A[n, k] = 0
    return A, B

In [None]:
# Show the matrices A and B for N=4
A, B = hippo_legs(4)
display(A)
display(B)

## Discretization of the A and B matrices
The $A$ and $B$ matrices can be discretized using the Bilinear Transform.
$$
\begin{aligned}
\bar{A}&=(I-\frac{\Delta}{2}A)^{-1}\cdot(I+\frac{\Delta}{2}A)\\
\bar{B}&=(I-\frac{\Delta}{2}A)^{-1}\cdot \Delta B
\end{aligned}
$$

In [None]:
def discretize_hippo(A: np.ndarray, B: np.ndarray, dt: float) -> tuple[np.ndarray, np.ndarray]:
    """
    Converts continuous-time matrices A, B into discrete-time versions using the bilinear transform.

    Parameters:
        A (ndarray): continuous-time state transition matrix
        B (ndarray): continuous-time input matrix
        dt (float): time step for discretization
    
    Returns:
        (A_d, B_d) (ndarray, ndarray): discrete-time state transition matrix A_d and discrete-time input matrix B_d
    """
    I = np.eye(A.shape[0])
    A_d = np.linalg.inv(I - 0.5 * A * dt) @ (I + 0.5 * A * dt)
    B_d = np.linalg.inv(I - 0.5 * A * dt) @ (B * dt)

    # Alternatively, using solve_triangular for better numerical stability and faster computation
    # A_d = la.solve_triangular(I - 0.5 * A * dt, I + 0.5 * A * dt, lower=True)
    # B_d = la.solve_triangular(I - 0.5 * A * dt, B * dt, lower=True)
    return A_d, B_d

In [None]:
# Show the discretized matrices for dt=0.1
A_d, B_d = discretize_hippo(A, B, 0.1)
display(A_d)
display(B_d)

## Projection of a function $f(t)$ into coefficients $c$
We can project the function $f(t)$ into a basis of legendre polynomials. This yields the coefficients $c$.

We do this by iteratively applying the State equation.

$$
c_k = \bar{A}c_{k-1} + \bar{B}f_{k}
$$

We take the final coefficients as a result.

**TODO: is this correct?**

In [None]:
def project_function(f: np.ndarray, A: np.ndarray, B: np.ndarray) -> np.ndarray:
    """
    Projects a time-series f(t) into a coefficient vector c representing the Legendre polynomial expansion.\\
    Each time step updates the coefficients using the discretized HiPPO system.

    Parameters:
        f (ndarray): time-series data to project
        A (ndarray): state transition matrix
        B (ndarray): input matrix

    Returns:
        c (ndarray): coefficients of the Legendre polynomial expansion
    """
    c = np.zeros_like(B)
    for i, f_t in enumerate(f):
        A_t = A / (i + 1)   # TODO: why???
        B_t = B / (i + 1)
        A_d, B_d = discretize_hippo(A_t, B_t, 1.0)  # dt=1 since scale handles time
        c = A_d @ c + B_d * f_t     # iteratively apply state equation to update coefficients
    return c

## Reconstruction of the approximation $g(t)$
We can construct the approximation $g(t) \approx f(t)$ by evaluating the legendre polynomials $P_i$, weighting them with the corresponding coefficients $c_i$ and summing the results.

$$
g(t) = \sum_{i=0}^N c_i P_i(t)
$$

In [None]:
def reconstruct_function(c: np.ndarray, n_samples: int) -> np.ndarray:
    """
    Given Legendre coefficients c, reconstruct the function approximation.

    Parameters:
        c (ndarray): coefficients of the Legendre polynomial expansion
        n_samples (int): number of samples to generate for the approximation

    Returns:
        approx (ndarray): reconstructed function values
    """
    t = np.linspace(-1, 1, n_samples)   # Legendre polynomials are defined on [-1, 1]
    approx = np.zeros_like(t)
    for i, c_i in enumerate(c):
        P_i = eval_legendre(i, t)
        approx += c_i * P_i
    return approx

In [None]:
# Plot the first 5 legendre polynomials and a reconstructed function using random coefficients
c = [0.1, 0.2, 0.7, -0.3, -0.1]
n_samples = 1000
t = np.linspace(-1, 1, n_samples)

# Plot legendre polynomials
plt.figure(figsize=(12, 6))
for i, c_i in enumerate(c):
    P_i = eval_legendre(i, t)
    plt.plot(t, P_i, label=f'$P_{i}(t)$')

# Plot the reconstruction
approx = reconstruct_function(c, n_samples)
plt.plot(t, approx, label=r'$\sum c_i \cdot P_i(t)$', linestyle='--', color='black')

plt.title(f'First {len(c)} Legendre Polynomials and Function Reconstruction')
plt.xlabel('t')
plt.ylabel('Signal Value')
plt.legend(loc='lower right')

## Synthetic Data Generation
Helper function to generate a band-limited signal.  
Copied from [https://github.com/state-spaces/s4/blob/main/notebooks/hippo_function_approximation.ipynb](https://github.com/state-spaces/s4/blob/main/notebooks/hippo_function_approximation.ipynb).

In [None]:
def whitesignal(period, dt, freq, rms=0.5):
    """
    Produces output signal of length period / dt, band-limited to frequency freq\\
    Adapted from the nengo library
    """

    if freq is not None and freq < 1. / period:
        raise ValueError(f"Make ``{freq=} >= 1. / {period=}`` to produce a non-zero signal",)

    nyquist_cutoff = 0.5 / dt
    if freq > nyquist_cutoff:
        raise ValueError(f"{freq} must not exceed the Nyquist frequency for the given dt ({nyquist_cutoff:0.3f})")

    n_coefficients = int(np.ceil(period / dt / 2.))
    shape = (n_coefficients + 1,)
    sigma = rms * np.sqrt(0.5)
    coefficients = 1j * np.random.normal(0., sigma, size=shape)
    coefficients[..., -1] = 0.
    coefficients += np.random.normal(0., sigma, size=shape)
    coefficients[..., 0] = 0.

    set_to_zero = np.fft.rfftfreq(2 * n_coefficients, d=dt) > freq
    coefficients *= (1-set_to_zero)
    power_correction = np.sqrt(1. - np.sum(set_to_zero, dtype=float) / n_coefficients)
    if power_correction > 0.: coefficients /= power_correction
    coefficients *= np.sqrt(2 * n_coefficients)
    signal = np.fft.irfft(coefficients, axis=-1)
    signal = signal - signal[..., :1]  # Start from 0
    return signal

## Approximating a signal using HiPPO
The code below approximates a signal $f(t) \approx g(t)$ by projecting it into a basis of legendre polynomials and then reconstring it using the coefficients $c$.

Note that we use a very small number of legendre polynomials to make the difference between $f(t)$ and $g(t)$ in the plot more visible. Using more polynomials leads to much results.

In [None]:
# Plotting the original and approximated signal

# Parameters
T=2         # Time period
dt=1e-3     # Time step
N=20        # Number of Legendre polynomials (Choose N=32 for near perfect approximation!)
freq=3.0    # Frequency of the signal
np.random.seed(42)  # For reproducibility

# Original signal
t = np.arange(0.0, T, dt)
f_t = whitesignal(T, dt, freq=freq)

# Approximated signal
A, B = hippo_legs(N)
c = project_function(f_t, A, B)
c_scaled = B * c    # Scale the coefficients with B because legendre polynomials aren't orthonormal
g_t = reconstruct_function(c_scaled, len(t))

# Plotting the results
plt.figure(figsize=(16, 6))
plt.plot(t, f_t, color='blue', label='Input Signal $f(t)$')
plt.plot(t, g_t, color='orange', linestyle='dashed', label='HiPPO Approximation $g(t)$')
plt.fill_between(t, f_t, g_t, color='red', alpha=0.3, label='Error')

plt.title('HiPPO-LegS Signal Approximation')
plt.xlabel('t')
plt.ylabel('Signal Value')
plt.legend(loc='upper right')
plt.show()
plt.close()