In [79]:
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad
import flax.linen as nn
import optax

# 모델 정의 (Linen 사용)
class PINN(nn.Module):
    hidden_size: int

    @nn.compact
    def __call__(self, t, x):
        input_data = jnp.concatenate([t, x], axis=-1)
        x = nn.Dense(self.hidden_size)(input_data)
        x = nn.tanh(x)
        x = nn.Dense(self.hidden_size)(x)
        x = nn.tanh(x)
        x = nn.Dense(1)(x)
        return x

In [80]:
# Collocation points
x = jnp.linspace(0, 1, 100).reshape(-1, 1)
t = jnp.linspace(0, 1, 100).reshape(-1, 1)

# Boundary conditions
x_bc = jnp.array([[0.0], [1.0]])
t_bc = jnp.array([[0.0], [0.0]])
u_bc = jnp.array([[0.0], [0.0]])

# Initial condition
n_ic = 100
x_ic = jnp.linspace(0, 1, n_ic).reshape(-1, 1)
t_ic = jnp.zeros_like(x_ic)
u_ic = jnp.sin(jnp.pi * x_ic)

data = {}
data['bc'] = t_bc, x_bc, u_bc
data['ic'] = t_ic, x_ic, u_ic

In [122]:
# 하이퍼파라미터
alpha = 1.0
hidden_size = 20

# 난수 키 생성 및 모델 초기화
key = jax.random.PRNGKey(0)
model = PINN(hidden_size=hidden_size)
optimizer = optax.adam(learning_rate=0.001)

params = model.init(key, x, t)
opt_state = optimizer.init(params)

# 손실 함수
def loss_fn(params, t, x, alpha, data):
    func   = lambda *args: model.apply(params, *args).sum()     # args = (t, x)
    grad_0 = lambda *args: grad(func, argnums=0)(*args).sum()   # args = (t, x)
    grad_1 = lambda *args: grad(func, argnums=1)(*args).sum()   # args = (t, x)

    u = model.apply(params, t, x)
    u_t = grad(func, argnums=0)(t, x)
    u_x = grad(func, argnums=1)(t, x)
    u_xx = grad(grad_1, argnums=1)(t, x)

    residual = u_t - alpha * u_xx
    loss = pde_loss = jnp.mean(residual**2)
    aux = {"pde": pde_loss}

    for name in data:
        t_data, x_data, u_data = data[name]
        u_pred = model.apply(params, t_data, x_data)
        aux[name] = jnp.mean((u_pred - u_data)**2)
        loss += aux[name]

    return loss, aux

# 학습 스텝 (JIT 컴파일)
@jax.jit
def train_step(params, opt_state, t, x, alpha, data={}):
    # loss, grads = jax.value_and_grad(loss_fn)(params, t, x, alpha)
    (loss, aux), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, t, x, alpha, data)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss, aux

In [123]:
%%time
n_epochs = 10000   # epoch 수 증가
for epoch in range(1, n_epochs + 1):
    params, opt_state, loss, aux = train_step(params, opt_state, t, x, alpha, data)

    desc = f"[{epoch:5d}/{n_epochs}] Loss: {loss:.3e} "
    desc += " ".join([f"{name}: {aux[name]:.3e}" for name in aux])

    if epoch % (n_epochs // 10) == 0:
        print(desc)

[ 1000/10000] Loss: 1.956e-03 bc: 2.252e-04 ic: 1.388e-03 pde: 3.431e-04
[ 2000/10000] Loss: 1.608e-04 bc: 7.807e-06 ic: 8.058e-05 pde: 7.241e-05
[ 3000/10000] Loss: 3.912e-05 bc: 1.542e-06 ic: 1.837e-05 pde: 1.921e-05
[ 4000/10000] Loss: 1.765e-05 bc: 8.272e-07 ic: 8.452e-06 pde: 8.373e-06
[ 5000/10000] Loss: 1.009e-05 bc: 4.204e-07 ic: 4.064e-06 pde: 5.608e-06
[ 6000/10000] Loss: 2.251e-04 bc: 8.542e-06 ic: 1.721e-05 pde: 1.993e-04
[ 7000/10000] Loss: 1.964e-04 bc: 1.682e-05 ic: 1.829e-05 pde: 1.612e-04
[ 8000/10000] Loss: 4.596e-06 bc: 8.214e-08 ic: 1.204e-06 pde: 3.309e-06
[ 9000/10000] Loss: 4.148e-05 bc: 2.843e-06 ic: 4.990e-06 pde: 3.365e-05
[10000/10000] Loss: 3.474e-06 bc: 4.524e-08 ic: 9.580e-07 pde: 2.470e-06
CPU times: total: 3.34 s
Wall time: 2.59 s
