In [71]:
import jax
import jax.numpy as jnp
from flax import nnx
import optax

In [95]:
class MLP(nnx.Module):
    def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.activation = nnx.tanh
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

    def __call__(self, x, t):
        x = jnp.hstack([x, t])
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x

In [79]:
model = MLP(2, 32, 1, rngs=nnx.Rngs(0))

x = jnp.ones((20, 1))*0.01
t = jnp.ones((20, 1))*0.01
u = jnp.ones((20, 1))*0.01

u_pred = model(x, t)

print(x.shape, t.shape, u_pred.shape)

(20, 1) (20, 1) (20, 1)


In [81]:
def residual_loss(model, x, t):
    u = model(x, t)
    u_t = jax.jacrev(model, argnums=1)(x, t)
    # u_x = jax.jacrev(model, argnums=0)(x, t)
    u_xx = jax.jacrev(jax.jacrev(model, argnums=0), argnums=0)(x, t)
    residual = u_t - u_xx
    return jnp.mean(residual**2)

residual_loss(model, x, t)

Array(0.0113041, dtype=float32)

In [86]:
model.train()

@nnx.jit
def train_step(model, x, t, u_true):
    def loss_fn(model):
        u = model(x, t)
        u_t = jax.jacrev(model, argnums=1)(x, t)
        # u_x = jax.jacrev(model, argnums=0)(x, t)
        u_xx = jax.jacrev(jax.jacrev(model, argnums=0), argnums=0)(x, t)
        residual = u_t - u_xx
        loss_pde = jnp.mean(residual**2)
        loss_mse = jnp.mean((u - u_true)**2)
        return loss_pde + loss_mse
    
    grads = nnx.grad(loss_fn)(model)
    _, params, rest = nnx.split(model, nnx.Param, ...)
    params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
    nnx.update(model, nnx.GraphState.merge(params, rest))

In [87]:
train_step(model, x, t, u)

In [103]:
learning_rate = 0.005
momentum = 0.9
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

def loss_fn(model, x, t, u):
    u_pred = model(x, t)
    loss = jnp.mean((u_pred - u)**2)
    return loss, u_pred

@nnx.jit
def train_step(model, optimizer, x, t, u):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, u_pred), grads = grad_fn(model, x, t, u)
    # metrics.update(loss=loss, logits=u_pred, labels=u)  # In-place updates.
    optimizer.update(grads)                               # In-place updates.

@nnx.jit
def eval_step(model, metrics: nnx.MultiMetric, x, t, u):
    loss, u_pred = loss_fn(model, x, t, u)
    # metrics.update(loss=loss, logits=u_pred, labels=u)  # In-place updates.

In [104]:
train_step(model, optimizer, x, t, u)

In [161]:
model = MLP(2, 32, 1, rngs=nnx.Rngs(0))
learning_rate = 0.001
momentum = 0.9
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
#   accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

x = jnp.ones((20, 1))*0.01
t = jnp.ones((20, 1))*0.01
u = jnp.ones((20, 1))*0.01

def loss_fn(model, x, t, u):
    pred = model(x, t)
    u_t = jax.jacrev(model, argnums=1)(x, t)
    u_xx = jax.jacrev(jax.jacrev(model, argnums=0), argnums=0)(x, t)

    residual = u_t - u_xx
    loss_pde = jnp.mean(residual**2)
    loss_mse = jnp.mean((pred - u)**2)
    loss = loss_pde + loss_mse
    return loss, pred

@nnx.jit
def train_step(model, optimizer, metrics, x, t, u):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, pred), grads = grad_fn(model, x, t, u)
    # metrics.update(loss=loss, logits=u_pred, labels=u)  # In-place updates.
    metrics.update(loss=loss)  # In-place updates.
    optimizer.update(grads)                             # In-place updates.


metrics_history = {"train_loss": []}
n_epochs = 100
for epoch in range(1, n_epochs + 1):
    train_step(model, optimizer, metrics, x, t, u)

    if epoch % (n_epochs // 10) == 0:
        for metric, value in metrics.compute().items():  # Compute the metrics.
            print(f"[{epoch:4d}/{n_epochs}] {metric}: {value:.2e}")
            metrics_history[f'train_{metric}'].append(value)
        metrics.reset()
        
metrics_history["train_loss"]

[  10/100] loss: 7.23e-03
[  20/100] loss: 1.44e-03
[  30/100] loss: 6.35e-05
[  40/100] loss: 1.58e-04
[  50/100] loss: 9.11e-05
[  60/100] loss: 6.07e-06
[  70/100] loss: 7.16e-06
[  80/100] loss: 3.90e-06
[  90/100] loss: 2.09e-07
[ 100/100] loss: 5.08e-07


[Array(0.00722719, dtype=float32),
 Array(0.00143648, dtype=float32),
 Array(6.3463885e-05, dtype=float32),
 Array(0.00015777, dtype=float32),
 Array(9.109734e-05, dtype=float32),
 Array(6.0693014e-06, dtype=float32),
 Array(7.158123e-06, dtype=float32),
 Array(3.896335e-06, dtype=float32),
 Array(2.0872065e-07, dtype=float32),
 Array(5.080557e-07, dtype=float32)]

In [163]:
class TrainState(nnx.Optimizer):
    def __init__(self, model, optimizer, metrics):
        self.metrics = metrics
        super().__init__(model, optimizer)

    def update(self, *, grads, **updates):
        self.metrics.update(**updates)
        super().update(grads)

model = MLP(2, 32, 1, rngs=nnx.Rngs(0))
learning_rate = 0.001
momentum = 0.9

metrics = nnx.metrics.Average()
state = TrainState(model, optax.adamw(learning_rate, momentum), metrics)

x = jnp.ones((20, 1))*0.01
t = jnp.ones((20, 1))*0.01
u = jnp.ones((20, 1))*0.01

def loss_fn(model, x, t, u):
    pred = model(x, t)
    u_t = jax.jacrev(model, argnums=1)(x, t)
    u_xx = jax.jacrev(jax.jacrev(model, argnums=0), argnums=0)(x, t)

    residual = u_t - u_xx
    loss_pde = jnp.mean(residual**2)
    loss_mse = jnp.mean((pred - u)**2)
    loss = loss_pde + loss_mse
    return loss

n_epochs = 100
for epoch in range(1, n_epochs + 1):
    grads = nnx.grad(loss_fn)(state.model, x, t, u)
    state.update(grads=grads, values=loss_fn(state.model, x, t, u))

    if epoch % (n_epochs // 10) == 0:
        print(f"[{epoch:4d}/{n_epochs}] loss: {state.metrics.compute():.2e}")

[  10/100] loss: 7.23e-03
[  20/100] loss: 4.33e-03
[  30/100] loss: 2.91e-03
[  40/100] loss: 2.22e-03
[  50/100] loss: 1.80e-03
[  60/100] loss: 1.50e-03
[  70/100] loss: 1.28e-03
[  80/100] loss: 1.12e-03
[  90/100] loss: 9.99e-04
[ 100/100] loss: 8.99e-04
