# Set up input data

In [None]:
# Data transformations.
!pip install grain transformers



In [None]:
# We'll use data from TFDS for simplicity.
!pip install tensorflow-datasets[tf-nightly]



In [None]:
# Remove dir if exists.
!rm -rf /tmp/arrayrecord/ag_news_subset

In [None]:
import json
import os
from array_record.python.array_record_module import ArrayRecordWriter
import grain
import tensorflow_datasets as tfds

# Copy to local dir and convert to JSON-serialized.
data_path = "/tmp/arrayrecord/ag_news_subset"
os.makedirs(data_path)
writer = ArrayRecordWriter(
    os.path.join(data_path, "train.array-record"), "group_size:1"
)

source = grain.MapDataset.source(
    tfds.data_source("ag_news_subset", split="train")
)

for idx, e in enumerate(source.to_iter_dataset()):
  if idx % 10000 == 0:
    print(f"Written {idx} examples")
  new_e = {}
  for k, v in e.items():
    new_e[k] = v.decode("utf-8") if isinstance(v, bytes) else v
  writer.write(json.dumps(new_e).encode("utf-8"))

writer.close()
print(f"Finished")



Written 0 examples
Written 10000 examples
Written 20000 examples
Written 30000 examples
Written 40000 examples
Written 50000 examples
Written 60000 examples
Written 70000 examples
Written 80000 examples
Written 90000 examples
Written 100000 examples
Written 110000 examples
Finished


In [None]:
!ls -lh /tmp/arrayrecord/ag_news_subset

total 31M
-rw-r--r-- 1 root root 31M Nov 17 22:20 train.array-record


# Create source and examine the data

We have locally stored data serialized as JSON in an ArrayRecord file.

In [None]:
import grain
from pprint import pprint

data_path = "/tmp/arrayrecord/ag_news_subset/train.array-record"
source = grain.sources.ArrayRecordDataSource(data_path)
print(f"{len(source)} examples")

120000 examples


# Process the data

## Parse

Let's read, parse and inspect the data. `MapDataset` object acts as a lazily initialized sequence. It will only process an element at the given index.

In [None]:
import json


parsed_ds = grain.MapDataset.source(source).map(json.loads)

pprint(parsed_ds[0])

{'description': 'AMD #39;s new dual-core Opteron chip is designed mainly for '
                'corporate computing applications, including databases, Web '
                'services, and financial transactions.',
 'label': 3,
 'title': 'AMD Debuts Dual-Core Opteron Processor'}


## Tokenize

One of the central transformations in text data processing is tokenization -- splitting text into tokens and mapping them into a vocabulary entry indices for ML training-friendly representation. In this particular demo we will use HuggingFace tokenizer APIs. Any other would work as well.

In [None]:
from transformers import AutoTokenizer
import numpy as np

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize(element):
  tokenized = tokenizer(element["description"])
  return {"description": np.asarray(tokenized["input_ids"])}

tokenized_ds = parsed_ds.map(tokenize)

pprint(tokenized_ds[0])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


{'description': array([  101,  2572,  2094,  1001,  4464,  1025,  1055,  2047,  7037,
        1011,  4563, 23569, 26534,  9090,  2003,  2881,  3701,  2005,
        5971,  9798,  5097,  1010,  2164, 17881,  1010,  4773,  2578,
        1010,  1998,  3361, 11817,  1012,   102])}


## Shuffle, repeat & shard

In order to prevent order bias we globally shuffle all records. We then repeat the dataset multiple times to make it generalize better (different epochs are shuffled differently).

In case of distributed training, we split the dataset into the # of hosts parts.

In [None]:
# Global shuffle.
shuffled_ds = tokenized_ds.shuffle(seed=42)

# Repeat dataset 10 times, each epoch is shuffled differently.
repeated_ds = shuffled_ds.repeat(num_epochs=100)

# Shard for distributed training.
shard_index = 1  # this will typically be jax.process_index()
shard_count = 16  # this will typically be jax.process_count()
sharded_ds = repeated_ds[shard_index::shard_count]

pprint(sharded_ds[0])

{'description': array([  101,  3688,  3517,  2574,  2008,  2097,  7532,  3081,  1037,
        2047,  2188, 14048, 12827,  6153,  2011,  2070,  1997,  1996,
        2088,  1005,  1055,  2922,  7325,  8139,  1998,  3274,  3316,
        1012,   102])}


## Pack

Text and multimedia data have naturally varying sizes. In order to enable batched processing of such data by ML training, Grain provides several bin packing algorithms. They allow to fit varying size data into fixed size length and minimize the necessary padding.

Since packing needs to fetch a varying number of elements to fit the fixed size bins, it can no longer preserve indexing in the original dataset. It therefor requires confersion to a `grain.IterDataset` which is a Python `Iterable` producing a `grain.DatasetIterator` for fetching elements that supports checkpointing.

In [None]:
sequence_length = 128

def trim_values(element):
  return {"description": element["description"][:sequence_length]}

