## How to load datasets in JAX with TensorFlow
Click the image below to read the post online.

<a target="_blank" href="https://www.machinelearningnuggets.com/how-to-load-dataset-in-jax-with-tensorflow"><img src="https://digitalpress.fra1.cdn.digitaloceanspaces.com/mhujhsj/2022/07/logo.png" alt="Open in ML Nuggets"></a>

In [None]:
pip install wget

## 1. Imports

Import JAX, [JAX NumPy](https://jax.readthedocs.io/en/latest/jax.numpy.html),
Flax, ordinary NumPy, and TensorFlow Datasets (TFDS). Flax can use any
data-loading pipeline and this example demonstrates how to utilize TFDS.

In [67]:
!pip install -q flax

In [68]:
import jax
import jax.numpy as jnp                # JAX NumPy
from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state
import optax                           # Optimizers

## 2. Define network

Create a convolutional neural network with the Linen API by subclassing
[Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).
Because the architecture in this example is relatively simple—you're just
stacking layers—you can define the inlined submodules directly within the
`__call__` method and wrap it with the
[@compact](https://flax.readthedocs.io/en/latest/flax.linen.html#compact-methods)
decorator.

In [69]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=2)(x)
    return x

## 3. Define loss

We simply use `optax.softmax_cross_entropy()`. Note that this function expects both `logits` and `labels` to have shape `[batch, num_classes]`. Since the labels will be read from TFDS as integer values, we first need to convert them to a onehot encoding.

Our function returns a simple scalar value ready for optimization, so we first take the mean of the vector shaped `[batch]` returned by Optax's loss function.

In [70]:
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=2)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

## 4. Metric computation

For loss and accuracy metrics, create a separate function:

In [71]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

## 5. Loading data

Define a function that loads and prepares the MNIST dataset and converts the
samples to floating-point numbers.

In [72]:
import wget # pip install wget
import zipfile

wget.download("https://ml.machinelearningnuggets.com/train.zip")
with zipfile.ZipFile('train.zip', 'r') as zip_ref:
  zip_ref.extractall('.')

In [73]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import numpy as np

In [74]:
import pandas as pd
base_dir = 'train'
filenames = os.listdir(base_dir)
categories = []
for filename in filenames:
    category = filename.split('.')[0]
    if category == 'dog':
        categories.append("dog")
    else:
        categories.append("cat")

df = pd.DataFrame({'filename': filenames,'category': categories})

In [75]:
df = df.sample(800)

In [76]:
train_datagen = ImageDataGenerator(rescale=1./255, 
                                   shear_range=0.2,
                                   zoom_range=0.2, 
                                   horizontal_flip=True,
                                   width_shift_range=0.1,
                                   height_shift_range=0.1,
                                   validation_split=0.2
                                   )


In [77]:
validation_gen = ImageDataGenerator(rescale=1./255,validation_split=0.2)
image_size = (128, 128)
batch_size = 64
training_set = train_datagen.flow_from_dataframe(df,base_dir,
                                                seed=101,                                                 
                                                target_size=image_size,
                                                batch_size=batch_size,
                                                x_col='filename',
                                                y_col='category',
                                                subset = "training",
                                                class_mode='binary')
validation_set = validation_gen.flow_from_dataframe(df,base_dir, 
                                              target_size=image_size,
                                              batch_size=batch_size, 
                                              x_col='filename',
                                                y_col='category',
                                              subset = "validation",
                                              class_mode='binary')

Found 640 validated image filenames belonging to 2 classes.
Found 160 validated image filenames belonging to 2 classes.


In [78]:
train_images, train_labels = next(training_set)
test_images, test_labels = next(validation_set)

print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

Train: (64, 128, 128, 3) (64,)
Test: (64, 128, 128, 3) (64,)


In [79]:
for train_images, train_labels in training_set:
  print('Train:', train_images.shape, train_labels.shape)
  break

Train: (64, 128, 128, 3) (64,)


## 6. Create train state

A common pattern in Flax is to create a single dataclass that represents the
entire training state, including step number, parameters, and optimizer state.

Also adding optimizer & model to this state has the advantage that we only need
to pass around a single argument to functions like `train_step()` (see below).

