In [None]:
#| echo: false 
#| output: false
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import os 
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"


# Differential Equations

We're going to take a small detour through ordinary differential equations (ODEs).
As should become evident shortly,
this topic is important for our journey through score modelling.

## Ordinary differential equations

Before we go into probability flow ODEs, though, 
we should revisit their more, ahem, _ordinary_ counterpart,
ordinary differential equations (ODEs).
ODEs are usually taught in undergraduate calculus classes,
since they involve differentiation and integration.
I do remember encountering them in high school in Singapore,
which is a testament to how advanced the mathematics curriculum in Singapore is.

ODEs are useful models of systems
where we believe that the rate of change of an output variable
is a math function of some input variable.
In abstract mathematical symbols:

$$\frac{dy}{dx} = f(x, \theta)$$

Here, $f$ simply refers to some mathematical function of $x$
and the function's parameters $\theta$.

### A classic ODE example

A classic ODE example that we might think of is that of a decay curve:

$$\frac{dy}{dt} = -y$$

Implemented in `diffrax`, which is a JAX package for differential equations,
and wrapped in Equinox as a parameterized function,
we have the following code:

In [None]:
from diffrax import diffeqsolve, Tsit5, ODETerm, SaveAt
import jax.numpy as np
from jax import vmap
import equinox as eqx

def exponential_decay_drift(t, y, args):
    return -y


class ODE(eqx.Module):
    drift: callable

    def __call__(self, ts: np.ndarray, y0: float):
        term = ODETerm(self.drift)
        solver = Tsit5()
        saveat = SaveAt(ts=ts, dense=True)
        sol = diffeqsolve(term, solver, t0=ts[0], t1=ts[-1], dt0=ts[1] - ts[0], y0=y0, saveat=saveat)
        return vmap(sol.evaluate)(ts)


For those of us who have learned about ODEs, 
the structure of the code above should look pretty familiar.
The diffrax API neatly organizes what we need to solve ODEs:

- the `ODETerm`, which is the $\frac{dy}{dt}$ equation,
- a `solver`, for which `diffrax` provides a library of them,
- the initial and end points $t_0$ and $t_1$ along the $t$ axis along with step size $dt$,
- the initial value of $y$, i.e. $y_0$.

At the end of it all we have a `sol`ution object that we can use later on.

With the solved ODE, we can use it to plot what the solution would look like.

In [None]:
ode = ODE(exponential_decay_drift)
ts = np.linspace(0, 10, 1000)
ys = ode(ts=ts, y0=3)


In [None]:
#| code-fold: true 
#| fig-cap: Solution to the ODE $f'(y) = -y$.
#| label: fig-ode-exponential-decay
import matplotlib.pyplot as plt 
import seaborn as sns

plt.plot(ts, ys)
plt.xlabel("t")
plt.ylabel("y")
sns.despine()

The solution of the ODE that we had above is an exponential decay,
and that is exactly what we see in the curve above.

And if we wanted to run the ODE from multiple starting points:

In [None]:
#| code-fold: true
#| fig-cap: Multiple solutions to the ODE $f'(y) = -y$.
#| label: fig-ode-multiple-decay

ys = ode(ts=ts, y0=np.arange(0, 10))

for curve in ys.T:
    plt.plot(ts, curve)
plt.xlabel("t")
plt.ylabel("y")
sns.despine()

## Stochastic Differential Equations

Stochastic differential equations (SDEs) extend ODEs
by adding in noise into each step when solving. 
SDEs can thus be thought of as having a "drift" component,
in which the system being modeled by the SDE "drifts" through the vector field,
and a "diffusion" component,
in which the system's state is perturbed with additional noise.
Let's look at an example from diffrax's documentation,
modified to be a bit simpler.
We basically have:

$$\frac{dy}{dt} = -y + N(0, 0.1)$$

where $N(0, 0.1)$ refers to a draw from an i.i.d. Gaussian.

More generally, we have the following form:

$$dx = f(x, t)dt + g(t)dw$$

Here, 

