In [None]:
%load_ext autoreload
%autoreload 2

## Burger's equation

Tools: jax (to install `pip install jax jaxlib flax`)
If you are unfamiliar with Jax random generation, check [this](https://jax.readthedocs.io/en/latest/jax.random.html)

Goal: have a first simple 1D model to work with similar to [this paper](https://arxiv.org/pdf/1711.10561.pdf)



Burger's equation becomes:
$$
u_t + u \times u_x − (0.01/π)u_{xx} = 0, x ∈ [−1, 1], t ∈ [0, 1], \\
u(0, x) = − sin(πx), \\
u(t, −1) = u(t, 1) = 0
$$

In [None]:
import jax.numpy as np
import jax
from jax import grad, jit, vmap, jacfwd, jacrev
from jax import random
from models.nets import MLP
from functools import partial

key = random.PRNGKey(0)
key, subkey = random.split(key)

# A test point
x_test = np.ones(1) * 0.25
t_test = np.ones(1) * 0.25

model = MLP(features=[20,20,20,20,20,20,20, 1])
init_params = model.init(subkey, t_test, x_test)

@jit
def u(t, x, params_):
    return model.apply(params_, t, x)[0]

print('initialized parameter shapes:\n', jax.tree_map(np.shape, init_params))
print(f'\nu(x, t): {u(t_test, x_test, init_params):.3f}')

In [None]:
# t = 0 border condition
def u0(x):
    return - np.sin(np.pi * x)

# u_xx
def hessian(f, index_derivation=0):
    return jacfwd(jacrev(f,index_derivation),index_derivation)

@jit
def f(t, x, params_):
    u_out = u(t, x, params_)
    u_t = grad(u,0)(t, x, params_)
    u_x = grad(u,1)(t, x, params_)
    u_xx = hessian(u, 1)(t, x, params_)[0]
    f_out = u_t + u_out*u_x - (0.01/np.pi)*u_xx
    return np.squeeze(f_out)

In [None]:
# Testing our functions
u(t_test, x_test, init_params), f(t_test, x_test, init_params)

In [None]:
def loss(batches, params_):
    t_, x_, u_, tf_, xf_ = batches
    
    # Physics with mse_f
    mse_f = lambda t,x: partial(f, params_=params_)(t,x)**2
    v_mse_f = vmap(mse_f, (0,0), 0)
    loss_f = np.mean(v_mse_f(tf_, xf_))
    
    # Borders with mse_u
    def mse_u(t_, x_, u_, params_):
        return np.mean((u_ - u(t_, x_, params_))**2)
    v_mse_u = vmap(partial(mse_u, params_=params_), (0,0,0), 0)
    loss_u = np.mean(v_mse_u(t_, x_, u_))
    
    # total loss, then aux loss values. Only the first output is differentiated (because of has_aux=True below)
    return (loss_f+loss_u, (loss_u, loss_f))

losses_and_grad = jit(jax.value_and_grad(loss, 1, has_aux=True))

In [None]:
# Testing the loss function
losses, grads = losses_and_grad((np.zeros((10, 1)), 
                                 np.zeros((10, 1)), 
                                 np.ones((10, 1))*0.4, 
                                 np.ones((10, 1))*0.25,
                                 np.ones((10, 1))*0.25),
                                 init_params)


a, (b,c) = losses
print(f"total loss: {a:.3f}, mse_u: {b:.3f}, mse_f: {c:.3f}")

#### Data and learning

We build $N_u = 100$ boundary data points as mentionned in the paper. Half of them for $t=0$, the other half for $x= \pm 1$. Wrap it into a dataset class

In [None]:
from data import datasets

key, subkey = random.split(key, 2)
ds = datasets.BurgersDataset(subkey, u0, batch_size=32, N_u=200)

In [None]:
# Optimizer
import optax
key, subkey = random.split(key, 2)
params = model.init(subkey, t_test, x_test)
tx = optax.adam(learning_rate=0.001)
opt_state = tx.init(params)

In [None]:
# Main train loop
steps = 5000
for i in range(steps):
    tb, xb, ub = ds.border_batch()
    tb_uni, xb_uni = ds.inside_batch()
    
    losses, grads = losses_and_grad((tb, xb, ub, tb_uni, xb_uni), 
                                    params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    total_loss_val, (mse_u_val, mse_f_val) = losses
    
    if i % 100 == 99:
        print(f'Loss at step {i+1}: {total_loss_val:.4f} / mse_u: {mse_u_val:.4f} / mse_f: {mse_f_val:.4f}') 

#### Display


In [None]:
batched_u = vmap(partial(u, params_=params), (0, 0), 0)

In [None]:
from data.display import display_burgers_grid, display_burgers_slice

display_burgers_grid(batched_u, 100)

In [None]:
display_burgers_slice(batched_u, 30, slices=[0.0, 0.25, 0.5, 0.75])

## Modèle KPP

$$
\begin{equation} \label{eq:KPP_homog}
  \partial_t u(t,x) = D \Delta u + r u (1 - u), \ t>0, \ x\in \Omega \subset \mathbb{R}^2,
\end{equation}
$$


avec la condition initiale $u(0,\cdot)=u_0(\cdot)$ dans $\Omega$ et la condition au bord $u(t,\cdot )=0$ sur $\partial\Omega$ pour tout $t>0$. On pourra prendre $\Omega=(0,1)\times(0,1)$.

In [None]:
import jax.numpy as np
import jax
from jax import grad, jit, vmap, jacfwd, jacrev
from jax import random
import flax
from flax import linen as nn
from flax.core import freeze, unfreeze
from typing import Sequence

key = random.PRNGKey(0)
key, subkey = random.split(key)

class MLP(nn.Module):
    features: Sequence[int]

    def setup(self):
        self.layers = [nn.Dense(feat) for feat in self.features]

    def __call__(self, t, x):
        h = np.concatenate((t,x))
        for i, layer in enumerate(self.layers):
            h = layer(h)
            if i != len(self.layers) - 1:
                h = nn.tanh(h)
        return h

x = np.zeros((2,))
t = np.zeros((1,))

model = MLP(features=[20,20,1])
params = model.init(subkey, t, x)
y = model.apply(params, t, x)

print('initialized parameter shapes:\n', jax.tree_map(np.shape, unfreeze(params)))
print('output:\n', y)

def u(t, x, params):
    return model.apply(params, t, x)[0]

In [None]:
D = 1.0
r = 1.1

def hessian(f, index_derivation=0):
    return jacfwd(jacrev(f,index_derivation),index_derivation)

@jit
def f(t_, x_, params_):
    u_out = u(t_, x_, params_)
    lap_u = np.trace(np.squeeze(hessian(u,1)(t_, x_, params_)))
    u_t = grad(u,0)(t_, x_, params_)
    f_out = u_t - D*lap_u - r*u_out*(1-u_out)
    return f_out

In [None]:
# A test point
x_test = np.ones(2) * 0.25
t_test = np.ones(1) * 0.25

u(t_test, x_test, params), f(t_test, x_test, params)

In [None]:
def mse_f(t_, x_, params_):
    return np.mean(f(t_, x_, params_)**2)

mse_f(t_test, x_test, params)

In [None]:
# use of vmap to batch efficiently
batched_mse_f = vmap(mse_f, (0, 0, None), 0)

def loss_f(t_, x_, params_):
    return np.mean(batched_mse_f(t_, x_, params_))

loss_f(np.zeros((10, 1)), 
       np.zeros((10, 2)),  
       params)

In [None]:
def mse_u(t_, x_, u_, params_):
    return np.mean((u_ - u(t_, x_, params_))**2)

mse_u(np.zeros(1), np.zeros(2), np.zeros(1), params)

In [None]:
# use of vmap to batch efficiently
batched_mse_u = vmap(mse_u, (0, 0, 0, None), 0)

def loss_u(t_, x_, u_, params_):
    return np.mean(batched_mse_u(t_, x_, u_, params_))

loss_u(np.zeros((10, 1)), 
     np.zeros((10, 2)), 
     np.ones((10, 1))*0.4, 
     params)

In [None]:
def total_loss(t_, x_, u_, t2_, x2_, params_):
    return loss_f(t2_, x2_, params_) + loss_u(t_, x_, u_, params_)

total_loss(np.zeros((10, 1)), 
     np.zeros((10, 2)), 
     np.ones((10, 1))*0.4, 
     np.ones((10, 1))*0.25,
     np.ones((10, 2))*0.25,
     params)

In [None]:
#unused function
def add_gradients(grad1, grad2, alpha=1.0, beta=1.0):
    return jax.tree_multimap(lambda x, y: x*alpha+y*beta, grads, grads_f)

add_grads = jit(add_gradients)

loss_grad_fn = jit(jax.value_and_grad(loss_u, 3))
loss_grad_f_fn = jit(jax.value_and_grad(loss_f, 2))
loss_grad_total_fn = jit(jax.value_and_grad(total_loss, 5))

"""def losses_and_grad(t_, x_, u_, t2_, x2_, params_):
    loss_f = loss_f(t2_, x2_, params_)
    loss_u = loss_u(t_, x_, u_, params_)
    jax.value_and_grad(total_loss, has_aux=True)
    grad(total_loss, 5)

losses_and_grad_fn = jit(losses_and_grad)"""

In [None]:
def u0(x):
    return np.exp(-vmap(np.dot)(x-0.5,x-0.5)*30)

In [None]:
class DatasetKPP():
    def __init__(self, key, batch_size=10, N_u=500, N_f=5000):
        self.batch_size = batch_size
        self.N_u = N_u
        self.N_f = N_f
        self.curr_idx = 0
        self.curr_f_idx = 0
        x_,t_,u_ = self.generate_data(key, N_u)
        self.x_data = x_
        self.t_data = t_
        self.u_data = u_
        key1, key2 = random.split(key, 2)
        self.t_f_data = random.uniform(key1, (self.N_f, 1))
        self.x_f_data = random.uniform(key2, (self.N_f, 2))
        
        
    def generate_data(self, key, N_u):
        key, subkey = random.split(key)
        data_type = random.uniform(key, (N_u,1))>0.5
        key, subkey1, subkey2, subkey3 = random.split(key, 4)
        x_data = data_type * (random.uniform(subkey1, (N_u,2)))+ \
            (1-data_type) * ((random.uniform(subkey2, (N_u,2))>0.5))
        t_data = data_type * 0 + ((1-data_type) * random.uniform(subkey3, (N_u,1)))
        u_data = data_type * np.expand_dims(u0(x_data),-1)
        return x_data, t_data, u_data

        
    def next_batch(self):
        bstart = self.curr_idx * self.batch_size
        bend = (self.curr_idx + 1) * self.batch_size
        if bend >= self.N_u:
            bend = self.N_u-1
            self.curr_idx = 0
        else:
            self.curr_idx = self.curr_idx + 1
        x_ = self.x_data[bstart:bend]
        t_ = self.t_data[bstart:bend]
        u_ = self.u_data[bstart:bend]
        return t_, x_, u_
    
    def uniform_batch(self, size):
        bstart = self.curr_f_idx 
        bend = bstart + size
        if bend >= self.N_f:
            bend = self.N_f-1
            self.curr_f_idx = 0
        else:
            self.curr_f_idx = self.curr_f_idx + size
        t_b = self.t_f_data[bstart:bend]
        x_b = self.x_f_data[bstart:bend]
        return t_b, x_b
    

In [None]:
ds=DatasetKPP(key, 32)

In [None]:
# Optimizer
import optax
key, subkey = random.split(key, 2)
params = model.init(subkey, t, x)
tx = optax.adam(learning_rate=0.001)
opt_state = tx.init(params)

In [None]:
steps = 1000
batch_nf = 512
# Main train loop
for i in range(steps):
    
    # OLD
    #loss_val, grads = loss_grad_fn(tb, xb, ub, params)
    #loss_f_val, grads_f = loss_grad_f_fn(tb_uni, xb_uni, params)
    #total_grads = add_grads(grads, grads_f, alpha=1.0, beta=1.0)

    tb, xb, ub = ds.next_batch()
    tb_uni, xb_uni = ds.uniform_batch(batch_nf)
    
    total_loss, total_grads = loss_grad_total_fn(tb, xb, ub, tb_uni, xb_uni, params)
    updates, opt_state = tx.update(total_grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    if i % 100 == 0:
        print(f'Loss step {i}: loss: {total_loss:.4f}') 
    #print(f' / mse_u: {loss_val:.2f} / mse_f: {loss_f_val:.2f}')

In [None]:
import matplotlib.pyplot as plt
def display_sliceKPP(resolution=30, slices=[0.0,]):
    num = len(slices)
    plt.figure(figsize=(10,6))
    batched_u = vmap(u, (0, 0, None), 0)
    batched_u0 = vmap(u0)
    for i in range(num):
        plt.subplot(1, num,i+1)
        tt = np.ones((resolution, 1)) * slices[i]
        xx = np.expand_dims(np.linspace(0,1, resolution),-1)
        yy = np.expand_dims(np.linspace(0,1, resolution),-1)
        map_out=[]
        
        for ix in range(resolution):
            xrow = np.expand_dims(np.repeat(xx[ix], resolution),axis=-1)
            row = np.concatenate((xrow,yy), axis=-1)
            outp = batched_u(tt, row, params)
            #outp = u0(row)
            map_out.append(np.expand_dims(outp, axis=-1))
        map_out = np.hstack(map_out)
        plt.imshow(map_out, vmin=-0.05, vmax=0.2)
        plt.xticks([x*resolution/5 for x in range(5)], [round(i/5.0, 2) for i in range(5)])
        plt.yticks([x*resolution/5 for x in range(5)], [round(i/5.0, 2) for i in range(5)])

    plt.colorbar()

In [None]:
display_sliceKPP(20)

In [None]:
display_sliceKPP(20, slices=[0, 0.25, 0.5, 0.75])

In [None]:
u0(np.ones((10,2)))

In [None]:
np.repeat(np.ones(1),10).shape