# Neural Operators: Part I

In this example, we are looking at functional data to see what kinds of things we can model. Some components we will cover are:

* Interpolation
* Forecasting
* Mapping

In [None]:
import sys, os
from pyprojroot import here


# spyder up to find the roo

root = here(project_files=[".local"])
# append to path
sys.path.append(str(root))

%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jrandom

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_context(context="poster", font_scale=0.8)

%matplotlib inline

Let's assume we have a function $s:\mathbb{R}^t\rightarrow \mathbb{R}$ which maps some scaler value for time, $t$, to some scalar value. A more concrete example is temperature. The "real" function is a continuous representation of temperature.

As a simple example, let's define the function as:

$$
s(t):=\sin(\pi t)
$$

We can define a time vector, $\boldsymbol{t}\in\mathbb{R}^{D_t}$, as a set of discrete time inputs that range from $-\pi/2$ to $\pi/2$.

In [None]:
# define the function
s_fn = lambda t: jnp.sin(jnp.pi * t)

# define the time vector
t_vector = jnp.linspace(-jnp.pi / 2.0, jnp.pi / 2.0, 100)

# get s vector.
s_vector = jax.vmap(s_fn)(t_vector)

Let's do a demo plot of this.

In [None]:
# demo plot of function
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(t_vector, s_vector, color="red", label="True Temperatures", zorder=1)
ax.set(xlabel="$t$", ylabel="$s(t)$")

plt.tight_layout()
plt.legend(fontsize=14)
plt.show()

## Discrete Observations

In actuality, this plot shows a nice continuous curve but life doesn't really work like that. We typically measure some discrete signals which represent the continuous form of the true function. So in this case, let's plot some 10 temperature values.

In [None]:
num_obs = 25
num_query = 15

t_data, s_data = {}, {}

t_data["obs"] = jnp.linspace(-0.95 * jnp.pi / 2.0, 0.95 * jnp.pi / 2.0, num_obs)
t_data["query"] = jnp.linspace(-0.95 * jnp.pi / 2.0, 0.95 * jnp.pi / 2.0, num_query)

# get s vector.
s_data["obs"] = jax.vmap(s_fn)(t_data["obs"])
s_data["query"] = jax.vmap(s_fn)(t_data["query"])

In [None]:
# demo plot of function
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(t_vector, s_vector, color="red", label="Continuous Signal", zorder=1)
# ax.stem(t_obs, s_obs, label="Discrete Signal", linefmt="black", markerfmt=".k", basefmt="black")
ax.stem(
    t_data["query"],
    s_data["query"],
    label="Query Signal",
    linefmt="c",
    markerfmt=".c",
    basefmt="c",
)

ax.set(xlabel="$t$", ylabel="$s(t)$")

plt.tight_layout()
plt.legend(fontsize=14)
plt.show()

We immediately see that the "true" signal we measure actually is discrete but we can say that it lives on a continuous domain. In other words, if we are given a vector of temperature observations, $\boldsymbol{s}\in\mathbb{R}^{D_t}$, it is actually a functional dataset where the temperature value corresponds to a point in time.

A more concise way to write this is to say that the time component, $t$, lives in a continuous space $\mathcal{A}$ and the temperature values also live in a continuous space, $\mathcal{B}$. However, there exists a space which contains a map from space $\mathcal{A}$ to space $\mathcal{B}$. We can write this as:

$$
\boldsymbol{C}(\mathcal{A},\mathcal{B})
$$

So, in our case, we can say that we don't really observe temperature in a vacuum. What we really observe, is the mapping of the space of time, $\mathcal{T}$, to the space of temperature through the function, $\mathcal{S}$.

$$
C(\mathcal{T},\mathcal{S})
$$

where $t\in\mathcal{T}$ and $s\in\mathcal{S}$. So in practical terms, we have temperature and we also have some meta-data attached to it which is time.

## Real World Data

So in the real world, the temperature values that we observe are never "completely" continuous. It would be very expensive to capture and record all of this information. Instead we interact with the world by getting observations in the form of signals. We will discuss 3 key aspects that can cause serious problems with real world data:

* [ ] Representation - Discrete, Continuous, Precision
* [x] Discretization - Coarse, Fine, Adaptive
* [x] Sparsity, Irregularity - Interpolation, Graphs
* [x] Noise - Filtering, Smoothing, Modeling

### Discretized Signals

Because we have There is an adequate amount of discretization that is sufficient to capture the essence of the signal. A discretization that is too low will miss some high frequency signals. A discretization that is too high will be wasteful for computation and storage because there is redundant information. 

Note: the proper term for this is the Nyquist frequency which basically describes the correct amount of discretization necessary to capture all of the frequencies embedded within this signal.

Here, we have an example of how we can create some fake data using different discretizations.