trimmed_ds = sharded_ds.map(trim_values).to_iter_dataset(grain.ReadOptions(num_threads=0))
packed_ds = grain.experimental.FirstFitPackIterDataset(
      trimmed_ds,
      length_struct={"description": sequence_length},
      num_packing_bins=30
     )

pprint(next(iter(packed_ds)))

{'description': array([  101, 26665,  1011,  5522,  1005,  1055, 23205, 29501,  3062,
        1015,  1012,  5764,  3867,  2011,  1032, 22878,  2006,  6928,
        1010,  8402,  6409,  2046,  1037,  2353,  2154,  2004,  2178,
        1032, 12058,  1999,  3514,  7597, 20183, 15508,  2055,  1996,
        3795,  3171,  1032,  4254,  1998,  6573,  2091,  9167,  2545,
        2107,  2004, 11742,  5013, 13058,  1012,   102,   101,  1037,
        6816,  4457,  2010,  3954,  1998,  1037, 10563,  2012,  1037,
        2221,  4189,  1999,  2358,  1012, 14060,  1010,  2021,  4445,
        4265,  2350,  6441,  1012, 20099,  6610,  8833, 17922, 24598,
        1010,  3954,  1997,  1996,  6816,  1010,  4265, 26136, 14890,
         102,   101, 15335,  2176,  7767,  8046,  1996, 12592,  2231,
        2006,  6928,  2000,  6186,  1996,  6019,  1997,  1037,  9042,
        1011,  4427,  6543,  7450,  3832,  2000,  2562,  2343, 17127,
        2474,  6806,  6784,  1999,  2373,  2005,  2178,  2093,  2086,
    

## Batch

Now that the example sizes are fixed, we can batch the data for training!

In [None]:
batch_size = 512

batched_ds = packed_ds.batch(batch_size, drop_remainder=True)

pprint(next(iter(batched_ds)))

{'description': array([[  101, 26665,  1011, ...,  2086,  1012,   102],
       [  101, 10884,  1006, ...,     0,     0,     0],
       [  101,  9706,  1011, ...,     0,     0,     0],
       ...,
       [  101, 14497,  1010, ...,     0,     0,     0],
       [  101,  9838,  1001, ...,     0,     0,     0],
       [  101,  1996,  7794, ...,     0,     0,     0]]),
 'description_positions': array([[ 0,  1,  2, ..., 34, 35, 36],
       [ 0,  1,  2, ...,  0,  0,  0],
       [ 0,  1,  2, ...,  0,  0,  0],
       ...,
       [ 0,  1,  2, ...,  0,  0,  0],
       [ 0,  1,  2, ...,  0,  0,  0],
       [ 0,  1,  2, ...,  0,  0,  0]], dtype=int32),
 'description_segment_ids': array([[1, 1, 1, ..., 3, 3, 3],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0]], dtype=int32)}


## Enable visualization mode

In order to understand the sequence of transformations and their outputs better, Grain offers a pipeline visualization mode.

In [None]:
from absl import flags

# Enable visualization.
flags.FLAGS.mark_as_parsed()
grain.config.update("py_dataset_visualization_output_dir", "")

next(iter(batched_ds))

# Disable visualization -- we don't need it for following sections.
grain.config.update("py_dataset_visualization_output_dir", None)

Grain Dataset graph:

SourceMapDataset(source=ArrayRecordDataSource)
  ││
  ││  
  ││
  ╲╱
'bytes[]'

  ││
  ││  MapMapDataset(transform=loads @ .../python3.12/json/__init__.py:299)
  ││
  ╲╱
{'description': 'str[]', 'label': 'int[]', 'title': 'str[]'}

  ││
  ││  MapMapDataset(transform=tokenize @ ...//tmp/ipython-input-1246250811.py:6)
  ││
  ╲╱
{'description': 'int64[29]'}

  ││
  ││  ShuffleMapDataset
  ││
  ╲╱
{'description': 'int64[29]'}

  ││
  ││  RepeatMapDataset(num_epochs=100)
  ││
  ╲╱
{'description': 'int64[29]'}

  ││
  ││  SliceMapDataset[1:12000000:16]
  ││
  ╲╱
{'description': 'int64[29]'}

  ││
  ││  MapMapDataset(transform=trim_values @ ...//tmp/ipython-input-4160003364.py:3)
  ││
  ╲╱
{'description': 'int64[29]'}

  ││
  ││  PrefetchDatasetIterator(read_options=ReadOptions(num_threads=0, prefetch_buffer_size=500), allow_nones=False)
  ││
  ╲╱
{'description': 'int64[29]'}

Grain Dataset graph:

SourceMapDataset(source=ArrayRecordDataSource)
  ││
  ││  
  ││
  ╲╱
'byt

# Measure throughput

We have the necessary transformations, let's make sure that we're utilizing the training accelerator efficiently!

In [None]:
import time

num_batches = 200


def time_iterator(it) -> None:
  start_time = time.perf_counter()
  next(it)
  time_to_first_batch = time.perf_counter() - start_time
  print(f"Time to get the first batch: {time_to_first_batch:.02f} sec.")
  start_time = time.perf_counter()
  for _ in range(num_batches):
    next(it)
  total_time = time.perf_counter() - start_time
  print(
      "Iterator throughput:"
      f" {(num_batches * batch_size / total_time):.02f} examples/sec; "
      f"{(num_batches / total_time):.02f} batches/sec."
  )

In [None]:
time_iterator(iter(batched_ds))

Time to get the first batch: 0.81 sec.
Iterator throughput: 656.87 examples/sec; 1.28 batches/sec.


Once you've measured your pipeline's throughput, there's two possible scenarios: it is either faster or slower than your training step.

## Data loading is slower than training

### Enable performance debug mode

Grain offers a debug mode in which each transformation execution time is tracked and periodically logged into a table.

In [None]:
grain.config.update("py_debug_mode", True)

time_iterator(iter(batched_ds))

Time to get the first batch: 0.83 sec.
Grain Dataset Execution Summary:

NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.

|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| id | name                           | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |
|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 7  | SourceMapDataset(s

### GIL-free bottleneck

A subclass of bottlenecks that are executed without Python's GIL can be dealt with relatively easy.

Some examples of such transformations: IO, numpy, JAX, PIL, C/C++ extension modules.

Increase # of threads! But keep it lower than the number of available cores.

In [None]:
read_options = grain.ReadOptions(num_threads=24)

ds = grain.MapDataset.source(source).map(json.loads).map(tokenize)
ds = ds.shuffle(seed=42).repeat(num_epochs=100)[shard_index::shard_count]
ds = ds.map(trim_values).to_iter_dataset(read_options)
ds = grain.experimental.FirstFitPackIterDataset(
      ds,
      length_struct={"description": sequence_length},
      num_packing_bins=30
     )
ds = ds.batch(batch_size, drop_remainder=True)

time_iterator(iter(ds))

Time to get the first batch: 0.52 sec.
Grain Dataset Execution Summary:

NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.

|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| id | name                           | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |
|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 7  | SourceMapDataset(s

### Bottleneck with GIL

These bottlenecks do not allow to take advantage of multithreading in Python and therefore require multiprocessing.

In [None]:
ds = grain.MapDataset.source(source).map(json.loads).map(tokenize)
ds = ds.shuffle(seed=42).repeat(num_epochs=100)[shard_index::shard_count]
ds = ds.map(trim_values).to_iter_dataset(grain.ReadOptions(num_threads=2))
ds = grain.experimental.FirstFitPackIterDataset(
    ds, length_struct={"description": sequence_length}, num_packing_bins=30
)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.mp_prefetch(
    grain.multiprocessing.MultiprocessingOptions(num_workers=16)
)

time_iterator(iter(ds))

Time to get the first batch: 49.06 sec.




Grain Dataset Execution Summary:

NOTE: Before analyzing the `total_processing_time` for a node, please check the `percent wait time` column to ensure that the node is indicated as bottleneck. The `MapDataset` nodes are executed in multiple threads and thus, should not be compared to the `total_processing_time` of `DatasetIterator` nodes.

|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| id | name                           | inputs | percent wait time | total processing time | min processing time | max processing time | avg processing time | num produced elements | memory usage |
|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 0  | MultiprocessPrefetchDatasetIte | []     | N/A            

## Data loading is faster than training

Don't spend time optimizing your pipeline, just hide its latency behind the training!

In this example data fetching and training are synchronous:

In [None]:
import time

train_step_latency_s = 1
num_steps = 100


def train_step(data_batch):
  del data_batch
  time.sleep(train_step_latency_s)


def train(dataset):
  it = iter(dataset)
  start_time = time.perf_counter()
  for _ in range(num_steps):
    data_batch = next(it)
    train_step(data_batch)

  training_time = time.perf_counter() - start_time
  idle_ratio = (
      training_time - num_steps * train_step_latency_s
  ) / training_time
  print(f"Spent {(idle_ratio * 100):.2f}% of time waiting for data")

In [None]:
train(batched_ds)

Spent 46.24% of time waiting for data


### Add background prefetching

Background thread prefetching allows to asynchronously process data before it is requested and thus hides majority of the data processing latency.

In [None]:
prefetched_ds = grain.experimental.ThreadPrefetchIterDataset(
    batched_ds, prefetch_buffer_size=3
)

train(prefetched_ds)

Spent 1.60% of time waiting for data


### Hide first batch processing behind checkpoint recovery

Another (often ovelooked) optimization is to overlap first batch processing with the model checkpoint recovery.

In [None]:
model_restore_latency_s = 5


def restore_model():
  time.sleep(model_restore_latency_s)


def train_from_checkpoint(dataset):
  it = iter(dataset)
  it.start_prefetch()
  start_time = time.perf_counter()
  restore_model()
  for _ in range(num_steps):
    data_batch = next(it)
    train_step(data_batch)

  training_time = time.perf_counter() - start_time
  idle_ratio = (
      training_time - num_steps * train_step_latency_s - model_restore_latency_s
  ) / training_time
  print(f"Spent {(idle_ratio * 100):.2f}% of time waiting for data")

In [None]:
train_from_checkpoint(prefetched_ds)

Spent 0.48% of time waiting for data
