## Burger's equation

Tools: jax (to install `pip install jax jaxlib flax`)

Goal: have a first simple 1D model to work with.

https://arxiv.org/pdf/1711.10561.pdf

Burger's equation becomes:
$$
u_t + u \times u_x − (0.01/π)u_{xx} = 0, x ∈ [−1, 1], t ∈ [0, 1], \\
u(0, x) = − sin(πx), \\
u(t, −1) = u(t, 1) = 0
$$

In [None]:
import jax.numpy as np
import jax
from jax import grad, jit, vmap, jacfwd, jacrev
from jax import random
import flax
from flax import linen as nn
from flax.core import freeze, unfreeze
from typing import Sequence

# Jax uses a state based random number generation 
# process, which is much less error prone than (hidden) 
# stateless cases in tensorflow, pytorch, but more 
# cumbersome. You need to split the key into 2 parts 
# and use subparts to generate your random numbers.
key = random.PRNGKey(0)
key, subkey = random.split(key)


class MLP(nn.Module):
    features: Sequence[int]

    def setup(self):
        self.layers = [nn.Dense(feat) for feat in self.features]

    def __call__(self, t, x):
        h = np.concatenate((t,x))
        for i, layer in enumerate(self.layers):
            h = layer(h)
            if i != len(self.layers) - 1:
                h = nn.tanh(h)
        return h

x = np.zeros((1,))
t = np.zeros((1,))

model = MLP(features=[3,4,5,1])
params = model.init(subkey, t, x)
y = model.apply(params, t, x)

print('initialized parameter shapes:\n', jax.tree_map(np.shape, unfreeze(params)))
print('output:\n', y)

def u(t, x, params):
    return model.apply(params, t, x)[0]

In [None]:
# t = 0 border condition
def u0(x):
    return - np.sin(np.pi * x)

# u_xx
def hessian(f, index_derivation=0):
    return jacfwd(jacrev(f,index_derivation),index_derivation)

@jit
def f(t_, x_, params_):
    u_out = u(t_, x_, params_)
    u_t = grad(u,0)(t_, x_, params_)
    u_x = grad(u,1)(t_, x_, params_)
    u_xx = hessian(u, 1)(t_, x_, params_)
    f_out = u_t + u_out*u_x - (0.01/np.pi)*u_xx
    return f_out

In [None]:
# A test point
x_test = np.ones(1) * 0.25
t_test = np.ones(1) * 0.25

u(t_test, x_test, params), f(t_test, x_test, params)

In [None]:
def mse_f(t_, x_, params_):
    return np.mean(f(t_, x_, params_)**2)

mse_f(t_test, x_test, params)

In [None]:
# use of vmap to batch efficiently
batched_mse_f = vmap(mse_f, (0, 0, None), 0)

def loss_f(t_, x_, params_):
    return np.sum(batched_mse_f(t_, x_, params_))

loss_f(np.zeros((10, 1)), 
       np.zeros((10, 1)),  
       params)

In [None]:
def mse_u(t_, x_, u_, params_):
    return np.mean((u_ - u(t_, x_, params_))**2)

mse_u(np.zeros(1), np.zeros(1), np.zeros(1), params)

In [None]:
# use of vmap to batch efficiently
batched_mse_u = vmap(mse_u, (0, 0, 0, None), 0)

def loss(t_, x_, u_, params_):
    return np.sum(batched_mse_u(t_, x_, u_, params_))

loss(np.zeros((10, 1)), 
     np.zeros((10, 1)), 
     np.zeros((10, 1)), 
     params)

In [None]:
# total loss, todo
def mse(params_, t_u_, x_u_, u_, t_f_, x_f_):
    return jit(mse_u)(t_u_, x_u_, u_, params_) + jit(mse_f)(t_f_, x_f_, params_)

In [None]:
loss_grad_fn = jax.value_and_grad(mse)

# Will output 2 objects: loss and gradients wrt to every parameters
loss_grad_fn(params, np.ones(1), np.zeros(1), np.zeros(1), np.ones(1), np.zeros(1))

#### Data and learning

We build $N_u = 100$ boundary data points as mentionned in the paper. Half of them for $t=0$, the other half for $x= \pm 1$. Wrap it into a dataset class

