In [1]:
%load_ext autoreload
%autoreload 2

## Modèle KPP

Tools: jax (to install `pip install jax jaxlib flax`)
If you are unfamiliar with Jax random generation, check [this](https://jax.readthedocs.io/en/latest/jax.random.html)


$$
\begin{equation} \label{eq:KPP_homog}
  \partial_t u(t,x) = D \Delta u + r u (1 - u), \ t>0, \ x\in \Omega \subset \mathbb{R}^2,
\end{equation}
$$


avec la condition initiale $u(0,\cdot)=u_0(\cdot)$ une gausienne dans $\Omega$ très "peakée", et la condition au bord $u(t,\cdot )=0$ sur $\partial\Omega$ pour tout $t>0$. On pourra prendre $\Omega=(0,1)\times(0,1)$.


In [2]:
import jax.numpy as np
import jax
from jax import grad, jit, vmap, jacfwd, jacrev
from jax import random
from models.nets import MLP
from functools import partial

key = random.PRNGKey(0)
key, subkey = random.split(key)

# A test point
x_test = np.ones(2) * 0.25
t_test = np.ones(1) * 0.25

model = MLP(features=[20,40,80,80,80,40,20, 1])
init_params = model.init(subkey, t_test, x_test)

@jit
def u(t, x, params_):
    return model.apply(params_, t, x)[0]

print('initialized parameter shapes:\n', jax.tree_map(np.shape, init_params))
print(f'\nu(x, t): {u(t_test, x_test, init_params):.3f}')



initialized parameter shapes:
 FrozenDict({
    params: {
        layers_0: {
            bias: (20,),
            kernel: (3, 20),
        },
        layers_1: {
            bias: (40,),
            kernel: (20, 40),
        },
        layers_2: {
            bias: (80,),
            kernel: (40, 80),
        },
        layers_3: {
            bias: (80,),
            kernel: (80, 80),
        },
        layers_4: {
            bias: (80,),
            kernel: (80, 80),
        },
        layers_5: {
            bias: (40,),
            kernel: (80, 40),
        },
        layers_6: {
            bias: (20,),
            kernel: (40, 20),
        },
        layers_7: {
            bias: (1,),
            kernel: (20, 1),
        },
    },
})

u(x, t): -0.026


In [3]:
D = 1.0
r = 1.0

def hessian(f, index_derivation=0):
    return jacfwd(jacrev(f,index_derivation),index_derivation)

def pde_rhs(t_, x_, params_):
    u_out = u(t_, x_, params_)
    lap_u = np.trace(np.squeeze(hessian(u,1)(t_, x_, params_)))
    return D*lap_u - r*u_out*(1-u_out)

@jit
def f(t_, x_, params_):
    u_t = grad(u,0)(t_, x_, params_)
    f_out = u_t - pde_rhs(t_, x_, params_)
    return f_out

In [4]:
u(t_test, x_test, init_params), f(t_test, x_test, init_params)

(DeviceArray(-0.02555121, dtype=float32),
 DeviceArray([-0.02761158], dtype=float32))

In [5]:
def u0(x):
    return np.clip(10 *np.exp(-vmap(np.dot)(x-0.5,x-0.5)*100), 0.0, 1.0)

In [6]:
def loss(batches, params_, delta = 1e-1, loss_weights={'u': 1.0/3, 'f': 1.0/3, 'delta': 1.0/3}):
    t_, x_, u_, tf_, xf_ = batches
    
    # Physics with mse_f
    mse_f = lambda t,x: partial(f, params_=params_)(t,x)**2
    v_mse_f = vmap(mse_f, (0,0), 0)
    loss_f = np.mean(v_mse_f(tf_, xf_))
    
    # Delta physics
    def mse_delta(t_, x_, params_):
        return (u(t_ + delta, x_, params_) - u(t_, x_, params_) - delta * pde_rhs(t_, x_, params_))**2
    v_mse_delta = vmap(partial(mse_delta, params_=params_), (0,0), 0)
    loss_delta = .5 * (np.mean(v_mse_delta(t_, x_)) + np.mean(v_mse_delta(tf_, xf_)))

    # Borders with mse_u
    def mse_u(t_, x_, u_, params_):
        return np.squeeze((u_ - u(t_, x_, params_))**2)
    v_mse_u = vmap(partial(mse_u, params_=params_), (0,0,0), 0)
    loss_u = np.mean(v_mse_u(t_, x_, u_))
    
    # total loss, then aux loss values. Only the first output is differentiated (because of has_aux=True below)
    total_loss = loss_weights['u'] * loss_u + loss_weights['f'] * loss_f + loss_weights['delta'] * loss_delta
    return (total_loss, (loss_u, loss_f, loss_delta))

losses_and_grad = jit(jax.value_and_grad(loss, 1, has_aux=True))

In [7]:
# Testing the loss function
losses, grads = losses_and_grad((np.zeros((10, 1)), 
                                 np.zeros((10, 2)), 
                                 np.ones((10, 1))*0.4, 
                                 np.ones((10, 1))*0.25,
                                 np.ones((10, 2))*0.25),
                                 init_params)


a, (b, c, d) = losses
print(f"total loss: {a:.3f}, mse_u: {b:.3f}, mse_f: {c:.3f}, mse_delta: {d:.7f}")

total loss: 0.054, mse_u: 0.160, mse_f: 0.001, mse_delta: 0.0000000


#### Data and learning

We build $N_u = 100$ boundary data points as mentionned in the paper. Half of them for $t=0$, the other half for $x= \pm 1$. Wrap it into a dataset class

In [8]:
from data import datasets

key, subkey = random.split(key, 2)
ds = datasets.KPPDataset(subkey, u0, batch_size=128, N_u=128*5)

In [9]:
# Optimizer
import optax
key, subkey = random.split(key, 2)
params = model.init(subkey, t_test, x_test)
tx = optax.adam(learning_rate=0.0003)
opt_state = tx.init(params)

In [10]:
# Main train loop
steps = 10000
losses_total, losses_u, losses_f, losses_delta = [], [], [], []

for i in range(steps):
    key, subkey1, subkey2 = random.split(key, 3)
    tb, xb, ub = ds.border_batch(subkey1)
    tb_uni, xb_uni = ds.inside_batch(subkey2)

    losses, grads = losses_and_grad((tb, xb, ub, tb_uni, xb_uni), 
                                    params,
                                    loss_weights={'u': 2, 'f': 1, 'delta': 1})
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    total_loss_val, (mse_u_val, mse_f_val, mse_delta_val) = losses
    losses_total.append(total_loss_val)
    losses_u.append(mse_u_val)
    losses_f.append(mse_f_val)
    losses_delta.append(mse_delta_val)    
    if i % 100 == 99:
        print(f'Loss at step {i+1}: {total_loss_val:.4f} / mse_u: {np.log10(mse_u_val):.4f} / mse_f: {np.log10(mse_f_val):.4f} /  mse_delta: {np.log(mse_delta_val):.6f}') 

Loss at step 100: 0.0856 / mse_u: -1.3749 / mse_f: -2.9086 /  mse_delta: -20.427872
Loss at step 200: 0.0196 / mse_u: -2.0638 / mse_f: -2.6281 /  mse_delta: -19.878386
Loss at step 300: 0.0104 / mse_u: -2.4322 / mse_f: -2.5167 /  mse_delta: -19.684092
Loss at step 400: 0.0092 / mse_u: -2.4754 / mse_f: -2.6067 /  mse_delta: -19.088905
Loss at step 500: 0.0022 / mse_u: -3.4203 / mse_f: -2.8520 /  mse_delta: -19.861916
Loss at step 600: 0.0061 / mse_u: -2.6042 / mse_f: -2.9377 /  mse_delta: -19.744234
Loss at step 700: 0.0047 / mse_u: -2.7359 / mse_f: -2.9983 /  mse_delta: -20.061142
Loss at step 800: 0.0034 / mse_u: -2.9913 / mse_f: -2.8641 /  mse_delta: -19.517435
Loss at step 900: 0.0019 / mse_u: -3.7133 / mse_f: -2.8106 /  mse_delta: -19.844921
Loss at step 1000: 0.0032 / mse_u: -2.9441 / mse_f: -3.0495 /  mse_delta: -19.740261
Loss at step 1100: 0.0015 / mse_u: -3.8653 / mse_f: -2.9176 /  mse_delta: -20.290577
Loss at step 1200: 0.0016 / mse_u: -3.3585 / mse_f: -3.1667 /  mse_delta: 

In [None]:
import matplotlib.pyplot as plt
plt.plot(np.log10(np.array(losses_total)), label="total")
plt.plot(np.log10(np.array(losses_u)), label="mse_u")
plt.plot(np.log10(np.array(losses_f)), label="mse_f")
plt.plot(np.log10(np.array(losses_delta)), label="mse_delta")
#plt.plot(losses_f, label="mse_f")
plt.legend()
plt.show();

In [None]:
plt.plot(losses_u)

#### Display


In [None]:
batched_u = vmap(partial(u, params_=params), (0, 0), 0)
batched_u0 = lambda t,x: u0(x)

In [None]:
display_KPP_at_times(batched_u, 30, times=[0.0,0.01,0.1,0.2,1])

In [None]:
# Copying the current params
super_params = params.unfreeze().copy()