In [1]:
import jax, jax.numpy as jnp
import optax
from jax import random, grad, jacfwd, jit, vmap
import equinox as eqx
import matplotlib.pyplot as plt
import optax as op
import os, sys, time
from functools import partial

In [2]:
def laplace_operator(model):
    return vmap(jacfwd(jacfwd(model, argnums=0), argnums=0), in_axes=0)

def loss_residual_poisson(model, t, y):
    return jnp.sum((laplace_operator(model)(t) - y)**2)

def loss_boundary_poisson(model, t, u):
    return jnp.mean((model(t) - u)**2)    

In [3]:
x_key, y_key = random.split(random.PRNGKey(0))
x = random.uniform(x_key, (1, 1), minval=-1., maxval=1.)
y = jnp.sin(x)

In [4]:
keys = random.split(random.PRNGKey(0), 3)
pinn = eqx.nn.Sequential([
    eqx.nn.Linear(1, 10, key=keys[0]),
    eqx.nn.Lambda(jax.nn.tanh),
    eqx.nn.Linear(10, 1, key=keys[2])])

In [5]:
eqx.filter_grad(loss_residual_poisson)(pinn, x, y).layers[2].weight

Array([[ 1.7511047e-03,  3.7859907e-05, -9.9056400e-03, -3.6682673e-02,
        -1.1085793e-03,  3.1637501e-02, -9.4369409e-04,  1.3873098e-02,
        -1.9107487e-02, -4.4103645e-02]], dtype=float32)

In [6]:
2 * (laplace_operator(pinn)(x).ravel() - y) * vmap(vmap(grad(grad(jnp.tanh)), in_axes=0),in_axes=0)(x @ pinn.layers[0].weight.T + pinn.layers[0].bias) * (pinn.layers[0].weight.T ** 2)

Array([[ 1.7511046e-03,  3.7859911e-05, -9.9056391e-03, -3.6682677e-02,
        -1.1085737e-03,  3.1637501e-02, -9.4369409e-04,  1.3873094e-02,
        -1.9107487e-02, -4.4103649e-02]], dtype=float32)

In [86]:
eqx.filter_grad(loss_residual_poisson)(pinn, x, y).layers[0].weight

Array([[ 0.00534572],
       [-0.00030895],
       [ 0.02145175],
       [-0.00532079],
       [ 0.00334739],
       [ 0.02240073],
       [-0.00480167],
       [-0.01863186],
       [-0.00197127],
       [-0.03655439]], dtype=float32)

In [83]:
W_1 = pinn.layers[0].weight.T
W_2 = pinn.layers[2].weight.T
b_1 = pinn.layers[0].bias
b_2 = pinn.layers[2].bias

2 * (laplace_operator(pinn)(x).ravel() - y) *\
    (W_2.T * vmap(vmap(grad(grad(jnp.tanh)), in_axes=0),in_axes=0)(x @ W_1 + b_1) * (2 * W_1) +\
     W_2.T *  W_1 ** 2 * x * vmap(vmap(grad(grad(grad(jnp.tanh))), in_axes=0),in_axes=0)(x @ W_1 + b_1))

Array([[ 0.00534572, -0.00030895,  0.02145175, -0.00532079,  0.00334739,
         0.02240073, -0.00480167, -0.01863185, -0.00197127, -0.03655439]],      dtype=float32)

Array([[-0.19146267]], dtype=float32)

In [61]:
(pinn.layers[2].weight * vmap(vmap(grad(grad(jnp.tanh)), in_axes=0),in_axes=0)(x @ pinn.layers[0].weight.T + pinn.layers[0].bias)) @ (pinn.layers[0].weight ** 2)

Array([[0.01843224]], dtype=float32)

In [5]:
# BATCH_SIZE = 100
# x = x.reshape((-1, BATCH_SIZE, 1))
# y = y.reshape((-1, BATCH_SIZE, 1))

In [10]:
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((pred_y - y)**2)

def fit(model, optimizer: optax.GradientTransformation):
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
    
    @eqx.filter_jit
    def step(model, opt_state, batch, labels):
        loss_value, grads = eqx.filter_value_and_grad(loss, has_aux=False)(model, batch, labels)
        updates, opt_state = optimizer.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value
    
    for i, (batch, labels) in enumerate(zip(x, y)):
        model, opt_state, loss_value = step(model, opt_state, batch, labels)
        if i % 100 == 0:
            print(f"Step {i}, loss = {loss_value}")
            
    return model



In [11]:
keys = random.split(random.PRNGKey(0), 3)
pinn = eqx.nn.Sequential([
    eqx.nn.Linear(1, 100, key=keys[0]),
    eqx.nn.Lambda(jax.nn.tanh),
    # eqx.nn.Linear(100, 100, key=keys[1]),
    # eqx.nn.Lambda(jax.nn.tanh),
    eqx.nn.Linear(100, 1, key=keys[2])])
# params, static = eqx.partition(pinn, eqx.is_array)
# optimizer = optax.adam(1e-2)
# pinn = fit(pinn, optimizer)