# Introduction to Data Loaders for Multi-Device Training with JAX

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_for_multi_device_setups_with_jax.ipynb)

This tutorial explores various data loading strategies for **JAX** in **multi-device distributed** environments, leveraging [**TPUs**](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#what-is-a-tpu). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:
*   [**PyTorch DataLoader**](https://github.com/pytorch/data)
*   [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)
*   [**Grain**](https://github.com/google/grain)
*   [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)

You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.

Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide introduces optimizations for distributed training across multiple GPUs or TPUs. It focuses on data sharding with `Mesh` and `NamedSharding` to efficiently partition and synchronize data across devices. By leveraging multi-device setups, you'll maximize resource utilization for large datasets in distributed environments.

Import JAX API

In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random, device_put
from jax.sharding import Mesh, PartitionSpec, NamedSharding

### Checking TPU Availability for JAX

In [2]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

### Setting Hyperparameters and Initializing Parameters

You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network.

In [3]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Function to initialize network parameters for all layers based on defined sizes
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]  # Layers of the network
step_size = 0.01                   # Learning rate
num_epochs = 8                     # Number of training epochs
batch_size = 128                   # Batch size for training
n_targets = 10                     # Number of classes (digits 0-9)
num_pixels = 28 * 28               # Each MNIST image is 28x28 pixels
data_dir = '/tmp/mnist_dataset'    # Directory for storing the dataset

# Initialize network parameters using the defined layer sizes and a random seed
params = init_network_params(layer_sizes, random.PRNGKey(0))

### Model Prediction with Auto-Batching

In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.

To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration.

In [4]:
from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

Multi-device setup using a Mesh of devices

In [5]:
# Get the number of available devices (GPUs/TPUs) for sharding
num_devices = len(jax.devices())

# Multi-device setup using a Mesh of devices
devices = jax.devices()
mesh = Mesh(devices, ('device',))

# Define the sharding specification - split the data along the first axis (batch)
sharding_spec = PartitionSpec('device')

### Utility and Loss Functions

You'll now define utility functions for:
- One-hot encoding: Converts class indices to binary vectors.
- Accuracy calculation: Measures the performance of the model on the dataset.
- Loss computation: Calculates the difference between predictions and targets.

To optimize performance:
- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters.
- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation.

- [`device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to distribute the dataset across TPU cores.

In [6]:
import time

def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
  """Calculate the accuracy of predictions."""
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  """Calculate the loss between predictions and targets."""
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  """Update the network parameters using gradient descent."""
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

def reshape_and_one_hot(x, y):
    """Reshape and one-hot encode the inputs."""
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, n_targets)
    return x, y

def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):
    """Train the model for a given number of epochs and device_put for TPU transfer."""
    for epoch in range(num_epochs):
        start_time = time.time()
        for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:
            x, y = reshape_and_one_hot(x, y)
            x, y = device_put(x, NamedSharding(mesh, sharding_spec)), device_put(y, NamedSharding(mesh, sharding_spec))
            params = update(params, x, y)

        print(f"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: "
              f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f},"
              f"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}")

## Loading Data with PyTorch DataLoader

This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images.

In [7]:
!pip install torch torchvision



In [8]:
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST

In [9]:
def numpy_collate(batch):
  """Collate function to convert a batch of PyTorch data into NumPy arrays."""
  return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
    """Custom DataLoader to return NumPy arrays from a PyTorch Dataset."""
    def __init__(self, dataset, batch_size=1,
                  shuffle=False, sampler=None,
                  batch_sampler=None, num_workers=0,
                  pin_memory=False, drop_last=False,
                  timeout=0, worker_init_fn=None):
      super(self.__class__, self).__init__(dataset,
          batch_size=batch_size,
          shuffle=shuffle,
          sampler=sampler,
          batch_sampler=batch_sampler,
          num_workers=num_workers,
          collate_fn=numpy_collate,
          pin_memory=pin_memory,
          drop_last=drop_last,
          timeout=timeout,
          worker_init_fn=worker_init_fn)
class FlattenAndCast(object):
  """Transform class to flatten and cast images to float32."""
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

### Load Dataset with Transformations

Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types.

In [10]:
mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 36.1MB/s]


Extracting /tmp/mnist_dataset/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.13MB/s]


Extracting /tmp/mnist_dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 10.1MB/s]


Extracting /tmp/mnist_dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 6.34MB/s]

Extracting /tmp/mnist_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw






### Full Training Dataset for Accuracy Checks

Convert the entire training dataset to JAX arrays.

In [11]:
train_images = jnp.array(mnist_dataset.data.numpy().reshape(len(mnist_dataset.data), -1), dtype=jnp.float32)
train_labels = one_hot(np.array(mnist_dataset.targets), n_targets)

### Get Full Test Dataset

Load and process the full test dataset.

In [12]:
mnist_dataset_test = MNIST(data_dir, download=True, train=False)
test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)

In [13]:
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


### Training Data Generator

Define a generator function using PyTorch's DataLoader for batch training.
Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.

Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.`
This warning can be safely ignored since data loaders do not use JAX within the forked processes.

In [14]:
def pytorch_training_generator(mnist_dataset):
    return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

### Training Loop (PyTorch DataLoader)

The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters.

In [15]:
train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable')

Epoch 1 in 5.65 sec: Train Accuracy: 0.9159,Test Accuracy: 0.9197
Epoch 2 in 4.26 sec: Train Accuracy: 0.9371,Test Accuracy: 0.9383
Epoch 3 in 4.39 sec: Train Accuracy: 0.9493,Test Accuracy: 0.9468
Epoch 4 in 4.16 sec: Train Accuracy: 0.9568,Test Accuracy: 0.9536
Epoch 5 in 4.04 sec: Train Accuracy: 0.9632,Test Accuracy: 0.9576
Epoch 6 in 4.06 sec: Train Accuracy: 0.9674,Test Accuracy: 0.9617
Epoch 7 in 4.06 sec: Train Accuracy: 0.9708,Test Accuracy: 0.9649
Epoch 8 in 4.07 sec: Train Accuracy: 0.9737,Test Accuracy: 0.9672


## Loading Data with TensorFlow Datasets (TFDS)

This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow.

Ensure you have the latest versions of both TensorFlow and TensorFlow Datasets

In [16]:
!pip install --upgrade tensorflow tensorflow-datasets

Collecting tensorflow
  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting tensorboard<2.19,>=2.18 (from tensorflow)
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting keras>=3.5.0 (from tensorflow)
  Downloading keras-3.6.0-py3-none-any.whl.metadata (5.8 kB)
Collecting ml-dtypes<0.5.0,>=0.4.0 (from tensorflow)
  Downloading ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting namex (from keras>=3.5.0->tensorflow)
  Downloading namex-0.0.8-py3-none-any.whl.metadata (246 bytes)
Collecting optree (from keras>=3.5.0->tensorflow)
  Downloading optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (47 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.8/47.8 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (615.3 MB)

In [17]:
import tensorflow_datasets as tfds

### Fetch Full Dataset for Evaluation

Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation.

In [18]:
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, n_targets)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, n_targets)

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/mnist_dataset/mnist/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to /tmp/mnist_dataset/mnist/3.0.1. Subsequent calls will reuse this data.


In [19]:
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


### Define the Training Generator

Create a generator function to yield batches of data for training.

In [20]:
def training_generator():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

### Training Loop (TFDS)

Use the training generator in a custom training loop.

In [21]:
train_model(num_epochs, params, training_generator)

Epoch 1 in 4.96 sec: Train Accuracy: 0.9254,Test Accuracy: 0.9271
Epoch 2 in 3.22 sec: Train Accuracy: 0.9428,Test Accuracy: 0.9418
Epoch 3 in 3.23 sec: Train Accuracy: 0.9532,Test Accuracy: 0.9517
Epoch 4 in 3.26 sec: Train Accuracy: 0.9600,Test Accuracy: 0.9557
Epoch 5 in 3.28 sec: Train Accuracy: 0.9651,Test Accuracy: 0.9605
Epoch 6 in 3.11 sec: Train Accuracy: 0.9691,Test Accuracy: 0.9628
Epoch 7 in 3.25 sec: Train Accuracy: 0.9726,Test Accuracy: 0.9648
Epoch 8 in 3.15 sec: Train Accuracy: 0.9754,Test Accuracy: 0.9665


## Loading Data with Grain

This section demonstrates how to load MNIST data using Grain, a data-loading library. You'll define a custom dataset class for Grain and set up a Grain DataLoader for efficient training.

Install Grain

In [22]:
!pip install grain

Collecting grain
  Downloading grain-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting jaxtyping (from grain)
  Downloading jaxtyping-0.2.36-py3-none-any.whl.metadata (6.5 kB)
Collecting more-itertools>=9.1.0 (from grain)
  Downloading more_itertools-10.5.0-py3-none-any.whl.metadata (36 kB)
Downloading grain-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (418 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m419.0/419.0 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading more_itertools-10.5.0-py3-none-any.whl (60 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxtyping-0.2.36-py3-none-any.whl (55 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.8/55.8 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: more-itertools, jaxtyping, grain
  Attempting uninstall

Import Required Libraries (import MNIST dataset from torchvision)

In [23]:
import numpy as np
import grain.python as pygrain
from torchvision.datasets import MNIST

### Define Dataset Class

Create a custom dataset class to load MNIST data for Grain.

In [24]:
class Dataset:
    def __init__(self, data_dir, train=True):
        self.data_dir = data_dir
        self.train = train
        self.load_data()

    def load_data(self):
        # Load the MNIST dataset using PyGrain
        self.dataset = MNIST(self.data_dir, download=True, train=self.train)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img, label = self.dataset[index]
        return np.ravel(np.array(img, dtype=np.float32)), label

### Initialize the Dataset

In [25]:
mnist_dataset = Dataset(data_dir)

### Get the full train and test dataset

In [26]:
train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32)
train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets)

mnist_dataset_test = MNIST(data_dir, download=True, train=False)

# Convert test images to JAX arrays and encode test labels as one-hot vectors
test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32)
test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets)

In [27]:
print("Train:", train_images.shape, train_labels.shape)
print("Test:", test_images.shape, test_labels.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


### Initialize PyGrain DataLoader

In [28]:
sampler = pygrain.SequentialSampler(
    num_records=len(mnist_dataset),
    shard_options=pygrain.ShardByJaxProcess())  # Shard across TPU cores

def pygrain_training_generator():
    return pygrain.DataLoader(
        data_source=mnist_dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size)],
    )

