In [56]:
## https://github.com/ASEM000/Physics-informed-neural-network-in-JAX/blob/main/%5B1%5D_ODE_PINN.ipynb
import jax
import jax.numpy as jnp
import optax

In [57]:
key = jax.random.PRNGKey(0)
t = jax.random.uniform(key, minval=0, maxval=jnp.pi, shape=(100, 1))
# t = jnp.linspace(0, jnp.pi, 100).reshape(-1, 1)

In [58]:
def init_params(layers):
    keys = jax.random.split(jax.random.PRNGKey(0), len(layers)-1)
    params = list()
    for key, n_in, n_out in zip(keys, layers[:-1], layers[1:]):
        lb, ub = -(1 / jnp.sqrt(n_in)), (1 / jnp.sqrt(n_in)) # xavier initialization lower and upper bound
        W = lb + (ub-lb) * jax.random.uniform(key,shape=(n_in,n_out))
        B = jax.random.uniform(key,shape=(n_out,))
        params.append({'W':W,'B':B})
    return params

def model(params, t):
    *hidden, last = params
    for layer in hidden :
        t = jnp.matmul(t, layer['W']) + layer['B']
        t = jax.nn.tanh(t)
    return jnp.matmul(t, last['W']) + last['B']

In [59]:
params = init_params([1] + [20]*4 + [1])
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

@jax.jit
def loss_fn(params, t):
    u = lambda t: model(params, t)
    u_t = lambda t: jax.grad(lambda _t: jnp.sum(u(_t)))(t)
    u_tt = lambda t: jax.grad(lambda _t: jnp.sum(u_t(_t)))(t)

    residual = -t * jnp.cos(2 * jnp.pi * t) + u_t(t) + u_tt(t)
    ode_loss = jnp.mean(residual**2)

    t_ic1, u_ic1 = jnp.array([[0.]]), jnp.array([[1.]])   # u(0) = 1
    ic1_loss = jnp.mean((u(t_ic1) - u_ic1)**2)

    t_ic2, u_ic2 = jnp.array([[0.]]), jnp.array([[10.]])  # u'(0) = 10
    ic2_loss = jnp.mean((u_t(t_ic2) - u_ic2)**2)
    
    loss = ode_loss + ic1_loss + ic2_loss
    return  loss

@jax.jit
def train_step(params, opt_state, t):
    # grads = jax.grad(loss_fn)(params, t)
    loss, grads = jax.value_and_grad(loss_fn)(params, t)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

n_epochs = 10000
for epoch in range(1, n_epochs + 1):
    params, opt_state, loss = train_step(params, opt_state, t)

    if epoch % (n_epochs // 10) == 0:
        print(f"[{epoch:5d}/{n_epochs}] loss: {loss:.3e}")

[ 1000/10000] loss: 1.859e+00
[ 2000/10000] loss: 7.550e-01
[ 3000/10000] loss: 1.086e-03
[ 4000/10000] loss: 1.788e-04
[ 5000/10000] loss: 1.205e-04
[ 6000/10000] loss: 3.551e-05
[ 7000/10000] loss: 7.183e-05
[ 8000/10000] loss: 6.923e-06
[ 9000/10000] loss: 5.177e-06
[10000/10000] loss: 1.402e-05
