In [25]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax

from typing import Sequence
import numpy as np

class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features:
            x = nn.relu(nn.Dense(feat)(x))
        return x

In [26]:
# 가상 데이터 생성
key = jax.random.PRNGKey(0)
x_train = np.random.randn(100, 2)
y_train = np.sin(x_train)

# JAX로 변환
x_train = jnp.array(x_train)
y_train = jnp.array(y_train)

In [27]:
model = MLP(features=[2, 100, 100, 2])
params = model.init(key, x_train)['params']

def loss(params, x, y):
    pred = model.apply({'params': params}, x)
    return jnp.mean((pred - y) ** 2)

optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)

@jax.jit
def update(params, opt_state, x, y):
    loss_value, grads = jax.value_and_grad(loss)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

n_epochs = 1000
for epoch in range(1, n_epochs + 1):
    params, opt_state, loss_value = update(params, opt_state, x_train, y_train)
    if epoch % (n_epochs // 10) == 0:
        print(f"[{epoch:4d}/{n_epochs}] Loss: {loss_value:.3f}")

[ 100/1000] Loss: 0.446
[ 200/1000] Loss: 0.446
[ 300/1000] Loss: 0.446
[ 400/1000] Loss: 0.446
[ 500/1000] Loss: 0.446
[ 600/1000] Loss: 0.446
[ 700/1000] Loss: 0.446
[ 800/1000] Loss: 0.446
[ 900/1000] Loss: 0.446
[1000/1000] Loss: 0.446
