In [None]:
%load_ext autoreload
%autoreload 2

## Burger's equation

Tools: jax (to install `pip install jax jaxlib flax`)
If you are unfamiliar with Jax random generation, check [this](https://jax.readthedocs.io/en/latest/jax.random.html)

Goal: have a first simple 1D model to work with similar to [this paper](https://arxiv.org/pdf/1711.10561.pdf)



Burger's equation becomes:
$$
u_t + u \times u_x − (0.01/π)u_{xx} = 0, x ∈ [−1, 1], t ∈ [0, 1], \\
u(0, x) = − sin(πx), \\
u(t, −1) = u(t, 1) = 0
$$

In [None]:
import jax.numpy as np
import jax
from jax import grad, jit, vmap, jacfwd, jacrev
from jax import random
from models.nets import MLP
from functools import partial

key = random.PRNGKey(0)
key, subkey = random.split(key)

# A test point
x_test = np.ones(1) * 0.25
t_test = np.ones(1) * 0.25

model = MLP(features=[20,20,20,20,20,20,20, 1])
init_params = model.init(subkey, t_test, x_test)

@jit
def u(t, x, params_):
    return model.apply(params_, t, x)[0]

print('initialized parameter shapes:\n', jax.tree_map(np.shape, init_params))
print(f'\nu(x, t): {u(t_test, x_test, init_params):.3f}')

In [None]:
# t = 0 border condition
def u0(x):
    return - np.sin(np.pi * x)

# u_xx
def hessian(f, index_derivation=0):
    return jacfwd(jacrev(f,index_derivation),index_derivation)

@jit
def f(t, x, params_):
    u_out = u(t, x, params_)
    u_t = grad(u,0)(t, x, params_)
    u_x = grad(u,1)(t, x, params_)
    u_xx = hessian(u, 1)(t, x, params_)[0]
    f_out = u_t + u_out*u_x - (0.01/np.pi)*u_xx
    return np.squeeze(f_out)

In [None]:
# Testing our functions
u(t_test, x_test, init_params), f(t_test, x_test, init_params)

In [None]:
def loss(batches, params_):
    t_, x_, u_, tf_, xf_ = batches
    
    # Physics with mse_f
    mse_f = lambda t,x: partial(f, params_=params_)(t,x)**2
    v_mse_f = vmap(mse_f, (0,0), 0)
    loss_f = np.mean(v_mse_f(tf_, xf_))
    
    # Borders with mse_u
    def mse_u(t_, x_, u_, params_):
        return np.mean((u_ - u(t_, x_, params_))**2)
    v_mse_u = vmap(partial(mse_u, params_=params_), (0,0,0), 0)
    loss_u = np.mean(v_mse_u(t_, x_, u_))
    
    # total loss, then aux loss values. Only the first output is differentiated (because of has_aux=True below)
    return (loss_f+loss_u, (loss_u, loss_f))

losses_and_grad = jit(jax.value_and_grad(loss, 1, has_aux=True))

In [None]:
# Testing the loss function
losses, grads = losses_and_grad((np.zeros((10, 1)), 
                                 np.zeros((10, 1)), 
                                 np.ones((10, 1))*0.4, 
                                 np.ones((10, 1))*0.25,
                                 np.ones((10, 1))*0.25),
                                 init_params)


a, (b,c) = losses
print(f"total loss: {a:.3f}, mse_u: {b:.3f}, mse_f: {c:.3f}")

#### Data and learning

We build $N_u = 100$ boundary data points as mentionned in the paper. Half of them for $t=0$, the other half for $x= \pm 1$. Wrap it into a dataset class

In [None]:
from data import datasets

key, subkey = random.split(key, 2)
ds = datasets.BurgersDataset(subkey, u0, batch_size=32, N_u=200)

In [None]:
# Optimizer
import optax
key, subkey = random.split(key, 2)
params = model.init(subkey, t_test, x_test)
tx = optax.adam(learning_rate=0.001)
opt_state = tx.init(params)

In [None]:
# Main train loop
steps = 5000
for i in range(steps):
    tb, xb, ub = ds.border_batch(key)
    tb_uni, xb_uni = ds.inside_batch(key)
    
    losses, grads = losses_and_grad((tb, xb, ub, tb_uni, xb_uni), 
                                    params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    total_loss_val, (mse_u_val, mse_f_val) = losses
    
    if i % 100 == 99:
        print(f'Loss at step {i+1}: {total_loss_val:.4f} / mse_u: {mse_u_val:.4f} / mse_f: {mse_f_val:.4f}') 

#### Display


In [None]:
batched_u = vmap(partial(u, params_=params), (0, 0), 0)

In [None]:
from data.display import display_burgers_grid, display_burgers_slice

display_burgers_grid(batched_u, 100)

In [None]:
display_burgers_slice(batched_u, 30, slices=[0.0, 0.25, 0.5, 0.75])