### Training Loop (Grain)

Run the training loop using the Grain DataLoader.

In [29]:
train_model(num_epochs, params, pygrain_training_generator)

Epoch 1 in 8.05 sec: Train Accuracy: 0.9159,Test Accuracy: 0.9197
Epoch 2 in 8.14 sec: Train Accuracy: 0.9371,Test Accuracy: 0.9383
Epoch 3 in 8.99 sec: Train Accuracy: 0.9493,Test Accuracy: 0.9468
Epoch 4 in 9.00 sec: Train Accuracy: 0.9568,Test Accuracy: 0.9536
Epoch 5 in 8.40 sec: Train Accuracy: 0.9632,Test Accuracy: 0.9576
Epoch 6 in 8.28 sec: Train Accuracy: 0.9674,Test Accuracy: 0.9617
Epoch 7 in 8.20 sec: Train Accuracy: 0.9708,Test Accuracy: 0.9649
Epoch 8 in 8.24 sec: Train Accuracy: 0.9737,Test Accuracy: 0.9672


## Loading Data with Hugging Face

This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator.

Install the Hugging Face `datasets` library.

In [30]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting aiohttp (from datasets)
  Downloading aiohttp-3.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->datasets)
  Downloading aiohappyeyeballs-2.4.3-py3-none-any.whl.metadata (6.1 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->datasets)
  Downloading aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)
