In [57]:
import jax
import jax.numpy as jnp
from flax import nnx
import optax

In [58]:
class PINN(nnx.Module):
    def __init__(self, hidden_size, rngs: nnx.Rngs):
        # self.net = nnx.Sequential([
        #     nnx.Linear(2, hidden_size, rngs=rngs),
        #     nnx.tanh,
        #     nnx.Linear(hidden_size, hidden_size, rngs=rngs),
        #     nnx.tanh,
        #     nnx.Linear(hidden_size, 1, rngs=rngs)
        # ])
        self.linear1 = nnx.Linear(2, hidden_size, rngs=rngs)
        self.linear2 = nnx.Linear(hidden_size, hidden_size, rngs=rngs)
        self.linear3 = nnx.Linear(hidden_size, 1, rngs=rngs)

    def __call__(self, t, x):
        inputs = jnp.concatenate([t, x], axis=-1)
        outputs = nnx.tanh(self.linear1(inputs))
        outputs = nnx.tanh(self.linear2(outputs))
        outputs = self.linear3(outputs)
        return outputs

In [59]:
# 하이퍼파라미터 설정
alpha = 1.0
hidden_size = 20

# 난수 키 생성
key = jax.random.PRNGKey(0)
rngs = nnx.Rngs(key)

# 모델 및 옵티마이저 초기화
model = PINN(hidden_size, rngs=rngs)
x = jnp.linspace(0, 1, 100).reshape(-1, 1)
t = jnp.linspace(0, 1, 100).reshape(-1, 1)

optimizer = nnx.Optimizer(model, optax.adam(learning_rate=0.001))

@nnx.jit
def loss_fn(model, t, x, alpha):
    u = model(t, x)
    u_t = jax.grad(lambda *args: model(*args).sum())(t, x)[1]
    u_x = jax.grad(lambda *args: model(*args).sum())(t, x)[0]
    u_xx = jax.grad(lambda *args: u_x.sum())(t, x)[0]

    # PDE 손실
    residual = u_t - alpha * u_xx
    pde_loss = jnp.mean(residual**2)

    # 경계 조건 손실
    x_bc = jnp.array([[0.0], [1.0]])
    t_bc = jnp.array([[0.0], [0.0]])
    u_bc = model(x_bc, t_bc)
    bc_loss = jnp.mean(u_bc**2)

    # 초기 조건 손실
    n_ic = 100 # 초기 조건 점의 개수
    x_ic = jnp.linspace(0, 1, n_ic).reshape(-1, 1)
    t_ic = jnp.zeros_like(x_ic)
    u_ic = model(x_ic, t_ic)
    ic_loss = jnp.mean((u_ic - jnp.sin(jnp.pi * x_ic))**2)

    loss = pde_loss + bc_loss + ic_loss
    return loss, (pde_loss, bc_loss, ic_loss)


@nnx.jit
def train_step(model, optimizer, t, x, alpha):
    (loss, aux), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model, t, x, alpha)
    optimizer.update(grads)
    return loss, aux

n_epochs = 10000
for epoch in range(1, n_epochs + 1):
    loss, aux = train_step(model, optimizer, t, x, alpha)
    pde_loss, bc_loss, ic_loss = aux

    if epoch % (n_epochs // 10) == 0:
        print(f"[{epoch:5d}/{n_epochs}] Loss: {loss:.3e} "
              f"pde: {pde_loss:.3e} bc: {bc_loss:.3e} ic: {ic_loss:.3e}")

[ 1000/10000] Loss: 2.698e-03 pde: 3.732e-06 bc: 3.380e-04 ic: 2.356e-03
[ 2000/10000] Loss: 1.359e-03 pde: 3.435e-07 bc: 1.543e-04 ic: 1.204e-03
[ 3000/10000] Loss: 8.570e-04 pde: 1.226e-07 bc: 9.234e-05 ic: 7.645e-04
[ 4000/10000] Loss: 6.434e-04 pde: 4.469e-08 bc: 6.637e-05 ic: 5.770e-04
[ 5000/10000] Loss: 5.266e-04 pde: 7.585e-07 bc: 5.118e-05 ic: 4.747e-04
[ 6000/10000] Loss: 4.456e-04 pde: 1.826e-09 bc: 4.236e-05 ic: 4.032e-04
[ 7000/10000] Loss: 1.558e-03 pde: 1.053e-03 bc: 1.209e-04 ic: 3.843e-04
[ 8000/10000] Loss: 3.362e-04 pde: 2.344e-09 bc: 3.052e-05 ic: 3.056e-04
[ 9000/10000] Loss: 3.182e-04 pde: 2.140e-05 bc: 2.478e-05 ic: 2.720e-04
[10000/10000] Loss: 2.658e-04 pde: 3.318e-06 bc: 2.470e-05 ic: 2.378e-04