Because this is such a common pattern, Flax provides the class
[flax.training.train_state.TrainState](https://flax.readthedocs.io/en/latest/flax.training.html#train-state)
that serves most basic usecases. Usually one would subclass it to add more data
to be tracked, but in this example we can use it without any modifications.

In [80]:
def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, image_size[1], image_size[1], 3]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

## 7. Training step

A function that:

- Evaluates the neural network given the parameters and a batch of input images
  with the
  [Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)
  method.
- Computes the `cross_entropy_loss` loss function.
- Evaluates the loss function and its gradient using
  [jax.value_and_grad](https://jax.readthedocs.io/en/latest/jax.html#jax.value_and_grad).
- Applies a
  [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions)
  of gradients to the optimizer to update the model's parameters.
- Computes the metrics using `compute_metrics` (defined earlier).

Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit)
decorator to trace the entire `train_step` function and just-in-time compile
it with [XLA](https://www.tensorflow.org/xla) into fused device operations
that run faster and more efficiently on hardware accelerators.

In [81]:
@jax.jit
def train_step(state, train_images, train_labels):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, train_images)
    loss = cross_entropy_loss(logits=logits, labels=train_labels)
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=train_labels)
  return state, metrics

## 8. Evaluation step

Create a function that evaluates your model on the test set with
[Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)

In [82]:
@jax.jit
def eval_step(params, test_images, test_labels):
  logits = CNN().apply({'params': params}, test_images)
  return compute_metrics(logits=logits, labels=test_labels)

## 9. Train function

Define a training function that:

- Shuffles the training data before each epoch using
  [jax.random.permutation](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.permutation.html)
  that takes a PRNGKey as a parameter (check the
  [JAX - the sharp bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JAX-PRNG)).
- Runs an optimization step for each batch.
- Retrieves the training metrics from the device with `jax.device_get` and
  computes their mean across each batch in an epoch.
- Returns the optimizer with updated parameters and the training loss and
  accuracy metrics.

In [90]:
def train_epoch(state, train_images, train_labels, epoch):
    """Train for 1 epoch on the training set."""
    batch_metrics = []
    # for train_images, train_labels in training_set:
    state, metrics = train_step(state, train_images, train_labels)
    batch_metrics.append(metrics)

    batch_metrics_np = jax.device_get(batch_metrics)  
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
        for k in batch_metrics_np[0]}
    print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

    return state

## 10. Eval function

Create a model evaluation function that:

- Retrieves the evaluation metrics from the device with `jax.device_get`.
- Copies the metrics
  [data stored](https://flax.readthedocs.io/en/latest/design_notes/linen_design_principles.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables)
  in a JAX
  [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions).

In [91]:
def eval_model(params, test_images, test_labels ):
  metrics = eval_step(params, test_images, test_labels )
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']

## 12. Seed randomness

- Get one
  [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey)
  and
  [split](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html#jax.random.split)
  it to get a second key that you'll use for parameter initialization. (Learn
  more about
  [PRNG chains](https://flax.readthedocs.io/en/latest/design_notes/linen_design_principles.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables)
  and
  [JAX PRNG design](https://github.com/google/jax/blob/main/design_notes/prng.md).)

In [92]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

## 13. Initialize train state

Remember that function initializes both the model parameters and the optimizer
and puts both into the training state dataclass that is returned.

In [93]:
learning_rate = 0.1
momentum = 0.9

In [94]:
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

## 14. Train and evaluate

Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy.

In [95]:
num_epochs = 1

In [None]:
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  # Run an optimization step over a training batch
  state = train_epoch(state, train_images, train_labels, epoch)
  # Evaluate on the test set after each training epoch 
  test_loss, test_accuracy = eval_model(state.params, test_images, test_labels)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
      epoch, test_loss, test_accuracy * 100))

## Where to go from here
Follow us on [LinkedIn](https://www.linkedin.com/company/mlnuggets), [Twitter](https://twitter.com/ml_nuggets), [GitHub](https://github.com/mlnuggets) and subscribe to our [blog](https://www.machinelearningnuggets.com/#/portal) so that you don't miss a new issue.