# Accuracy of truncated basis expansions

In this notebook, we will investigate the accuracy of truncated basis expansions for scalar functions $[-1, 1] \to \R$.

In [None]:
import numpy as np
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
from typing import Callable

## Basis functions

We will first define several classes of basis functions.

### Fourier basis

We first look at perhaps the most well-known set of basis functions, the *Fourier basis*, consisting of sines and cosines
$$
    \begin{aligned}
        w_0(x) &= 1 \\
        w_1(x) &= \sin(\pi x) \\
        w_2(x) &= \cos(\pi x) \\
        w_3(x) &= \sin(2\pi x) \\
        w_4(x) &= \cos(2\pi x) \\
        &\vdots
    \end{aligned}
$$

There are orthogonal with respect to the inner product
$$
    (w_i, w_j) := \int_{-1}^{1} w_i(x) w_j(x) dx = \text{constant} \times \delta_{ij}.
$$


In [None]:
def fourier_basis(x: np.ndarray, num_components: int) -> np.ndarray:
    """
    Generate a Fourier basis with the specified number of components.

    Parameters:
    - x: Input array of shape (num_samples,) containing the values at which to evaluate the basis functions.
    - num_components: Number of Fourier components to generate.

    Returns:
    - basis: Fourier basis array of shape (num_samples, num_components).
    """
    num_samples = len(x)
    basis = np.zeros((num_samples, num_components))

    for i in range(num_components):
        freq = ((i + 1) // 2) * np.pi  # Calculate the frequency of the current component
        basis[:, i] = np.cos(freq * x) if i % 2 == 0 else np.sin(freq * x)  # Evaluate the basis function

    return basis

Now, let us visualise the first 4 basis components...

In [None]:
num_components = 4
num_samples = 128
x = np.linspace(-1, 1, num_samples)  # uniformly spaced samples


In [None]:
def plot_features(inputs: np.ndarray, basis: Callable[[np.ndarray, int], np.ndarray], num_components: int) -> None:
    """
    Plot the features of a basis expansion.

    Parameters:
    - inputs: Input array of shape (num_samples,) containing the x-values.
    - features: Feature array of shape (num_samples, num_components) containing the basis expansion features.
    - num_components: Number of basis components.
    - feature_name: Name of the basis expansion.

    Returns:
    - None
    """

    # compute (w_1(x), w_2(x), ..., w_4(x))^T
    features = basis(x, num_components)

    # Create subplots with the specified number of rows and one column
    fig, ax = plt.subplots(num_components, 1, figsize=(1.5*num_components, 7), layout="constrained")

    # Iterate over each feature and its corresponding axis
    for i, (f, a) in enumerate(zip(features.T, ax)):
        # Plot the feature values against the inputs
        a.plot(inputs, f, label=f'{basis.__name__} {i}')
        a.legend(loc='upper right')

In [None]:
plot_features(inputs=x, basis=fourier_basis, num_components=num_components)

### Legendre basis

Next, we consider an alternative basis consisting of orthogonal polynomials called *Legendre polynomials*.

Let $P_k$ ($k\geq 0$) denote the $k$-th Legendre polynomial.

The simplest definition of Legendre polynomials can be constructed iteratively, by requiring the following conditions
- $P_k$ is a degree of polynomial of degree $k$
- $P_k(1)=1$ (standardisation)
- $\int_{-1}^{1} P_i(x) P_j(x) dx = \text{Constant} \times \delta_{ij}$ (orthogonality)

Let us see some examples
1. $k=0$: $P_0(x) = 1$ is the only possible choice
2. $k=1$: $P_1(x) = ax + b$ satisfies the first condition, and second condition forces $a+b = 1$, and the last condition forces $b=0$, thus $P_1(x)=x$
3. $k=2$: Using a similar argument, this time requiring orthogonality to both $P_0$ and $P_1$, $P_2(x) = (3x^2 - 1)/2$
4. $k=3$: $P_3 = (5x^3 - 3x)/2$
5. ...

There are many applications of Legendre polynomials
- Originally studied by [Adrien-Marie Legendre](https://en.wikipedia.org/wiki/Adrien-Marie_Legendre) in electrostatics and universal law of gravitation
- Multipole expansions for electrostatic analysis
- Improving the training of sequence-to-sequence state-space models for natural language and related applications (https://proceedings.neurips.cc/paper/2019/file/952285b9b7e7a1be5aa7849f32ffff05-Paper.pdf)

To geenerate Legendre polynomials, we can simply call the ready-made routine `np.polynomial.legendre.legval`.

In [None]:
def legendre_basis(x: np.ndarray, num_components: int) -> np.ndarray:
    """
    Generate a Legendre basis with the specified number of components.

    Parameters:
    - x: Input array of shape (num_samples,) containing the values at which to evaluate the basis functions.
    - num_components: Number of Legendre components to generate.

    Returns:
    - basis: Legendre basis array of shape (num_samples, num_components).
    """
    num_samples = len(x)
    basis = np.zeros((num_samples, num_components))

    for i in range(num_components):
        # Calculate the Legendre polynomial coefficients for the current component
        coefficients = [0] * i + [1]
        basis[:, i] = np.polynomial.legendre.legval(x, coefficients)  # Evaluate the basis function

    return basis

Let us plot the first 4 Legendre polynomials. Observe that they look quite different from the Fourier basis functions.

In [None]:
plot_features(inputs=x, basis=legendre_basis, num_components=num_components)

### Radial basis functions

Now, we consider a very different type of basis functions, the *radial basis functions*.

These have the form
$$
    w_k(x) = \exp(-\lambda (x - c_k)^2)
$$
where the radial centers $c_k$ are chosen to be "distributed" more or less evenly in the domain of interest $[-1,1]$.

Notice that in general, these are not orthogonal in $L^2([-1,1])$!

In [None]:
def radial_basis(x: np.ndarray, num_components: int) -> np.ndarray:
    """
    Generate a radial basis with the specified number of components.

    Parameters:
    - x: Input array of shape (num_samples,) containing the values at which to evaluate the basis functions.
    - num_components: Number of radial components to generate.

    Returns:
    - basis: Radial basis array of shape (num_samples, num_components).
    """
    num_samples = len(x)
    basis = np.zeros((num_samples, num_components))
    centers = np.linspace(-1, 1, num_components)

    for i, c in enumerate(centers):
        # Calculate the radial basis function for the current component
        basis[:, i] = np.exp(-5 * (x - c) ** 2)  # Evaluate the basis function

    return basis

Let us look at the first 4 radial basis functions for $\lambda = 5$, and $\{c_k\}$ are equally spaced in $[-1,1]$.
Other choices are of course possible!

In [None]:
plot_features(inputs=x, basis=radial_basis, num_components=num_components)

## Error of truncated expansion

Now, given a target function $f:[-1,1]\to\R$, we consider the truncated approximation
$$
    \hat{f}_n(x) = \sum_{k=1}^{n} a_k w_k(x)
$$
where $w_k(x)$ are the various basis elements defined above, and $a_k$ can be fitted via ordinary least squares.

### Identity function

Let us first look at the target function $f(x) = x$.

In [None]:
def identity(x: np.ndarray) -> np.ndarray:
    return x

plt.plot(x, identity(x))
plt.xlabel(r"$x$")
plt.ylabel(r"$f(x)$")

We use the `LinearRegression` routine to fit the coefficients $\{a_k\}$.

In [None]:
from sklearn.linear_model import LinearRegression
import numpy as np

def fit(func: np.ndarray, x: np.ndarray, basis: callable, num_components: int) -> tuple:
    """
    Fit the coefficients of the truncated expansion using linear regression.

    Parameters:
    - func: Target function to approximate.
    - x: Input array of shape (num_samples,) containing the values at which to evaluate the basis functions.
    - basis: Basis function to use for the expansion.
    - num_components: Maximum number of basis components to use.

    Returns:
    - scores: List of R^2 scores for each number of components.
    - predictions: List of predicted values for each number of components.
    """

    scores, predictions = [], []

    # Iterate over each number of components
    for N in range(1, num_components+1):
        basis_x = basis(x, N)  # Generate the basis features
        reg = LinearRegression().fit(basis_x, func(x))  # Fit the coefficients
        score = reg.score(basis_x, func(x))  # Calculate the R^2 score
        prediction = reg.predict(basis_x)  # Make predictions
        scores.append(score)
        predictions.append(prediction)

    return scores, predictions

In [None]:
def analyse(func: Callable[[np.ndarray], np.ndarray], basis: Callable[[np.ndarray, int], np.ndarray], inputs: np.ndarray, num_components: int) -> None:
    """
    Analyze the accuracy of truncated basis expansions.

    Parameters:
    - func: Target function to approximate.
    - basis: Basis function to use for the expansion.
    - inputs: Input array of shape (num_samples,) containing the values at which to evaluate the basis functions.
    - num_components: Maximum number of basis components to use.

    Returns:
    - None
    """

    # Fit the coefficients and calculate the R^2 scores
    scores, predictions = fit(func=func, x=inputs, basis=basis, num_components=num_components)

    # Create a mosaic plot layout
    mosaic = np.asarray([['a)', 'e)'], ['b)', 'e)'], ['c)', 'e)'], ['d)', 'e)']]).T

    # Create a figure and subplots
    fig, ax = plt.subplot_mosaic(mosaic, layout='constrained', figsize=(15, 6))
    fig.suptitle(f'Target: {func.__name__}, Basis: {basis.__name__}')

    # Plot the R^2 scores
    ax['e)'].plot(range(1, num_components+1), scores, '-o')
    ax['e)'].set_ylabel(r'$R^2$')
    ax['e)'].set_xlabel(r'$N$')
    ax['e)'].set_ylim(0, 1.1)

    # Plot the target function and predicted values for each number of components
    for label in ['a)', 'b)', 'c)', 'd)']:
        ax[label].set_ylabel(r'$f(x)$')
        ax[label].set_xlabel(r'$x$')
        ax[label].plot(inputs, func(inputs), '--', label='True')
        for n in range(num_components):
            ax[label].plot(inputs, predictions[n], label=f'N={n+1}')
        ax[label].legend()

In [None]:
analyse(func=identity, basis=fourier_basis, inputs=x, num_components=num_components)
analyse(func=identity, basis=legendre_basis, inputs=x, num_components=num_components)
analyse(func=identity, basis=radial_basis, inputs=x, num_components=num_components)

Notice that **different basis sets have different approximation qualities at the same number of components**!

### Square wave

As a second example, we consider a target in the form of a square wave.

In [None]:
import matplotlib.pyplot as plt

def square_wave(x: np.ndarray) -> np.ndarray:
    return 1.0 * (np.sin(2 * np.pi * x) > 0)

plt.plot(x, square_wave(x))
plt.xlabel(r"$x$")
plt.ylabel(r"$f(x)$")

Let's compare the truncated approximation qualities.

In [None]:
analyse(func=square_wave, basis=fourier_basis, inputs=x, num_components=num_components)
analyse(func=square_wave, basis=legendre_basis, inputs=x, num_components=num_components)
analyse(func=square_wave, basis=radial_basis, inputs=x, num_components=num_components)

Notice that, not only are the approximation qualities basis-dependent, **it is also target dependent**!

### Rational function

We confirm these points by considering a third target in the form of a rational function

In [None]:
def rational_function(x: np.ndarray) -> np.ndarray:
    return (x - 1) / (1 + x**2)

plt.plot(x, rational_function(x))
plt.xlabel(r"$x$")
plt.ylabel(r"$f(x)$")

In [None]:
analyse(func=rational_function, basis=fourier_basis, inputs=x, num_components=num_components)
analyse(func=rational_function, basis=legendre_basis, inputs=x, num_components=num_components)
analyse(func=rational_function, basis=radial_basis, inputs=x, num_components=num_components)

## Summary

We see here that
- Different basis functions give different approximation qualities
- Which basis is better depends on which target we are considering
- How does one obtain the "best" possible basis?