## Problem statement

The goal is to construct a surrogate model that approximates the solution of a time-dependent PDE for a number of fields on the sphere. In abstract from this PDE is written as

$$
\frac{\partial \psi (t)}{\partial t}  = \mathcal{F}(\psi(t))
$$

Here $\mathcal{F}$ is a given (not necessarily linear) operator.

### Input
The input of the model is a tensor $X_{b,i,j}$ of shape $(B,n_{\text{in}},n)$ where $B$ is the batch size. This tensor represents $n_{\text{in}}$ functions, each of which is represented by $n$ dofs: for fixed $b$, $i$ the vector $X_{b,i,*}$ is the dof-vector of a Firedrake function. For simplicity we assume that all functions live in the same function space.

The $n_{\text{in}}$ input functions are further split into $n_{\text{in}}^{\text{(dyn)}}$ functions that represent *dynamic* quantities such as velocity, pressure etc. and $n_{\text{in}}^{\text{(ancil)}}$ functions that represent static *ancillary* quantities such as coordinates or orography. We thus write $X = (X^{\text{(dyn)}},X^{\text{(ancil)}})$ and obviously $n_{\text{in}}=n_{\text{in}}^{\text{(dyn)}}+n_{\text{in}}^{\text{(ancil)}}$.

### Output
Similarly, the output of the model is a tensor $Y_{b,i,j}$ of shape $(B,n_{\text{out}},n)$ which represents $n_{\text{out}}$ functions, each of which is represented by $n$ dofs: for fixed $b$, $i$ the vector $Y_{b,i,*}$ is the dof-vector of a Firedrake function.

The dof-vectors $Y^{(\text{true})}_{b,i,j}$ of the true (or target) functions are obtained by solving a high-resolution model which integrates the discretised version of the PDE above and then possibly projects onto $n_{\text{out}}\le n_{\text{in}}$ functions which represent the "observed" solution. Here $X_{b,i,j}$ defines the initial conditions $\psi(t=0)$ and $Y^{(\text{true})}_{b,i,j}$ is obtained from the solution $\psi(t=T)$ at some later time $T$. Of course, $Y^{(\text{true})}$ could also represent the "true" state of the dynamical system, obtained via data assimilation.

The goal is to find a learnable model $Y=\Phi_\theta(X)$ such that the some loss $L(Y,Y^{(\text{true})})$ is minimised. The simples loss function is the MSE loss $L(Y,Y^{(\text{true})}) = \frac{1}{2}||Y-Y^{(\text{true})}||^2$ which is proportional to the $L_2$ error in the function space.

## Model structure