In [None]:
t_discrete, s_discrete = {}, {}

t_discrete["sparse"] = jnp.linspace(-0.95 * jnp.pi / 2.0, 0.95 * jnp.pi / 2.0, 5)
t_discrete["medium"] = jnp.linspace(-0.95 * jnp.pi / 2.0, 0.95 * jnp.pi / 2.0, 25)
t_discrete["dense"] = jnp.linspace(-0.95 * jnp.pi / 2.0, 0.95 * jnp.pi / 2.0, 100)

# get s vector.
s_discrete["sparse"] = jax.vmap(s_fn)(t_discrete["sparse"])
s_discrete["medium"] = jax.vmap(s_fn)(t_discrete["medium"])
s_discrete["dense"] = jax.vmap(s_fn)(t_discrete["dense"])

When we plot it, we see that we can intuitively see which of these discretizations is adequate enough given the true continuous signal.

In [None]:
for (iname, i_time), (iname, i_s) in zip(t_discrete.items(), s_discrete.items()):
    # demo plot of function
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(t_vector, s_vector, color="red", label="Continuous Signal", zorder=1)
    ax.stem(
        i_time,
        i_s,
        label="Discrete Signal",
        linefmt="black",
        markerfmt=".k",
        basefmt="black",
    )

    ax.set(xlabel="$t$", ylabel="$s(t)$")

    plt.tight_layout()
    plt.legend(fontsize=14)
    plt.show()

### Sparse, Irregular Sampling

In [None]:
t_irregular, s_irregular = {}, {}

keys = jrandom.PRNGKey(123)

keys, *uniform = jrandom.split(keys, 4)

t_irregular["sparse"] = jrandom.uniform(
    key=uniform[0], shape=(10,), minval=-0.95 * jnp.pi / 2.0, maxval=0.95 * jnp.pi / 2.0
)
t_irregular["medium"] = jrandom.uniform(
    key=uniform[1], shape=(25,), minval=-0.95 * jnp.pi / 2.0, maxval=0.95 * jnp.pi / 2.0
)
t_irregular["dense"] = jrandom.uniform(
    key=uniform[2],
    shape=(100,),
    minval=-0.95 * jnp.pi / 2.0,
    maxval=0.95 * jnp.pi / 2.0,
)

# get s vector.
s_irregular["sparse"] = jax.vmap(s_fn)(t_irregular["sparse"])
s_irregular["medium"] = jax.vmap(s_fn)(t_irregular["medium"])
s_irregular["dense"] = jax.vmap(s_fn)(t_irregular["dense"])

In [None]:
for (iname, i_time), (iname, i_s) in zip(t_irregular.items(), s_irregular.items()):
    # demo plot of function
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(t_vector, s_vector, color="red", label="Continuous Signal", zorder=1)
    ax.stem(
        i_time,
        i_s,
        label="Irregular Signal",
        linefmt="black",
        markerfmt=".k",
        basefmt="black",
    )

    ax.set(xlabel="$t$", ylabel="$s(t)$")

    plt.tight_layout()
    plt.legend(fontsize=14)
    plt.show()

### Noisy Signal

In [None]:
t_noisy, s_noisy = {}, {}


keys, *normal = jrandom.split(keys, 4)

t_noisy["sparse"] = jnp.linspace(-0.95 * jnp.pi / 2.0, 0.95 * jnp.pi / 2.0, 25)
t_noisy["medium"] = jnp.linspace(-0.95 * jnp.pi / 2.0, 0.95 * jnp.pi / 2.0, 25)
t_noisy["dense"] = jnp.linspace(-0.95 * jnp.pi / 2.0, 0.95 * jnp.pi / 2.0, 25)


# get s vector.
s_noisy["sparse"] = jax.vmap(s_fn)(t_noisy["sparse"])
s_noisy["medium"] = jax.vmap(s_fn)(t_noisy["medium"])
s_noisy["dense"] = jax.vmap(s_fn)(t_noisy["dense"])

s_noisy["sparse"] += 0.01 * jrandom.normal(key=normal[0], shape=(25,))
s_noisy["medium"] += 0.1 * jrandom.normal(key=normal[1], shape=(25,))
s_noisy["dense"] += 0.5 * jrandom.normal(key=normal[2], shape=(25,))

In [None]:
for (iname, i_time), (iname, i_s) in zip(t_noisy.items(), s_noisy.items()):
    # demo plot of function
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(t_vector, s_vector, color="red", label="Continuous Signal", zorder=1)
    ax.scatter(i_time, i_s, label="Noisy Observations", zorder=2, color="black")

    ax.set(xlabel="$t$", ylabel="$s(t)$")

    plt.tight_layout()
    plt.legend(fontsize=14)
    plt.show()