In [None]:
%load_ext autoreload
%autoreload 2

## Modèle KPP

Tools: jax (to install `pip install jax jaxlib flax`)
If you are unfamiliar with Jax random generation, check [this](https://jax.readthedocs.io/en/latest/jax.random.html)


$$
\begin{equation} \label{eq:KPP_homog}
  \partial_t u(t,x) = D \Delta u + r u (1 - u), \ t>0, \ x\in \Omega \subset \mathbb{R}^2,
\end{equation}
$$


avec la condition initiale $u(0,\cdot)=u_0(\cdot)$ une gausienne dans $\Omega$ très "peakée", et la condition au bord $u(t,\cdot )=0$ sur $\partial\Omega$ pour tout $t>0$. On pourra prendre $\Omega=(0,1)\times(0,1)$.


In [None]:
import jax.numpy as np
import jax
from jax import grad, jit, vmap, jacfwd, jacrev
from jax import random
from models.nets import MLP
from functools import partial

key = random.PRNGKey(0)
key, subkey = random.split(key)

# A test point
x_test = np.ones(2) * 0.25
t_test = np.ones(1) * 0.25

model = MLP(features=[20,40,80,80,80,40,20, 1])
init_params = model.init(subkey, t_test, x_test)

@jit
def u(t, x, params_):
    return model.apply(params_, t, x)[0]

print('initialized parameter shapes:\n', jax.tree_map(np.shape, init_params))
print(f'\nu(x, t): {u(t_test, x_test, init_params):.3f}')

In [None]:
D = 1.0
r = 1.0

def hessian(f, index_derivation=0):
    return jacfwd(jacrev(f,index_derivation),index_derivation)

def pde_rhs(t_, x_, params_):
    u_out = u(t_, x_, params_)
    lap_u = np.trace(np.squeeze(hessian(u,1)(t_, x_, params_)))
    return D*lap_u - r*u_out*(1-u_out)

@jit
def f(t_, x_, params_):
    u_t = grad(u,0)(t_, x_, params_)
    f_out = u_t - pde_rhs(t_, x_, params_)
    return f_out

In [None]:
u(t_test, x_test, init_params), f(t_test, x_test, init_params)

In [None]:
def u0(x):
    return np.clip(10 *np.exp(-vmap(np.dot)(x-0.5,x-0.5)*100), 0.0, 1.0)

In [None]:
def loss(batches, params_, delta = 1e-1, loss_weights={'u': 1.0/3, 'f': 1.0/3, 'delta': 1.0/3}):
    t_, x_, u_, tf_, xf_ = batches
    
    # Physics with mse_f
    mse_f = lambda t,x: partial(f, params_=params_)(t,x)**2
    v_mse_f = vmap(mse_f, (0,0), 0)
    loss_f = np.mean(v_mse_f(tf_, xf_))
    
    # Delta physics
    def mse_delta(t_, x_, params_):
        return (u(t_ + delta, x_, params_) - u(t_, x_, params_) - delta * pde_rhs(t_, x_, params_))**2
    v_mse_delta = vmap(partial(mse_delta, params_=params_), (0,0), 0)
    loss_delta = .5 * (np.mean(v_mse_delta(t_, x_)) + np.mean(v_mse_delta(tf_, xf_)))

    # Borders with mse_u
    def mse_u(t_, x_, u_, params_):
        return np.squeeze((u_ - u(t_, x_, params_))**2)
    v_mse_u = vmap(partial(mse_u, params_=params_), (0,0,0), 0)
    loss_u = np.mean(v_mse_u(t_, x_, u_))
    
    # total loss, then aux loss values. Only the first output is differentiated (because of has_aux=True below)
    total_loss = loss_weights['u'] * loss_u + loss_weights['f'] * loss_f + loss_weights['delta'] * loss_delta
    return (total_loss, (loss_u, loss_f, loss_delta))

losses_and_grad = jit(jax.value_and_grad(loss, 1, has_aux=True))

In [None]:
# Testing the loss function
losses, grads = losses_and_grad((np.zeros((10, 1)), 
                                 np.zeros((10, 2)), 
                                 np.ones((10, 1))*0.4, 
                                 np.ones((10, 1))*0.25,
                                 np.ones((10, 2))*0.25),
                                 init_params)


a, (b, c, d) = losses
print(f"total loss: {a:.3f}, mse_u: {b:.3f}, mse_f: {c:.3f}, mse_delta: {d:.7f}")

#### Data and learning

We build $N_u = 100$ boundary data points as mentionned in the paper. Half of them for $t=0$, the other half for $x= \pm 1$. Wrap it into a dataset class

In [None]:
from data import datasets

key, subkey = random.split(key, 2)
ds = datasets.KPPDataset(subkey, u0, batch_size=128, N_u=128*5)

In [None]:
# Optimizer
import optax
key, subkey = random.split(key, 2)
params = model.init(subkey, t_test, x_test)
tx = optax.adam(learning_rate=0.0003)
opt_state = tx.init(params)

In [None]:
# Main train loop
steps = 3000
losses_total, losses_u, losses_f, losses_delta = [], [], [], []

for i in range(steps):
    key, subkey1, subkey2 = random.split(key, 3)
    tb, xb, ub = ds.border_batch(subkey1)
    tb_uni, xb_uni = ds.inside_batch(subkey2)

    losses, grads = losses_and_grad((tb, xb, ub, tb_uni, xb_uni), 
                                    params,
                                    loss_weights={'u': 2, 'f': 1, 'delta': 1})
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    total_loss_val, (mse_u_val, mse_f_val, mse_delta_val) = losses
    losses_total.append(total_loss_val)
    losses_u.append(mse_u_val)
    losses_f.append(mse_f_val)
    losses_delta.append(mse_delta_val)    
    if i % 100 == 99:
        print(f'Loss at step {i+1}: {total_loss_val:.4f} / mse_u: {np.log10(mse_u_val):.4f} / mse_f: {np.log10(mse_f_val):.4f} /  mse_delta: {np.log(mse_delta_val):.6f}') 

In [None]:
import matplotlib.pyplot as plt
plt.plot(np.log10(np.array(losses_total)), label="total")
plt.plot(np.log10(np.array(losses_u)), label="mse_u")
plt.plot(np.log10(np.array(losses_f)), label="mse_f")
plt.plot(np.log10(np.array(losses_delta)), label="mse_delta")
#plt.plot(losses_f, label="mse_f")
plt.legend()
plt.show();

In [None]:
plt.plot(losses_u)

#### Display


In [None]:
batched_u = vmap(partial(u, params_=params), (0, 0), 0)
batched_u0 = lambda t,x: u0(x)

In [None]:
from data.display import display_KPP_at_times
# display_KPP_at_times(batched_u0, 30, times=[0.0,])

In [None]:
display_KPP_at_times(batched_u, 30, times=[0.0,0.01,0.1,0.2,1])

In [None]:
# Copying the current params
super_params = params.unfreeze().copy()

## KPP avec Solver standard discretisé

In [None]:
import jax.numpy as np
from models.solver import build_init_grid

Nx = Ny = 128
def u0(x, y):
    return 10 *np.exp(-((x-0.5)**2 + (y-0.5)**2)*100)

xx, yy, uu = build_init_grid(u0, Nx, Ny)
uu = uu/np.max(uu)
import matplotlib.pyplot as plt
plt.imshow(uu);

In [None]:
from perlin_numpy import generate_perlin_noise_2d
rgrid = generate_perlin_noise_2d((Nx,Ny),(Nx//16,Ny//16)) * 4

In [None]:
def r0(x,y):
    return 4 *np.exp(-((x-0.25)**2 + (y-0.75)**2)*5)

#rgrid = r0(xx, yy)
plt.imshow(rgrid);

In [None]:
import jax 
from models.solver import solver_iter

Nt = 300
grid = uu.copy()
plt.imshow(uu);
plt.colorbar();

In [None]:
for i in range(Nt * 2):
    grid = solver_iter(grid, 1./Nt, rgrid)
plt.imshow(grid);
plt.colorbar();

In [None]:
np.max(uu)

In [None]:
np.max(grid)

In [None]:
plt.imshow(grid - uu);
plt.colorbar();

## KPP avec NN discretisé

- [ ] Discretisation sur le temps: $u(x)$ est un vecteur de fonctions de $x \in \mathbb{R}^2$, de $T=64$ timesteps.
- [x] Discretisation sur l'espace: $u(t, u0)$ est une grille de taille $\Omega=128 \times 128$.
- [ ] Double discretisation $u$ est un volume de taille $T \times \Omega = 64 \times 32 \times 32$

In [None]:
from models.nets import UNet
import jax.numpy as np
import jax
from jax import grad, jit, vmap, jacfwd, jacrev
from jax import random
from functools import partial


key = random.PRNGKey(1)
key, subkey = random.split(key)
unet = UNet(features=64, training=False)

init_rngs = {'params': random.PRNGKey(0)}

unet_variables = unet.init(init_rngs, np.ones([1, 128, 128, 1]), np.ones([1]))

@jit
def unet_func(params, x, t):
    return unet.apply({"params":params}, x, t)

In [None]:
u0_batched = np.expand_dims(uu, (0, -1))
t_test = np.array([[0.0]])
out_test = unet_func(unet_variables["params"], u0_batched, t_test)
print(out_test.shape)
plt.imshow(np.squeeze(out_test));

In [None]:
plt.imshow(np.squeeze(u0_batched));

In [None]:
from functools import partial

def jac(params, f, t, x):
    unet_t = partial(f, params, x)
    out = np.squeeze(jacfwd(unet_t)(t), (-1,-2))
    return out

jac(unet_variables["params"], unet_func, t_test, u0_batched).shape

In [None]:
from tools.jaxtools import laplacian_grid
D = 1.0
r = 1.0

@jit
def mse_phy(params, t_, x_):
    # time derivative
    du_t = jac(params, unet_func, t_, x_)
    # rhs
    u_out = unet_func(params, x_, t_)
    lap_u = np.expand_dims(laplacian_grid(np.squeeze(u_out)), (0, -1))
    pde_rhs = D*lap_u - r*u_out*(1-u_out)
    f_out = du_t - pde_rhs
    return np.mean(f_out**2)

mask_border = np.expand_dims(1. - np.pad(np.ones((Nx-2, Ny-2)), ((1, 1), (1,1))), (0, -1))

@jit
def mse_border(params, t_, x_):
    u_out = unet_func(params, x_, t_)
    return np.mean(u_out **2 * mask_border)

In [None]:
(mse_phy(unet_variables["params"], t_test, u0_batched), 
 mse_border(unet_variables["params"], t_test, u0_batched))

In [None]:
def loss(t_, params_, delta = 1e-1, loss_weights={'border': 1., 'phy': 1.}):
    loss_border, loss_phy = mse_phy(params_, t_, u0_batched), mse_border(params_, t_, u0_batched)
    total_loss = loss_weights['border'] * loss_border + loss_weights['phy'] * loss_phy
    
    # total loss, then aux loss values. Only the first output is differentiated (because of has_aux=True below)
    return (total_loss, (loss_border, loss_phy))

losses_and_grad = jit(jax.value_and_grad(loss, 1, has_aux=True))

Sanity check: we verify that we are able to take the gradient of the loss with regards to the unet params:

In [None]:
grads = grad(loss, 1, has_aux=True)(t_test, unet_variables["params"])

In [None]:
# Testing the loss function
key, subkey = random.split(key)
t = random.uniform(subkey, shape=(1,1))

losses, grads = losses_and_grad(t, unet_variables["params"])

a, (b, c) = losses
print(f"total loss: {a:.3f}, mse_border: {b:.3f}, mse_phy: {c:.3f}")

In [None]:
from flax.training import train_state
import optax

def create_train_state(rng, learning_rate, momentum):
    cnn = UNet()
    params = cnn.init(rng, np.ones([1, 128, 128, 1]), np.ones([1]))['params']
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=cnn.apply, params=params, tx=tx)

In [None]:
@jax.jit
def train_step(state, key):
    """Train for a single step."""
    key, subkey = random.split(key)
    t = random.uniform(subkey, shape=(1,1))

    (_, logits), grads = losses_and_grad(t, state.params)
    state = state.apply_gradients(grads=grads)
    return state

In [None]:
state = create_train_state(subkey, 0.01, 0.99)
import tqdm
from tqdm.notebook import tqdm

def train_iter(state, n_iter, rng):
    for i in tqdm(range(n_iter)):
        state = train_step(state, rng)
        
        
train_iter(state, 10, subkey)

In [None]:
time = 0.0
out_test = unet_func(state.params, u0_batched, np.ones([1])* time)
plt.imshow(np.squeeze(out_test));