# Instructions
* Open a Colab project that you want to import code snippets to
* Go to ```Tools``` --> ```Settings```
* Go to ```Site``` --> ```Custom snippet notebook URL```
  * paste https://colab.research.google.com/github/google/flax/blob/flax_docs/docs/code_snippets.ipynb
  * click ```Save```
* Go to ```Help``` --> ```Search code snippets```
* Enter a header title from this Colab project (e.g. 'Imports', 'CNN', etc.) in the ```Filter code snippets``` field and import the code snippet into your Colab project

# Imports

In [None]:
!pip install -q flax

In [None]:
import jax
import jax.numpy as jnp
from jax import random

from flax import linen as nn
from flax.training import train_state, checkpoints

import numpy as np

import optax

import tensorflow_datasets as tfds

import os
import shutil

# TensorFlow Datasets (TFDS)

In [None]:
# get a list of TensorFlow datasets that contain keyword
keyword = 'image'
[dataset for dataset in tfds.list_builders() if keyword in dataset]

In [None]:
# load MNIST data
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

# Module Template

In [None]:
from jax import random
import jax.numpy as jnp
from flax import linen as nn

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    pass

x = jnp.zeros((3, 4))
m = Foo()
variables = m.init(random.PRNGKey(0), x)
output = m.apply(variables, x)

# MLP

In [None]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

# CNN

In [None]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

# ResNet

In [None]:
# https://github.com/google/flax/tree/main/examples/imagenet/models.py
import os
from google.colab import files
if 'flax' not in os.listdir():
  !git clone https://github.com/google/flax.git
files.view('flax/examples/imagenet/models.py')

# LSTM

In [None]:
# https://github.com/google/flax/tree/main/examples/seq2seq/models.py
import os
from google.colab import files
if 'flax' not in os.listdir():
  !git clone https://github.com/google/flax.git
files.view('flax/examples/seq2seq/models.py')

# Transformer

In [None]:
# https://github.com/google/flax/tree/main/examples/nlp_seq/models.py
import os
from google.colab import files
if 'flax' not in os.listdir():
  !git clone https://github.com/google/flax.git
files.view('flax/examples/nlp_seq/models.py')

# Loss functions

In [None]:
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

# Training loop

In [None]:
def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

In [None]:
@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.grad(loss_fn, has_aux=True)
  grads, logits = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state, metrics

In [None]:
@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits=logits, labels=batch['label'])

In [None]:
# training loop
for epoch in range(1, n_epochs+1):
  state, metrics = train_step(state, train_ds)
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print(f"epoch: {epoch}, train loss: {metrics['loss']}, " + 
        f"train accuracy: {metrics['accuracy']*100}%, " + 
        f"test loss: {test_loss}, test accuracy: {test_accuracy*100}%")

# Metrics

In [None]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

In [None]:
def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']

# Checkpoints

In [None]:
# https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#

In [None]:
# A simple model with one linear layer.
key1, key2 = random.split(random.PRNGKey(0))
x1 = random.normal(key1, (5,))      # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)

# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.1)      # An Optax SGD optimizer.
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)

# Some arbitrary nested pytree with a dictionary, a string, and a NumPy array.
config = {'dimensions': np.array([5, 3]), 'name': 'dense'}

# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}

In [None]:
# Import Flax Checkpoints.
ckpt_dir = 'tmp/flax-checkpointing'

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

checkpoints.save_checkpoint(ckpt_dir=ckpt_dir,
                            target=ckpt,
                            step=0,
                            overwrite=False,
                            keep=2)

'tmp/flax-checkpointing/checkpoint_0'

In [None]:
empty_state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=np.zeros_like(variables['params']),  # values of the tree leaf doesn't matter
    tx=tx,
)
target = {'model': empty_state, 'config': None, 'data': [jnp.zeros_like(x1)]}
state_restored = checkpoints.restore_checkpoint(ckpt_dir, target=target, step=0)

# Batch Norm

In [None]:
# https://flax.readthedocs.io/en/latest/guides/batch_norm.html

In [None]:
# Defining the model
class MLP(nn.Module):
  @nn.compact
  def __call__(self, x, train: bool):
    x = nn.Dense(features=4)(x)
    x = nn.BatchNorm(use_running_average=not train)(x)
    x = nn.relu(x)
    x = nn.Dense(features=1)(x)
    return x

In [None]:
# batch_stats collection
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.PRNGKey(0), x, train=False)
params = variables['params']
batch_stats = variables['batch_stats']

jax.tree_util.tree_map(jnp.shape, variables)

In [None]:
# apply
y, updates = mlp.apply(
  {'params': params, 'batch_stats': batch_stats},
  x,
  train=True, mutable=['batch_stats']
)
batch_stats = updates['batch_stats']

In [None]:
# training and evaluation
class TrainState(train_state.TrainState):
  batch_stats: Any

state = TrainState.create(
  apply_fn=mlp.apply,
  params=params,
  batch_stats=batch_stats,
  tx=optax.adam(1e-3),
)

In [None]:
@jax.jit
def train_step(state: TrainState, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits, updates = state.apply_fn(
      {'params': params, 'batch_stats': state.batch_stats},
      x=batch['image'], train=True, mutable=['batch_stats'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label'])
    return loss, (logits, updates)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits, updates)), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics

In [None]:
@jax.jit
def eval_step(state: TrainState, batch):
  """Train for a single step."""
  logits = state.apply_fn(
    {'params': params, 'batch_stats': state.batch_stats},
    x=batch['image'], train=False)
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label'])
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics

# MNIST

In [None]:
# https://colab.sandbox.google.com/github/google/flax/blob/main/docs/getting_started.ipynb
import os
from google.colab import files
if 'flax' not in os.listdir():
  !git clone https://github.com/google/flax.git
files.view('flax/docs/getting_started.md')