In [1]:
from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

# 모델 선언
model = MLP([12, 8, 4])

# 데이터 준비
batch = jnp.ones((32, 10))

# 모델 준비
variables = model.init(jax.random.PRNGKey(0), batch)

# 모델에 데이터 적용
output = model.apply(variables, batch)
output.shape

(32, 4)

https://wandb.ai/wandb_fc/tips/reports/How-To-Create-an-Image-Classification-Model-in-JAX-Flax--VmlldzoyMjA0Mjk1

In [30]:
import jax
import jax.numpy as jnp               # JAX NumPy

from flax import linen as nn          # The Linen API
from flax.training import train_state
import optax                          # The Optax gradient processing and optimization library

import numpy as np                    # Ordinary NumPy
import tensorflow_datasets as tfds    # TFDS for MNIST

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
class CNN(nn.Module):
    n_classes: int

    @nn.compact
    # Provide a constructor to register a new parameter 
    # and return its initial value
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1)) # Flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)    # There are 10 classes in MNIST
        return x

In [22]:
@jax.vmap
def cross_entropy_loss(logits, label):
    return -logits[label]


In [32]:
def compute_metrics(logits, labels):
    loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
      'loss': loss,
      'accuracy': accuracy
    }
    return metrics

In [33]:
def get_datasets():
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    # Split into training/test sets
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    # Convert to floating-points
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
    return train_ds, test_ds

In [50]:
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = CNN(n_classes=10).apply({'params': params}, batch['image'])
        loss = jnp.mean(optax.softmax_cross_entropy(
            logits=logits, 
            labels=jax.nn.one_hot(batch['label'], num_classes=10)))
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, batch['label'])
    return state, metrics

In [61]:
@jax.jit
def eval_step(params, batch):
    logits = CNN(n_classes=10).apply({'params': params}, batch['image'])
    return compute_metrics(logits, batch['label'])

In [62]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, len(train_ds['image']))
    perms = perms[:steps_per_epoch * batch_size]  # Skip an incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))

    batch_metrics = []

    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch)
        batch_metrics.append(metrics)

    training_batch_metrics = jax.device_get(batch_metrics)
    training_epoch_metrics = {
        k: np.mean([metrics[k] for metrics in training_batch_metrics])
        for k in training_batch_metrics[0]}

    print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))

    return state, training_epoch_metrics

In [63]:
def eval_model(model, test_ds):
    metrics = eval_step(model, test_ds)
    metrics = jax.device_get(metrics)
    eval_summary = jax.tree_map(lambda x: x.item(), metrics)
    return eval_summary['loss'], eval_summary['accuracy']

In [64]:
train_ds, test_ds = get_datasets()

In [65]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

In [66]:
cnn = CNN(n_classes = 10)
params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']

In [67]:
nesterov_momentum = 0.9
learning_rate = 0.001
tx = optax.sgd(learning_rate=learning_rate, nesterov=nesterov_momentum)

In [68]:
state = train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

In [69]:
num_epochs = 10
batch_size = 32

In [70]:
for epoch in range(1, num_epochs + 1):
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    # Run an optimization step over a training batch
    state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    # Evaluate on the test set after each training epoch
    test_loss, test_accuracy = eval_model(state.params, test_ds)
    print('Testing - epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

Training - epoch: 1, loss: 1.7937, accuracy: 62.74
Testing - epoch: 1, loss: 0.93, accuracy: 82.31
Training - epoch: 2, loss: 0.6111, accuracy: 85.10
Testing - epoch: 2, loss: 0.44, accuracy: 88.47
Training - epoch: 3, loss: 0.4127, accuracy: 88.41
Testing - epoch: 3, loss: 0.36, accuracy: 89.90
Training - epoch: 4, loss: 0.3597, accuracy: 89.69
Testing - epoch: 4, loss: 0.32, accuracy: 90.80
Training - epoch: 5, loss: 0.3279, accuracy: 90.49
Testing - epoch: 5, loss: 0.30, accuracy: 91.56
Training - epoch: 6, loss: 0.3046, accuracy: 91.19
Testing - epoch: 6, loss: 0.28, accuracy: 91.98
Training - epoch: 7, loss: 0.2851, accuracy: 91.72
Testing - epoch: 7, loss: 0.26, accuracy: 92.24
Training - epoch: 8, loss: 0.2679, accuracy: 92.16
Testing - epoch: 8, loss: 0.24, accuracy: 92.89
Training - epoch: 9, loss: 0.2520, accuracy: 92.73
Testing - epoch: 9, loss: 0.23, accuracy: 93.18
Training - epoch: 10, loss: 0.2383, accuracy: 92.99
Testing - epoch: 10, loss: 0.22, accuracy: 93.54
