In [65]:
import jax 
import jax.numpy as jnp
import optax
import numpy as np

In [66]:
np.random.seed(42)

N_0 = 100
N_b = 100
N_r = 10_000

tmin, tmax = 0., 1.
xmin, xmax = -1. ,1.

# boundary conditions
# U[0, x] = -sin(pi*x)
t_0 = jnp.ones([N_0, 1], dtype='float32')*0.
x_0 = np.random.uniform(low=xmin, high=xmax, size=(N_0, 1))
ic_0 = -jnp.sin(jnp.pi*x_0) 
IC_0 = jnp.concatenate([t_0, x_0, ic_0], axis=1)

# U[t, -1] = 0
t_b1 = np.random.uniform(low=tmin, high=tmax, size=(N_b, 1))
x_b1 = jnp.ones_like(t_b1) * -1
bc_1 = jnp.zeros_like(t_b1)
BC_1 = jnp.concatenate([t_b1, x_b1, bc_1], axis=1)

# U[t, 1] = 0
t_b2 = np.random.uniform(low=tmin, high=tmax, size=(N_b, 1))
x_b2 = jnp.ones_like(t_b2) 
bc_2 = jnp.zeros_like(t_b2)
BC_2 = jnp.concatenate([t_b2, x_b2, bc_2], axis=1)

conds = [IC_0, BC_1, BC_2]

#collocation points
t = np.random.uniform(low=tmin, high=tmax, size=(N_r, 1))
x = np.random.uniform(low=xmin, high=xmax, size=(N_r, 1))
# colloc = jnp.concatenate([t_c, x_c], axis=1) 

In [67]:
def init_params(layers):
    keys = jax.random.split(jax.random.PRNGKey(0), len(layers) - 1)
    params = []
    for key, n_in, n_out in zip(keys, layers[:-1], layers[1:]):
        lb, ub = -(1 / jnp.sqrt(n_in)), (1 / jnp.sqrt(n_in))    # xavier initialization
        W = lb + (ub - lb) * jax.random.uniform(key, shape=(n_in, n_out))
        B = jax.random.uniform(key,shape=(n_out,))
        params.append({'W':W,'B':B})
    return params

def model(params, t, x):
    X = jnp.concatenate([t, x], axis=1)
    *hidden, last = params
    for layer in hidden :
        X = jnp.matmul(X, layer['W']) + layer['B']
        X = jax.nn.tanh(X)
    return jnp.matmul(X, last['W']) + last['B']

In [68]:
params = init_params([2] + [20]*3 + [1])
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)

@jax.jit
def loss_fn(params, t, x):
    u = lambda t, x : model(params, t, x)
    u_t = lambda t, x: jax.grad(lambda t, x: jnp.sum(u(t, x)), 0)(t, x)
    u_x = lambda t, x: jax.grad(lambda t, x: jnp.sum(u(t, x)), 1)(t, x)
    u_xx = lambda t, x: jax.grad(lambda t, x: jnp.sum(u_x(t, x)), 1)(t, x)

    residual = u_t(t, x) + u(t, x)*u_x(t, x) - (0.01/jnp.pi)*u_xx(t, x)
    loss = jnp.mean(residual**2)
    
    for cond in conds :
        t_b, x_b, u_b = cond[:, [0]], cond[:, [1]], cond[:, [2]]  
        loss += jnp.mean((u(t_b, x_b) - u_b)**2)
    return loss

@jax.jit
def train_step(params, opt_state, t, x):
    # grads = jax.grad(loss_fn)(params, t, x)
    loss, grads = jax.value_and_grad(loss_fn)(params, t, x)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

n_epochs = 2000
for epoch in range(1, n_epochs + 1):
    params, opt_state, loss = train_step(params, opt_state, t, x)

    if epoch % (n_epochs // 10) == 0:
        print(f'[{epoch:5d}/{n_epochs}] loss: {loss:.3e}')

[  200/2000] loss: 1.119e-01
[  400/2000] loss: 8.660e-02
[  600/2000] loss: 7.786e-02
[  800/2000] loss: 6.193e-02
[ 1000/2000] loss: 3.762e-02
[ 1200/2000] loss: 3.208e-02
[ 1400/2000] loss: 2.484e-02
[ 1600/2000] loss: 1.225e-01
[ 1800/2000] loss: 1.135e-02
[ 2000/2000] loss: 9.186e-03
