## References
- [Annotated MNIST](https://flax.readthedocs.io/en/latest/notebooks/annotated_mnist.html)
- [Training a Simple Neural Network, with tensorflow/datasets Data Loading](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html)

## Setup

In [1]:
!pip install flax

Collecting flax
  Downloading flax-0.3.6-py3-none-any.whl (207 kB)
[?25l[K     |█▋                              | 10 kB 14.9 MB/s eta 0:00:01[K     |███▏                            | 20 kB 18.3 MB/s eta 0:00:01[K     |████▊                           | 30 kB 20.4 MB/s eta 0:00:01[K     |██████▎                         | 40 kB 10.9 MB/s eta 0:00:01[K     |████████                        | 51 kB 6.9 MB/s eta 0:00:01[K     |█████████▌                      | 61 kB 7.1 MB/s eta 0:00:01[K     |███████████                     | 71 kB 5.6 MB/s eta 0:00:01[K     |████████████▋                   | 81 kB 6.2 MB/s eta 0:00:01[K     |██████████████▎                 | 92 kB 6.6 MB/s eta 0:00:01[K     |███████████████▉                | 102 kB 5.8 MB/s eta 0:00:01[K     |█████████████████▍              | 112 kB 5.8 MB/s eta 0:00:01[K     |███████████████████             | 122 kB 5.8 MB/s eta 0:00:01[K     |████████████████████▋           | 133 kB 5.8 MB/s eta 0:00:01[K     

## Dataset

In [2]:
# See: https://www.tensorflow.org/datasets/catalog/tf_flowers
!curl -o flower_photos.tgz https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
!tar -xzf flower_photos.tgz
!rm flower_photos.tgz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  218M  100  218M    0     0   135M      0  0:00:01  0:00:01 --:--:--  134M


In [3]:
import os
from typing import Dict, List, Tuple

from tensorflow.keras import utils as tfku

class Flower:
    def __init__(self, image_path: str, label: str):
        self.image_path = image_path
        self.label = label

def collect_data() -> Tuple[Tuple[List["Flower"], ...], List[str]]:
    def walk_children(path: str, full_path: bool, walk_dirs: bool=True) -> List[str]:
        return sorted([
            os.path.join(path, name) if full_path else name
            for name in next(os.walk(path))[1 if walk_dirs else 2]
        ])
    DATA_PATH = "flower_photos"
    train_list = []
    val_list = []
    labels = walk_children(DATA_PATH, False, walk_dirs=True)
    for label in labels:
        for i, image_path in enumerate(walk_children(os.path.join(DATA_PATH, label), True, walk_dirs=False)):
            flower = Flower(image_path, label)
            if i % 5 != 0:
                train_list.append(flower)
            else:
                val_list.append(flower)
    return (train_list, val_list), labels

def make_slices_dict(flowers_list: List[Flower], labels: List[str]) -> Dict[str, List]:
    slices = {}
    slices["image"] = [flower.image_path for flower in flowers_list]
    slices["output"] = tfku.to_categorical([labels.index(flower.label) for flower in flowers_list], num_classes=len(labels))
    return slices

In [4]:
import tensorflow as tf

class Reader:
    def __init__(self, image_size: Tuple[int, int], augment: bool):
        self.image_size = image_size # (height, width)
        self.augment = augment

    def read_input(self, sources: Dict[str, tf.Tensor]) -> tf.Tensor:
        return self._read_image(sources["image"])

    def read_output(self, sources: Dict[str, tf.Tensor]) -> tf.Tensor:
        return sources["output"]

    def _read_image(self, source: tf.Tensor) -> tf.Tensor:
        image = tf.io.read_file(source)
        image = tf.image.decode_image(image, channels=3, expand_animations=False)
        if self.augment:
            image = tf.image.random_brightness(image, 0.5)
            image = tf.image.random_contrast(image, 0.2, 0.8)
            image = tf.image.random_crop(image, [
                tf.random.uniform([], minval=0.8, maxval=1.0) * tf.cast(tf.shape(image)[0], tf.float32),
                tf.random.uniform([], minval=0.8, maxval=1.0) * tf.cast(tf.shape(image)[1], tf.float32),
                3
            ])
            image = tf.image.random_flip_left_right(image)
        image = tf.image.resize(image, self.image_size)
        image /= 255
        return image

def generate_dataset(slices: Dict[str, List], reader: Reader, batch_size: int, shuffle: bool) -> tf.data.Dataset:
    def _read_data(sources: Dict[str, tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]:
        input_data = reader.read_input(sources)
        output_data = reader.read_output(sources)
        return (input_data if len(input_data) > 1 else input_data[0], output_data)
    dataset = tf.data.Dataset.from_tensor_slices(slices)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=max(len(data) for data in slices.values()))
    dataset = dataset.map(
        lambda data: _read_data(data),
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    ).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

## Model

In [5]:
from flax import linen as fl

class Cnn(fl.Module):
    labels_num: int

    @fl.compact
    def __call__(self, x):
        x = fl.Conv(32, (3, 3))(x)
        x = fl.relu(x)
        x = fl.avg_pool(x, (2, 2), strides=(2, 2))
        x = fl.Conv(64, (3, 3))(x)
        x = fl.relu(x)
        x = fl.avg_pool(x, (2, 2), strides=(2, 2))
        x = fl.Conv(128, (3, 3))(x)
        x = fl.relu(x)
        x = fl.avg_pool(x, (2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1)) # Flatten
        x = fl.Dense(256)(x)
        x = fl.relu(x)
        x = fl.Dense(self.labels_num)(x)
        x = fl.softmax(x)
        return x

In [6]:
import jax
import jax.numpy as jnp
import numpy as np
from flax.training import train_state as ft

def _calculate_loss(pred: jnp.DeviceArray, truth: np.ndarray) -> jnp.DeviceArray:
    # Cross-entropy
    return jnp.mean(-jnp.sum(truth * jnp.log(pred), axis=-1))

def _calculate_acc(pred: jnp.DeviceArray, truth: np.ndarray) -> jnp.DeviceArray:
    return jnp.mean(jnp.argmax(pred, axis=-1) == jnp.argmax(truth, axis=-1))

def _compute_epoch_metric(metrics: List[Dict[str, jnp.DeviceArray]]) -> Dict[str, np.float32]:
    return {
        k: np.mean([m[k] for m in jax.device_get(metrics)])
        for k in ["loss", "acc"]
    }

def train_epoch(state: ft.TrainState, train_dataset: tf.data.Dataset) -> Tuple[ft.TrainState, Dict[str, np.float32]]:
    @jax.jit
    def train_batch(state, batch):
        image, truth = batch
        def loss_fn(params):
            pred = state.apply_fn({"params": params}, image)
            loss = _calculate_loss(pred, truth)
            return loss, pred
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, pred), grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
        metric = {
            "loss": loss,
            "acc": _calculate_acc(pred, truth)
        }
        return state, metric
    batch_metrics = []
    for batch in train_dataset:
        state, metric = train_batch(state, batch)
        batch_metrics.append(metric)
    epoch_metric = _compute_epoch_metric(batch_metrics)
    return state, epoch_metric

def eval_epoch(state: ft.TrainState, val_dataset: tf.data.Dataset) -> Dict[str, np.float32]:
    def eval_batch(state, batch):
        image, truth = batch
        pred = state.apply_fn({"params": state.params}, image)
        metric = {
            "loss": _calculate_loss(pred, truth),
            "acc": _calculate_acc(pred, truth)
        }
        return metric
    batch_metrics = jax.device_get([eval_batch(state, batch) for batch in val_dataset])
    epoch_metric = _compute_epoch_metric(batch_metrics)
    return epoch_metric

## Training

In [7]:
import optax
import tensorflow_datasets as tfds

IMAGE_SIZE = (128, 128)
BATCH_SIZE = 8
EPOCHS = 10

# Dataset
(train_list, val_list), labels = collect_data()
train_slices = make_slices_dict(train_list, labels)
val_slices = make_slices_dict(val_list, labels)
train_dataset = tfds.as_numpy(generate_dataset(train_slices, Reader(IMAGE_SIZE, True), BATCH_SIZE, True))
val_dataset = tfds.as_numpy(generate_dataset(val_slices, Reader(IMAGE_SIZE, False), BATCH_SIZE, False))

# Model and state
model = Cnn(labels_num=len(labels))
state = ft.TrainState.create(
    apply_fn=model.apply,
    params=model.init(jax.random.PRNGKey(0), jnp.ones([1, *IMAGE_SIZE, 3]))["params"],
    tx=optax.adam(0.001)
)

# Training
for epoch in range(EPOCHS):
    state, train_metric = train_epoch(state, train_dataset)
    val_metric = eval_epoch(state, val_dataset)
    print((
        f"Epoch {epoch}, "
        f"loss: {train_metric['loss']:.4f}, acc: {train_metric['acc']:.2f}, "
        f"loss: {val_metric['loss']:.4f}, acc: {val_metric['acc']:.2f}"
    ))

Epoch 0, loss: 1.5488, acc: 0.32, loss: 1.3275, acc: 0.39
Epoch 1, loss: 1.3704, acc: 0.39, loss: 1.2021, acc: 0.49
Epoch 2, loss: 1.3167, acc: 0.43, loss: 1.2521, acc: 0.47
Epoch 3, loss: 1.2637, acc: 0.47, loss: 1.2226, acc: 0.50
Epoch 4, loss: 1.2271, acc: 0.49, loss: 1.2293, acc: 0.46
Epoch 5, loss: 1.1857, acc: 0.51, loss: 1.2572, acc: 0.52
Epoch 6, loss: 1.1659, acc: 0.53, loss: 1.2973, acc: 0.48
Epoch 7, loss: 1.1267, acc: 0.55, loss: 1.2715, acc: 0.54
Epoch 8, loss: 1.1183, acc: 0.55, loss: 1.1361, acc: 0.58
Epoch 9, loss: 1.0664, acc: 0.58, loss: 1.2287, acc: 0.57
