In [None]:
%load_ext autoreload
%autoreload 2

## Modèle KPP

Tools: jax (to install `pip install jax jaxlib flax`)
If you are unfamiliar with Jax random generation, check [this](https://jax.readthedocs.io/en/latest/jax.random.html)


$$
\begin{equation} \label{eq:KPP_homog}
  \partial_t u(t,x) = D \Delta u + r u (1 - u), \ t>0, \ x\in \Omega \subset \mathbb{R}^2,
\end{equation}
$$


avec la condition initiale $u(0,\cdot)=u_0(\cdot)$ une gausienne dans $\Omega$ très "peakée", et la condition au bord $u(t,\cdot )=0$ sur $\partial\Omega$ pour tout $t>0$. On pourra prendre $\Omega=(0,1)\times(0,1)$.


In [None]:
import jax.numpy as np
import jax
from jax import grad, jit, vmap, jacfwd, jacrev
from jax import random
from models.nets import MLP
from functools import partial

key = random.PRNGKey(0)
key, subkey = random.split(key)

# A test point
x_test = np.ones(2) * 0.25
t_test = np.ones(1) * 0.25

model = MLP(features=[20,20,20,20,20,20,20, 1])
init_params = model.init(subkey, t_test, x_test)

@jit
def u(t, x, params_):
    return model.apply(params_, t, x)[0]

print('initialized parameter shapes:\n', jax.tree_map(np.shape, init_params))
print(f'\nu(x, t): {u(t_test, x_test, init_params):.3f}')

In [None]:
D = 1.0
r = 1.0

def hessian(f, index_derivation=0):
    return jacfwd(jacrev(f,index_derivation),index_derivation)

@jit
def f(t_, x_, params_):
    u_out = u(t_, x_, params_)
    lap_u = np.trace(np.squeeze(hessian(u,1)(t_, x_, params_)))
    u_t = grad(u,0)(t_, x_, params_)
    f_out = u_t - D*lap_u - r*u_out*(1-u_out)
    return f_out

In [None]:
u(t_test, x_test, init_params), f(t_test, x_test, init_params)

In [None]:
def u0(x):
    return 10 *np.exp(-vmap(np.dot)(x-0.5,x-0.5)*100)

In [None]:
def loss(batches, params_):
    t_, x_, u_, tf_, xf_ = batches
    
    # Physics with mse_f
    mse_f = lambda t,x: partial(f, params_=params_)(t,x)**2
    v_mse_f = vmap(mse_f, (0,0), 0)
    loss_f = np.mean(v_mse_f(tf_, xf_))
    
    # Borders with mse_u
    def mse_u(t_, x_, u_, params_):
        return np.mean((u_ - u(t_, x_, params_))**2)
    v_mse_u = vmap(partial(mse_u, params_=params_), (0,0,0), 0)
    loss_u = np.mean(v_mse_u(t_, x_, u_))
    
    # total loss, then aux loss values. Only the first output is differentiated (because of has_aux=True below)
    return (0*loss_f+loss_u, (loss_u, loss_f))

losses_and_grad = jit(jax.value_and_grad(loss, 1, has_aux=True))

In [None]:
# Testing the loss function
losses, grads = losses_and_grad((np.zeros((10, 1)), 
                                 np.zeros((10, 2)), 
                                 np.ones((10, 1))*0.4, 
                                 np.ones((10, 1))*0.25,
                                 np.ones((10, 2))*0.25),
                                 init_params)


a, (b,c) = losses
print(f"total loss: {a:.3f}, mse_u: {b:.3f}, mse_f: {c:.3f}")

In [None]:
tb, xb = dsb.inside_batch()
tb.shape, xb.shape

In [None]:
tb,xb,ub = dsb.border_batch()

In [None]:
tb.shape, xb.shape, ub.shape

In [None]:
a1 = (random.uniform(subkey4, (5, 1))>0.5)*1.0
a2 = random.uniform(subkey5, (5, 1))

In [None]:
np.concatenate((a1, a2), axis=-1)

In [None]:
from data import datasets

key, subkey = random.split(key, 2)
ds = datasets.KPPDataset(subkey, u0, batch_size=32, N_u=200)

In [None]:
tb, xb = ds.inside_batch()
tb.shape, xb.shape

In [None]:
tb, xb, ub = ds.border_batch()
tb.shape, xb.shape, ub.shape

#### Data and learning

We build $N_u = 100$ boundary data points as mentionned in the paper. Half of them for $t=0$, the other half for $x= \pm 1$. Wrap it into a dataset class

In [None]:
from data import datasets

key, subkey = random.split(key, 2)
ds = datasets.KPPDataset(subkey, u0, batch_size=64, N_u=500)

In [None]:
_,xb,ub = ds.border_batch()

In [None]:
# Optimizer
import optax
key, subkey = random.split(key, 2)
params = model.init(subkey, t_test, x_test)
tx = optax.adam(learning_rate=0.001)
opt_state = tx.init(params)

In [None]:
# Main train loop
steps = 5000
for i in range(steps):
    tb, xb, ub = ds.border_batch()
    tb_uni, xb_uni = ds.inside_batch()

    losses, grads = losses_and_grad((tb, xb, ub, tb_uni, xb_uni), 
                                    params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    total_loss_val, (mse_u_val, mse_f_val) = losses
    
    if i % 100 == 99:
        print(f'Loss at step {i+1}: {total_loss_val:.4f} / mse_u: {mse_u_val:.4f} / mse_f: {mse_f_val:.4f}') 

#### Display


In [None]:
batched_u = vmap(partial(u, params_=params), (0, 0), 0)
batched_u0 = lambda t,x: u0(x)

In [None]:
#def u0(x):
#    return 10 *np.exp(-vmap(np.dot)(x-0.5,x-0.5)*100)

In [None]:
from data.display import display_KPP_at_times
display_KPP_at_times(batched_u0, 30, times=[0.0,])

In [None]:
display_KPP_at_times(batched_u, 30, times=[0.0,0.25,0.5,0.75])