As in [Ryan Keisler's paper](https://arxiv.org/abs/2202.07575) and in [GraphCast](https://www.science.org/doi/epdf/10.1126/science.adi2336) the model $\Phi_\theta = \mathcal{D}_{\theta_D}\circ\mathcal{P}_{\theta_P}\circ\mathcal{E}_{\theta_E}$ is split into three components:

1. an **encoder** $\mathcal{E}_{\theta_E}$ which maps the input tensor $X_{b,i,j}$ to the latent space via learnable embeddings
2. a **processor** or **latent model** $\mathcal{P}_{\theta_P}$ which solves a system of coupled time-dependent ODEs in the latent space. The coupling structure of the system of ODEs is defined by the graph of a dual mesh and the forcing is a learned function.
3. a **decoder** $\mathcal{D}_{\theta_D}$ which maps the solution from the latent space back to the output tensor $Y_{b,i,j}$ via learnable embeddings

The following figure summarises the structure of the model; the individual components are explaned in detail below.

![Model structure](figures/model_structure.svg)
*Figure 1: model structure*

The main difference to [Ryan Keisler's paper](https://arxiv.org/abs/2202.07575) and [GraphCast](https://www.science.org/doi/epdf/10.1126/science.adi2336) is that the processor solves a time-dependent ODE instead of using message passing on a Graph Neural network. Hence, the model is a realisation of a [Neural ODE](https://arxiv.org/abs/1806.07366). However, the latent model only includes *local* interactions defined by the topology of the dual mesh. It can therefore be seen as a low-resolution discretisation of a learnable time-dependent PDE.


### Latent space
The latent space is constructed as follows: consider the dual mesh of a refined icosahedron. This dual mesh has $n_{\text{patch}}$ vertices and each vertex $\alpha\in0,1,\dots,n_{\text{patch}}-1$ has exactly three neighbours $\beta\in\mathcal{N}(\alpha)$. A state in the latent space is a tensor $Z_{b,\alpha,k}$ of shape $(B,n_{\text{patch}},d_{\text{lat}})$ where we call $d_{\text{lat}}$ is the dimension of the latest space. As for the input tensor, the tensor $Z$ is split into $d_{\text{lat}}^{\text{(dyn)}}$ dynamic components and $d_{\text{lat}}^{\text{(ancil)}}$ ancillary components in the latent dimension, so $Z=(Z^{\text{(dyn)}},Z^{\text{(ancil)}})$ and $d_{\text{lat}}=d_{\text{lat}}^{\text{(dyn)}}+d_{\text{lat}}^{\text{(ancil)}}$.

![dual mesh](figures/dual_mesh.svg)

*Figure 2: Vertices and edges of the dual mesh*


### Component I: Encoder
The encoder consists of two parts:
#### Projection to VOM
For each $b,i$ the functions represented by $X_{b,i,*}$ are first projected to a vertex-only mesh (VOM). For each of the vertices of the dual mesh that defines the latent space we construct a little circular patch that consists of $p$ points, as shown in the following figure (in reality, however, the radii of the patches are adjusted such that they overlap and thus form a complete covering of the sphere).

![patch covering of the sphere](figures/patch_covering.svg)

*Figure 3: patch covering of the sphere with patches shown in red and dual mesh shown in black*

The vertices of the VOM are then gives by the $n_{\text{patch}}\cdot p$ points of all vertices. We use Firedrake's interpolation functionality to implement the mapping $X_{b,i,j}\mapsto \overline{X}_{b,\alpha,i,\ell}$ where $\overline{X}$ is a tensor of shape $(B,n_{\text{patch}},n_{\text{in}},p)$. For given $b,i$ this is a linear map $\mathcal{P}:\mathbb{R}^n \rightarrow \mathbb{R}^{n_{\text{patch}}\times p}$

$$
\overline{X}_{b,*,i,*} = \mathcal{P} X_{b,i,*}
$$

Derivatives transform with respect to the adjoint $\mathcal{P}^\dagger$:

$$
\frac{\partial}{\partial X_{b,i,*}} = \mathcal{P}^\dagger \frac{\partial}{\partial \overline{X}_{b,*,i,*}}
$$

and this operation is also available in Firedrake. The class `FunctionToPatchInterpolationLayer` in [patch_interpolation.py](src/neural_pde/patch_interpolation.py) uses this observation to implement the projection as a subclass of a tensorflow [keras layer](https://www.tensorflow.org/api_docs/python/tf/keras/.layers) through which we can back-propagate. The tensor $\overline{X}=(\overline{X}^{\text{dyn}},\overline{X}^{\text{ancil}})$ is split into a *dynamic* and an *ancillary* component in the same way as the input tensor.

#### Learnable embedding
Next, the tensor $\overline{X}_{b,p,i,\ell}$ is converted to a state in latent space via a learnable embedding. For this, we treat the dynamic and the ancillary components differently. Let $E^{\text{(dyn)}}_{\theta_E}:\mathbb{R}^{n_{\text{in}}\times P} \rightarrow \mathbb{R}^{d_{\text{dyn}}}$ and $E^{\text{(ancil)}}_{\theta_E}:\mathbb{R}^{n_{\text{in}}^{\text{(ancil)}}\times P} \rightarrow \mathbb{R}^{d_{\text{ancil}}}$ be the possibly non-linear encoder functions. Then for each sample $b$ and each patch $\alpha$

$$
\begin{aligned}
Z_{b,\alpha,*}^{\text{(dyn)}} &= E^{\text{(dyn)}}_{\theta_E} \left( \overline{X}^{\text{(dyn)}}_{b,\alpha,*,*},\overline{X}^{\text{(ancil)}}_{b,\alpha,*,*}\right)\\
Z_{b,\alpha,*}^{\text{(ancil)}} &= E^{\text{(ancil)}}_{\theta_E} \left( \overline{X}^{\text{(ancil)}}_{b,\alpha,*,*}\right)
\end{aligned}
$$

Note that the ancillary embedding depends only on the ancillary fields whereas the dynamic embedding depends on both the dynamic and the ancillary fields.


### Component II: Processor
In latent space we use the tensor $Z_{b,\alpha,k}=(Z_{b,\alpha,k}^{\text{(dyn)}},Z_{b,\alpha,k}^{\text{(ancil)}})$ as the initial condition for the following time-dependent ODE
$$
\begin{aligned}
\frac{\partial Z^{\text{(dyn)}}(t)}{\partial t} &= \mathcal{F}_{\theta_P} (Z^{\text{(dyn)}}(t),Z^{\text{(ancil)}})\qquad \text{with $Z^{\text{(dyn)}}(0) = Z^{\text{(dyn)}}$}\\
\frac{\partial Z^{\text{(ancil)}}(t)}{\partial t} &= 0\qquad \text{with $Z^{\text{(ancil)}}(0) = Z^{\text{(ancil)}}$}
\end{aligned}
$$

where $\mathcal{F}_{\theta_P}$ is a learnable function. The structure of this function is dictated by the topology of the dual mesh that defines the latent space. More specifically, let $F_{\theta_P}:\mathbb{R}^{4,d_{\text{lat}}}\rightarrow \mathbb{R}^{d_{\text{lat}}^{\text{(dyn)}}}$. Then for each batch $b$ and vertex $\alpha$ we have that

$$
\frac{\partial Z_{b,\alpha,*}(t)}{\partial t} = F_{\theta_P} \left( (Z_{b,\alpha,*}(t),Z_{b,\beta_0,*}(t),Z_{b,\beta_1,*}(t),Z_{b,\beta_2,*}(t) )^\top \right)
$$

where $\beta_0$, $\beta_1$ and $\beta_2$ are the indices of the vertices that are direct neighbours of the vertex $\alpha$ in the dual mesh, as shown in the following figure.

![local interaction](figures/interaction.svg)

*Figure 4: local interaction with the function $F_{\theta_P}$ at vertex $\alpha$. The 5-dimensional state vectors at the vertices are represented by the little boxes for a latent space with $d_{\text{lat}}^{\text{(dyn)}}=3$ (indicated in red) and $d_{\text{lat}}^{\text{(ancil)}}=2$ (indicated in blue)*

In practice the time-dependent ODE is solved with a numerical timestepping method; the current implementation uses a simple forward-Euler scheme.

The output of the processor is the solution $Z'_{b,\alpha,k}=Z_{b,\alpha,k}(\tau)$ at the final latent time $\tau$.


### Component III: Decoder
Like the encoder, the decoder is also split into two parts.

#### Learnable embedding
First the tensor $Z'_{b,\alpha,k}$ is mapped to a tensor $\overline{Y}_{b,\alpha,i,\ell}$ on the VOM via a learnable embedding. For this let $D_{\theta_D}:\mathbb{R}^{d_{\text{lat}}} \rightarrow \mathbb{R}^{n_{\text{out}}\times P}$ be a learnable function. Then for each sample $b$ and each patch $\alpha$

$$
\overline{Y}_{b,\alpha,*,*} = D_{\theta_D}\left(Z'_{b,\alpha,*}\right)
$$

#### Reconstruction on original mesh
To recover the output tensor $Y_{b,i,j}$ which represents the solution functions on the original mesh we use the adjoint of the map $\mathcal{P}$ that is used for the projection to the VOM. For each sample $b$ and each output field $i$

$$
Y_{b,i,*} = \mathcal{P}^\dagger \overline{Y}_{b,*,i,*}
$$

Now derivatives transform with $\mathcal{P}$ itself:

$$
\frac{\partial} {\partial \overline{Y}_{b,*,i,*}} = \mathcal{P} \frac{\partial}{\partial Y_{b,i,*}}.
$$

The class `PatchToFunctionInterpolationLayer` in [patch_interpolation.py](src/neural_pde/patch_interpolation.py) implements this reconstruction as a subclass of a tensorflow [keras layer](https://www.tensorflow.org/api_docs/python/tf/keras/.layers) through which we can back-propagate.

### Summary of learnable functions
The learnable functions in the following table are represened by small neural networks.
| component | function | from | to |
| ---- | ---- | ---- | ---- |
| encoder (dynamic) | $E^{\text{(dyn)}}_{\theta_E}$ | $\mathbb{R}^{n_{\text{in}}\times P}$ | $\mathbb{R}^{d_{\text{dyn}}}$ |
| encoder (ancillary) | $E^{\text{(ancil)}}_{\theta_E}$ | $\mathbb{R}^{n_{\text{in}}^{\text{(ancil)}}\times P}$ | $\mathbb{R}^{d_{\text{ancil}}}$ |
| processor forcing function | $F_{\theta_P}$ | $\mathbb{R}^{4,d_{\text{lat}}}$ | $ \mathbb{R}^{d_{\text{lat}}^{\text{(dyn)}}}$ |
| decoder | $D_{\theta_D}$ | $\mathbb{R}^{d_{\text{lat}}}$ | $\mathbb{R}^{n_{\text{out}}\times P}$ |

*Table 1: list of learnable functions*