In [None]:
class Dataset():
    def __init__(self, key, batch_size=10, N_u=100):
        self.batch_size = batch_size
        self.N_u = N_u
        self.curr_idx = 0
        x_,t_,u_ = self.generate_data(key, N_u)
        self.x_data = x_
        self.t_data = t_
        self.u_data = u_
        
        
    def generate_data(self, key, N_u):
        key, subkey = random.split(key)
        data_type = random.uniform(key, (N_u,))>0.5
        key, subkey1, subkey2, subkey3 = random.split(key, 4)
        x_data = data_type * (random.uniform(subkey1, (N_u,))*2.0-1.0)+ \
            (1-data_type) * ((random.uniform(subkey2, (N_u,))>0.5)*2.0-1.0)
        t_data = data_type * 0 + ((1-data_type) * random.uniform(subkey3, (N_u,)))
        u_data = data_type * u0(x_data)
        return x_data, t_data, u_data

        
    def next_batch(self):
        bstart = self.curr_idx * self.batch_size
        bend = (self.curr_idx + 1) * self.batch_size
        if bend >= self.N_u:
            bend = self.N_u-1
            self.curr_idx = 0
        else:
            self.curr_idx = self.curr_idx + 1
        x_ = np.expand_dims(self.x_data[bstart:bend], -1)
        t_ = np.expand_dims(self.t_data[bstart:bend], -1)
        u_ = np.expand_dims(self.u_data[bstart:bend], -1)
        return t_, x_, u_
    
    def uniform_batch(self, key):
        key1, key2 = random.split(key, 2)
        tb = random.uniform(key1, (self.batch_size, 1))
        xb = random.uniform(key2, (self.batch_size, 1))
        return tb, xb
    
ds = Dataset(key)

In [None]:
# Optimizer

import optax
key, subkey = random.split(key, 2)
params = model.init(subkey, t, x)
tx = optax.sgd(learning_rate=0.05)
opt_state = tx.init(params)

In [None]:
steps = 1000
batch = 10
loss_grad_fn = jit(jax.value_and_grad(loss, 3))
loss_grad_f_fn = jit(jax.value_and_grad(loss_f, 2))

# Main train loop
for i in range(steps):
    tb, xb, ub = ds.next_batch()
    loss_val, grads = loss_grad_fn(tb, xb, ub, params)

    # f_loss: not used for now
    key, subkey = random.split(key, 2)
    tb_uni, xb_uni = ds.uniform_batch(subkey)
    loss_f_val, grads_f = loss_grad_f_fn(tb_uni, xb_uni, params)
    
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    #updates, opt_state = tx.update(grads_f, opt_state)
    #params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print(f'Loss step {i}: mse_u: {loss_val:.2f} / mse_f: {loss_f_val:.2f}')

#### Display


In [None]:
u(np.zeros(1,), x_test, params)

In [None]:
import matplotlib.pyplot as plt

def display_u(resolution):
    nt, nx = resolution,resolution
    t = np.linspace(0, 1, nt)
    x = np.linspace(-1, 1, nx)
    print(t.shape, x.shape)
    tv, xv = np.meshgrid(t, x)
    tv = tv.reshape((nt * nx, 1))
    xv = xv.reshape((nt * nx, 1))

    batched_u = vmap(u, (0, 0, None), 0)
    values = batched_u(tv, xv, params)
    grid = values.reshape(nt, nx)
    plt.imshow(grid)
    plt.xticks([x*resolution/5 for x in range(5)], [round(i/(1.0*5), 2) for i in range(5)])
    plt.yticks([x*resolution/5 for x in range(5)], [round(1.0-2*i/(5), 2) for i in range(5)])
    plt.colorbar()

In [None]:
display_u(100)

todo: 
- sample $N_f$ points for evaluation of $f$ 

## Modèle KPP

$$
\begin{equation} \label{eq:KPP_homog}
  \partial_t u(t,x) = D \Delta u + r u (1 - u), \ t>0, \ x\in \Omega \subset \mathbb{R}^2,
\end{equation}
$$


avec la condition initiale $u(0,\cdot)=u_0(\cdot)$ dans $\Omega$ et la condition au bord $u(t,\cdot )=0$ sur $\partial\Omega$ pour tout $t>0$. On pourra prendre $\Omega=(0,1)\times(0,1)$.

To be continued.

In [None]:
### OLD

import jax.numpy as np
from jax import grad, jit, vmap, jacfwd, jacrev
from jax import random
import flax
from flax import linen as nn


# Jax uses a state based random number generation 
# process, which is much less error prone than (hidden) 
# stateless cases in tensorflow, pytorch, but more 
# cumbersome. You need to split the key into 2 parts 
# and use subparts to generate your random numbers.
key = random.PRNGKey(0)
key, subkey = random.split(key)

W = random.normal(subkey, (2, 1))
b = np.zeros(1)
params = (W, b)

# Simple u function: only 3 params!
def u(t, x, params):
    W, b = params
    xandt = np.concatenate((t, x))
    return np.dot(xandt, W) + b

for noobs: less Numpy way to build data (with loops)

In [None]:
N_u = 100
data = []

for i in range(N_u):
    x_data,t_data,u_data = 0.0, 0.0, 0.0
    key, subkey = random.split(key)
    if  random.uniform(subkey)>0.5:
        t_data = 0
        key, subkey = random.split(key)
        x_data = random.uniform(subkey)*2-1
        u_data = u0(x_data)
        # t=0, u= -sin(pi x)
    else:
        key, subkey = random.split(key)
        t_data = random.uniform(subkey)
        key, subkey = random.split(key)
        x_data = (random.uniform(subkey)>0.5)*2-1
        u_data = 0
    data.append([x_data,t_data,u_data])
data