In [135]:
## https://github.com/ASEM000/Physics-informed-neural-network-in-JAX/blob/main/%5B5%5D_System_of_ODEs_PINN.ipynb
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

In [136]:
# Collocation points
key = jax.random.PRNGKey(0)
t = jax.random.uniform(key, minval=0, maxval=jnp.pi, shape=(10000, 1))

In [137]:
def init_params(layers):
    keys = jax.random.split(jax.random.PRNGKey(0), len(layers) - 1)
    params = list()
    for key, n_in, n_out in zip(keys, layers[:-1], layers[1:]):
        lb, ub = -(1 / jnp.sqrt(n_in)), (1 / jnp.sqrt(n_in)) # xavier initialization lower and upper bound
        W = lb + (ub-lb) * jax.random.uniform(key,shape=(n_in,n_out))
        B = jax.random.uniform(key,shape=(n_out,))
        params.append({'W':W,'B':B})
    return params

def model(params, t):
    *hidden, last = params
    for layer in hidden :
        t = jnp.matmul(t, layer['W']) + layer['B']
        t = jax.nn.tanh(t)
    return jnp.matmul(t, last['W']) + last['B']

In [138]:
params = init_params([1] + [20]*2 + [2])
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)

@jax.jit
def loss_fn(params, t):
    x = lambda t: model(params, t)[:, [0]]
    y = lambda t: model(params, t)[:, [1]]
    x_t = lambda t: jax.grad(lambda t: jnp.sum(x(t)))(t)
    y_t = lambda t: jax.grad(lambda t: jnp.sum(y(t)))(t)

    residual_x = x_t(t) - x(t)
    residual_y = y_t(t) - x(t) + y(t)
    ode_loss = jnp.mean(residual_x**2) + jnp.mean(residual_y**2)

    # x[0] = 1, y[0] = 2
    t0, x0, y0 = jnp.array([[0.]]), jnp.array([[1.]]), jnp.array([[2.]])
    ic1_loss = jnp.mean((x(t0) - x0)**2)
    ic2_loss = jnp.mean((y(t0) - y0)**2)

    loss = ode_loss + ic1_loss + ic2_loss
    return  loss, (ic1_loss, ic2_loss)

@jax.jit
def train_step(params, opt_state, t):
    # grads = jax.grad(loss_fn)(params, t)
    # loss, grads = jax.value_and_grad(loss_fn)(params, t)
    (loss, aux), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, t)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, (loss, *aux)

n_epochs = 2000
for epoch in range(1, n_epochs + 1):
    params, opt_state, losses = train_step(params, opt_state, t)
    loss, ic1_loss, ic2_loss = losses[0], losses[1], losses[2]

    if epoch % (n_epochs // 10) == 0:
        print(f'[{epoch:5d}/{n_epochs}] loss: {loss:.3e}, ic1: {ic1_loss:.3e}, ic2: {ic2_loss:.3e}')

[  200/2000] loss: 3.839e-01, ic1: 1.484e-01, ic2: 2.349e-07
[  400/2000] loss: 3.525e-01, ic1: 1.366e-01, ic2: 3.548e-07
[  600/2000] loss: 1.815e-01, ic1: 8.646e-02, ic2: 1.200e-04
[  800/2000] loss: 8.636e-02, ic1: 2.598e-02, ic2: 2.084e-05
[ 1000/2000] loss: 4.519e-02, ic1: 1.875e-02, ic2: 2.918e-06
[ 1200/2000] loss: 4.289e-02, ic1: 2.148e-02, ic2: 9.932e-05
[ 1400/2000] loss: 1.471e-02, ic1: 5.264e-03, ic2: 4.512e-07
[ 1600/2000] loss: 1.201e-02, ic1: 6.632e-03, ic2: 3.439e-05
[ 1800/2000] loss: 4.926e-03, ic1: 1.543e-03, ic2: 1.643e-11
[ 2000/2000] loss: 3.149e-03, ic1: 9.238e-04, ic2: 3.327e-10