Collecting

In [31]:
from datasets import load_dataset

Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays.

In [32]:
mnist_dataset = load_dataset("mnist", cache_dir=data_dir).with_format("numpy")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/6.97k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/15.6M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.60M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

### Extract images and labels

Get image shape and flatten for model input.

In [33]:
train_images = mnist_dataset["train"]["image"]
train_labels = mnist_dataset["train"]["label"]
test_images = mnist_dataset["test"]["image"]
test_labels = mnist_dataset["test"]["label"]

# Extract image shape
image_shape = train_images.shape[1:]
num_features = image_shape[0] * image_shape[1]

# Flatten the images
train_images = train_images.reshape(-1, num_features)
test_images = test_images.reshape(-1, num_features)

# One-hot encode the labels
train_labels = one_hot(train_labels, n_targets)
test_labels = one_hot(test_labels, n_targets)

In [34]:
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


### Define Training Generator

Set up a generator to yield batches of images and labels for training.

In [35]:
def hf_training_generator():
    """Yield batches for training."""
    for batch in mnist_dataset["train"].iter(batch_size):
        x, y = batch["image"], batch["label"]
        yield x, y

### Training Loop (Hugging Face Datasets)

Run the training loop using the Hugging Face training generator.

In [36]:
train_model(num_epochs, params, hf_training_generator)

Epoch 1 in 6.24 sec: Train Accuracy: 0.9159,Test Accuracy: 0.9197
Epoch 2 in 5.76 sec: Train Accuracy: 0.9371,Test Accuracy: 0.9383
Epoch 3 in 5.70 sec: Train Accuracy: 0.9493,Test Accuracy: 0.9468
Epoch 4 in 6.36 sec: Train Accuracy: 0.9568,Test Accuracy: 0.9536
Epoch 5 in 5.89 sec: Train Accuracy: 0.9632,Test Accuracy: 0.9576
Epoch 6 in 5.78 sec: Train Accuracy: 0.9674,Test Accuracy: 0.9617
Epoch 7 in 5.74 sec: Train Accuracy: 0.9708,Test Accuracy: 0.9649
Epoch 8 in 6.21 sec: Train Accuracy: 0.9737,Test Accuracy: 0.9672


## Summary

This notebook has introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for machine learning tasks. Each library offers distinct advantages, allowing you to select the best approach for your specific project needs.

For more detailed strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html).