In [68]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

# 모델 정의 (Linen 사용)
class PINN(nn.Module):
    hidden_size: int

    @nn.compact
    def __call__(self, t, x):
        input_data = jnp.concatenate([t, x], axis=-1)
        x = nn.Dense(self.hidden_size)(input_data)
        x = nn.tanh(x)
        x = nn.Dense(self.hidden_size)(x)
        x = nn.tanh(x)
        x = nn.Dense(1)(x)
        return x

In [69]:
# 하이퍼파라미터
alpha = 1.0
hidden_size = 20

# 난수 키 생성 및 모델 초기화
key = jax.random.PRNGKey(0)
model = PINN(hidden_size=hidden_size)

x = jnp.linspace(0, 1, 100).reshape(-1, 1)
t = jnp.linspace(0, 1, 100).reshape(-1, 1)
params = model.init(key, x, t)

# 옵티마이저 초기화
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

# 손실 함수
@jax.jit
def loss_fn(params, t, x, alpha):
    # pde = lambda t, x : model(params, t, x)
    u_t = jax.grad(lambda *args: model.apply(params, *args).sum())(x, t)[0]
    u_x = jax.grad(lambda *args: model.apply(params, *args).sum())(x, t)[1]
    u_xx = jax.grad(lambda *args: u_x.sum())(x, t)[1]

    # u = model.apply(params, t, x)
    # u_t  = jax.grad(lambda t, x: u.sum(), argnums=0)(x, t)
    # u_x  = jax.grad(lambda t, x: u.sum(), argnums=1)(x, t)
    # u_xx = jax.grad(lambda t, x: u_x.sum(), argnums=1)(x, t)

    residual = u_t - alpha * u_xx
    pde_loss = jnp.mean(residual**2)

    # Boundary conditions
    x_bc = jnp.array([[0.0], [1.0]])
    t_bc = jnp.array([[0.0], [0.0]])
    u_bc = model.apply(params, x_bc, t_bc)
    bc_loss = jnp.mean(u_bc**2)

    # Initial condition
    n_ic = 100
    x_ic = jnp.linspace(0, 1, n_ic).reshape(-1, 1)
    t_ic = jnp.zeros_like(x_ic)
    u_ic = model.apply(params, x_ic, t_ic)
    ic_loss = jnp.mean((u_ic - jnp.sin(jnp.pi * x_ic))**2)

    loss = pde_loss + bc_loss + ic_loss
    return loss, (pde_loss, bc_loss, ic_loss)

# 학습 스텝 (JIT 컴파일)
@jax.jit
def train_step(params, opt_state, t, x, alpha):
    # loss, grads = jax.value_and_grad(loss_fn)(params, t, x, alpha)
    (loss, aux), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, t, x, alpha)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss, aux

# 학습 루프
n_epochs = 10000   # epoch 수 증가
for epoch in range(1, n_epochs + 1):
    params, opt_state, loss, aux = train_step(params, opt_state, t, x, alpha)
    pde_loss, bc_loss, ic_loss = aux

    if epoch % (n_epochs // 10) == 0:
        print(f"[{epoch:5d}/{n_epochs}] Loss: {loss:.3e} "
              f"pde: {pde_loss:.3e} bc: {bc_loss:.3e} ic: {ic_loss:.3e}")

[ 1000/10000] Loss: 7.516e-03 pde: 2.258e-05 bc: 1.437e-03 ic: 6.057e-03
[ 2000/10000] Loss: 2.846e-03 pde: 2.645e-06 bc: 4.630e-04 ic: 2.381e-03
[ 3000/10000] Loss: 1.452e-03 pde: 4.011e-05 bc: 1.795e-04 ic: 1.232e-03
[ 4000/10000] Loss: 8.993e-04 pde: 1.213e-06 bc: 1.151e-04 ic: 7.830e-04
[ 5000/10000] Loss: 6.111e-04 pde: 1.174e-07 bc: 7.181e-05 ic: 5.392e-04
[ 6000/10000] Loss: 7.218e-04 pde: 2.150e-04 bc: 7.523e-05 ic: 4.316e-04
[ 7000/10000] Loss: 3.919e-04 pde: 4.028e-08 bc: 3.672e-05 ic: 3.552e-04
[ 8000/10000] Loss: 3.288e-04 pde: 4.315e-08 bc: 2.923e-05 ic: 2.995e-04
[ 9000/10000] Loss: 2.844e-04 pde: 8.141e-08 bc: 2.475e-05 ic: 2.596e-04
[10000/10000] Loss: 2.547e-04 pde: 1.248e-08 bc: 2.183e-05 ic: 2.328e-04