> - $f(x, t)$ is a drift function that produces a vector output,
> - $g(t)$ is a diffusion function that produces a scalar output,
> - and $dw$ is infinitesimal white noise.  
> 
> (paraphrased from Yang's blog)


This is implemented in code below.

In [None]:
from jax import random
import jax.numpy as np
from diffrax import (
    ControlTerm,
    MultiTerm,
    VirtualBrownianTree,
)

class SDE(eqx.Module):
    drift: callable
    diffusion: callable

    def __call__(self, ts: np.ndarray, y0: float, key: random.PRNGKey):
        brownian_motion = VirtualBrownianTree(ts[0], ts[-1], tol=1e-3, shape=(), key=key)
        terms = MultiTerm(ODETerm(self.drift), ControlTerm(self.diffusion, brownian_motion))
        solver = Tsit5()
        saveat = SaveAt(ts=ts, dense=True)
        sol = diffeqsolve(terms, solver, t0=ts[0], t1=ts[-1], dt0=ts[1] - ts[0], y0=y0, saveat=saveat)
        return vmap(sol.evaluate)(ts) 


### Noisy Decay

In [None]:
# Each instance of random noise is paired with one SDE.
from functools import partial

def homoskedastic_diffusion(t, y, args):
    return 0.1

n_timesteps = 17
n_starting = 1001


key = random.PRNGKey(55)
y0s = random.normal(key, shape=(n_starting,))
keys = random.split(key, len(y0s))
ts = np.linspace(0, 6, n_timesteps)
sde = SDE(drift=exponential_decay_drift, diffusion=homoskedastic_diffusion)
sde = partial(sde, ts)
ys = vmap(sde)(y0s, keys)
for y in ys:
    plt.plot(ts, y, alpha=0.01, color="blue")

plt.xlabel("t")
plt.ylabel("y")
sns.despine()

### Noising SDE

Another SDE that we might want is something that has increasing amounts of noise over time.
Drift would basically be a 0 term, while the diffusion term would be some multiplier on time.

In [None]:
def constant_drift(t, y, args):
    return 0

def time_dependent_diffusion(t, y, args):
    return 0.1 * t


sde = SDE(drift=constant_drift, diffusion=time_dependent_diffusion)
ts = np.linspace(0, 5, n_timesteps)
sde = partial(sde, ts)
y0s = random.normal(key, shape=(n_starting,)) * 0.1
keys = random.split(key, n_starting)
noising_ys = vmap(sde)(y0s, keys)

In [None]:
#| code-fold: true
#| label: fig-noising-sde
#| fig-cap: A "noising" SDE that progressively adds more noise over time.
for y in noising_ys:
    plt.plot(ts, y, color="blue", alpha=0.01)
    plt.xlabel("t")
    plt.ylabel("y")
    plt.title(f"{n_starting} sample trajectories")
    sns.despine()

We are able to obtain greater amounts of noise from a tight starting point.


At each timepoint, there is also a marginal distribution.

In [None]:
import numpy as onp
fig, axes = plt.subplots(figsize=(8, 10), nrows=6, ncols=3, sharex=True)
axes = axes.flatten()

for ax, t, y in zip(axes, ts, noising_ys.T):
    plt.sca(ax)
    plt.hist(onp.array(y), bins=30)
    plt.title(f"time={t:.1f}")

sns.despine()
plt.delaxes(axes[-1])
plt.tight_layout()

### Oscillating SDE 

Let's do one final one just to hammer home the point.
We have a cosine drift term with homoskedastic noise.

In [None]:
def cosine_drift(t, y, args):
    return np.sin(t)

sde = SDE(drift=cosine_drift, diffusion=homoskedastic_diffusion)
ts = np.linspace(1, 10, n_timesteps)
sde = partial(sde, ts)
keys = random.split(key, 1001)
oscillating_y0s = random.normal(key, shape=(1001,)) * 0.1
oscillating_ys = vmap(sde)(oscillating_y0s, keys)


In [None]:
for y in oscillating_ys:
    plt.plot(ts, y, color="blue", alpha=0.01)

plt.xlabel("t")
plt.ylabel("y")
sns.despine()


In [None]:
import numpy as onp
fig, axes = plt.subplots(figsize=(8, 10), nrows=6, ncols=3, sharex=True)
axes = axes.flatten()

for ax, t, y in zip(axes, ts, oscillating_ys.T):
    plt.sca(ax)
    plt.hist(onp.array(y), bins=30)
    plt.title(f"time={t:.1f}")

sns.despine()
plt.delaxes(axes[-1])
plt.tight_layout()

## Reverse Time SDEs

With constant drift and time-dependent diffusion, we can noise up data in a continuous fashion.
How do we go backwards?
Here is where solving the reverse time SDE will come in. 
Again, we need to set up the drift and idifusion terms.
Here, drift is:

$$f(x, t) - g^2(t) \nabla_x \log p_t (x) $$

And diffusion is:

$$g(t) dw$$


However, the tricky part here is that we don't have access to
$\nabla_x \log p_t (x)$ (the true score function).
As such, we need to bring out our score model approximator!
To train the score model approximator,
we need the analogous score matching objective for continuous time problems.

In an ideal situation,
we would train the score matching model
using a weighted combination of Fisher divergences:

$$\mathbb{E}_{t \in U(0, T)} \mathbb{E}_{p_t(x)} [ \lambda(t) || \nabla_x \log p_t(x) - s_{\theta}(x, t) ||^2_2]$$

Now, just like before, we don't have access to $\nabla_x \log p_t (x)$,
so we instead use the score matching objective by Hyvärinen.

In [None]:
from score_models.models.feedforward import FeedForwardModel1D

model = FeedForwardModel1D(width_size=256, depth=2)
model

In [None]:
def reverse_drift(y, t, args):
    t = np.clip(t, a_min=0.001)
    f = np.asarray(constant_drift(y, t, args))
    g = np.asarray(time_dependent_diffusion(y, t, args))
    gaussian_score = GaussianModel(mu=0, log_sigma=0.1 * np.log(t))
    s = np.asarray(gaussian_score(y))
    return f - 0.5 * g**2 * s

reverse_drift(y=3.0, t=0.0, args=())

In [None]:
noising_ys.shape

In [None]:
sde = SDE(reverse_drift, time_dependent_diffusion)
ts = np.linspace(5, 0, n_timesteps)  # key here: time has to be reversed!
sde = partial(sde, ts)
keys = random.split(random.PRNGKey(45), n_starting)
y0s = vmap(sde)(noising_ys.T[-1], keys)

for y in y0s:
    plt.plot(ts, y, color="blue", alpha=0.1)

plt.xlabel("t")
plt.ylabel("y")
plt.gca().invert_xaxis()
sns.despine()


Some notes here: 

- In computational experiments, I noticed under/overflow issues with certain parameter scales.

## Probability Flow ODEs

Now that we've recapped what an ODE is, let's examine what probability flow ODEs are.
Probability flow ODEs have the following form:

$$dx = [f(x,t) - \frac{1}{2} g^2(t) \nabla_x \log p_t (x)] dt$$



Just like the SDE above,

$$dx = f(x, t)dt + g(t)dw$$

the terms carry the same meaning:

> - $f(x, t)$ is a drift function that produces a vector output,
> - $g(t)$ is a diffusion function that produces a scalar output,
> - and $dw$ is infinitesimal white noise.  
> 
> (paraphrased from Yang's blog)


So here, to obtain the appropriate ODE,
we just need to take the score function term, square it,
and multiply it with a score function (or an estimator function).

Let's experiment here with Gaussians, just to see how this works out.

In [None]:
def combined_term(y, t, args):
    f = constant_drift(y, t, args)
    g = time_dependent_diffusion(y, t, args)
    gaussian_score = GaussianModel(mu=0, log_sigma=t)
    s = gaussian_score(y)
    return f - 0.5 * g**2 * s


ode_combined = ODE(reverse_drift)
# ode_drift = ODE(constant_drift)
# ode_diffusion = ODE(lambda y, t, args: -time_dependent_diffusion(y, t, args))


ode = ode_combined
ts = np.linspace(5, 0, 1000)
key = random.PRNGKey(55)
y0s = random.normal(key=key, shape=(100,))
ys = ode(ts, y0s)

for y in ys.T:
    plt.plot(ts, y, color="blue", alpha=0.1)

sns.despine()
plt.ylim(-3, 3)