In [1]:
import jax, jax.numpy as jnp
import optax
from jax import random, grad, jacfwd, jit, vmap
import equinox as eqx
import matplotlib.pyplot as plt
import optax as op
import os, sys, time
from functools import partial

In [6]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'

In [30]:
X, Y = jnp.meshgrid(jnp.linspace(-1, 1, 100), jnp.linspace(-1, 1, 100))
X, Y = X.ravel(), Y.ravel()

key, subkey = random.split(random.PRNGKey(0))
X = random.permutation(subkey, X)

key, subkey = random.split(key)
Y = random.permutation(subkey, Y)

X = X[:256]
Y = Y[:256]

mean = jnp.array([jnp.mean(X), jnp.mean(Y)])
std = jnp.array([jnp.std(X), jnp.std(Y)])

X_bnd_left = jnp.array([
    -jnp.ones(256),
    jnp.linspace(-1, 1, 256)
]).T


In [32]:
X_bnd_left

Array([[-1.        , -1.        ],
       [-1.        , -0.99215686],
       [-1.        , -0.9843137 ],
       [-1.        , -0.9764706 ],
       [-1.        , -0.96862745],
       [-1.        , -0.9607843 ],
       [-1.        , -0.9529412 ],
       [-1.        , -0.94509804],
       [-1.        , -0.9372549 ],
       [-1.        , -0.92941177],
       [-1.        , -0.92156863],
       [-1.        , -0.9137255 ],
       [-1.        , -0.90588236],
       [-1.        , -0.8980392 ],
       [-1.        , -0.8901961 ],
       [-1.        , -0.88235295],
       [-1.        , -0.8745098 ],
       [-1.        , -0.8666667 ],
       [-1.        , -0.85882354],
       [-1.        , -0.8509804 ],
       [-1.        , -0.84313726],
       [-1.        , -0.8352941 ],
       [-1.        , -0.827451  ],
       [-1.        , -0.81960785],
       [-1.        , -0.8117647 ],
       [-1.        , -0.8039216 ],
       [-1.        , -0.79607844],
       [-1.        , -0.7882353 ],
       [-1.        ,

In [3]:
class InputNormalizer(eqx.Module):
    mean: jax.Array = eqx.field(static=True)
    std: jax.Array = eqx.field(static=True)

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, x):
        return (x - self.mean) / self.std

In [None]:
keys = random.split(random.PRNGKey(0), 4)
pinn = eqx.Sequential([
        eqx.nn.BatchNorm(inputs_size=2, axis_name='batch')
        eqx.nn.Dense(2, 512, keys[0]),
        eqx.nn.Lambda(jnp.sin),
        eqx.nn.Dense(512, 512, keys[1]),
        eqx.nn.Lambda(jnp.sin),
        eqx.nn.Dense(512, 512, keys[2]),
        eqx.nn.Lambda(jnp.sin),
        eqx.nn.Dense(512, 1, keys[3]),
    ])


In [None]:
def laplace_operator(model):
    return vmap(jacfwd(jacfwd(model, argnums=0), argnums=0), in_axes=0)

def loss_residual_poisson(model, t, y):
    return jnp.sum((laplace_operator(model)(t) - y)**2)

def loss_boundary_poisson(model, t, u):
    return jnp.mean((model(t) - u)**2)    