# Continous PDE Dynamics Forecasting with Implicit Neural Representations

### Yin et al., ICLR 2023

Tutorial presented by: Jaisidh Singh and Arkadijs Sergejevs

## Introduction

- Partial differential equations (PDEs) are indispensible tools to model the dynamics and temporal evolution of physical phenomena.

- In classical methods, the dynamics described by PDEs are solved by numerical solvers, which are highly accurate but computationally slow.

- Recently, data produced by simulations has enabled deep learning based PDE forecasting, which what this work deals with.

## Motivation and Problem Statement

- PDE forecasting in areas like weather prediction can require extrapolation through new and arbitrary points in time or space.

- Current data-driven forecasting methods have notable drawbacks as they often rely fixed discretizations of the space domain, hence they

    1. Do not generalize outside the space seen during training (train grid)
    2. Show poor evaluation at unobserved spatial locations and on free-form grids.
    3. May not forecast well on new initial conditions
    4. Cannot forecast long-term (outside training horizon).

## Proposed solution

- This work presents `DINo`, a method using continuous-time dynamics of spatially continuous functions. This is done by 

    1. using Implicit Neural Representations (INRs) to embed spatial observations independently of how they were discretized. 
    2. At the same time, the temporal progress is modeled with an ordinary differential equation (ODE) in the latent embedding space.

## PDE Refresher

**Formally**: A PDE involves a function $u(x, t)$ and its partial derivative such that 

$$F(x, t, u(x, t), \frac{\partial}{\partial x}u(x, t), \frac{\partial^2}{\partial x^2}u(x, t), \dots) = 0$$

**Intuitively**: Since $u(x, t)$ is a time-varying function of space, we can also denote it as $v_t(x) = u(x, t)$ which is described by the same PDE.

### Example: Heat Equation

Given by $\frac{\partial}{\partial t}u(x, t) = \alpha \frac{\partial^2}{\partial x^2} u(x, t)$, this PDE is visualised in 2D space below, with $\alpha = 0.02$ and all boundaries insulated at $0$:


In [16]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.core.display import HTML
from matplotlib.animation import FuncAnimation

plt.ioff()

nx, ny = 50, 50  # Number of grid points
dx = dy = 0.1    # Grid spacing
alpha = 0.02     # Thermal diffusivity
dt = 0.1         # Time step (chosen for stability)
steps = 200      # Number of animation frames
skip = 3         # Show every nth arrow

x = np.linspace(-2, 2, nx)
y = np.linspace(-2, 2, ny)
X, Y = np.meshgrid(x, y)
u = np.zeros((ny, nx))
u[ny//3:2*ny//3, nx//3:2*nx//3] = 100  # Square hot region
u += 50 * np.exp(-((X-1)**2 + (Y-1)**2)/0.5)  # Gaussian hot spot

def evolve_temperature(u):
    d2udx2 = (np.roll(u, 1, axis=1) - 2*u + np.roll(u, -1, axis=1)) / (dx**2)
    d2udy2 = (np.roll(u, 1, axis=0) - 2*u + np.roll(u, -1, axis=0)) / (dy**2)
    
    dudt = alpha * (d2udx2 + d2udy2)
    return u + dt * dudt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
fig.suptitle('2D Heat Equation')

heatmap = ax1.imshow(u, extent=[-2, 2, -2, 2], origin='lower', 
                     cmap='inferno', vmin=0, vmax=100)
plt.colorbar(heatmap, ax=ax1, label='Temperature')
ax1.set_title('Temperature Field')

dx_temp = np.gradient(u, dx, axis=1)
dy_temp = np.gradient(u, dy, axis=0)
quiver = ax2.quiver(
    X[::skip, ::skip], 
    Y[::skip, ::skip],
    dx_temp[::skip, ::skip], 
    dy_temp[::skip, ::skip],
    np.sqrt(dx_temp[::skip, ::skip]**2 + dy_temp[::skip, ::skip]**2),
    cmap='viridis'
    )
plt.colorbar(quiver, ax=ax2, label='Temperature Gradient Magnitude')
ax2.set_title('Temperature Gradient')

def update(frame):
    global u
    
    u = evolve_temperature(u)
    heatmap.set_array(u)
    
    dx_temp = np.gradient(u, dx, axis=1)
    dy_temp = np.gradient(u, dy, axis=0)
    quiver.set_UVC(dx_temp[::skip, ::skip], dy_temp[::skip, ::skip],
                   np.sqrt(dx_temp[::skip, ::skip]**2 + dy_temp[::skip, ::skip]**2))
    
    return heatmap, quiver

ani = FuncAnimation(fig, update, frames=steps, interval=50, blit=True)
ani.save("heat_equation_2d.gif", writer="pillow")

plt.cla()
plt.clf()

HTML('<img src="heat_equation_2d.gif" width="1000" align="center">')

## `DINo` Notation and Formulation