Import modules.


In [None]:
import os
import time

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf

import chiron


Disable visibility of all GPU devices.


In [None]:
chiron.set_visible_gpus()


Load dataset from TFRecord file.


In [None]:
dataset = chiron.load_tfrecord(
    "../data/brain-tumor-public-dataset/tfrecord/train/fold-1.tfrecord"
)
dataset.element_spec


Define mapping function parameters.


In [None]:
shuffle_buffer_size = 512
seed = 0
label_map = {"meningioma": 0, "glioma": 1, "pituitary": 2}
num_classes = 3
image_size = [256, 256]
batch_size = 32


One-hot encode labels.


In [None]:
label_encoder = chiron.LabelEncoder(label_map, num_classes)


Resize images so they may be batched.


In [None]:
resizer = chiron.Resizer(image_size)


Standardize images across batch.


In [None]:
whitener = chiron.PerBatchStandardWhitener()


Define naive data preprocessing pipeline.


In [None]:
dataset_naive = (
    dataset.shuffle(shuffle_buffer_size, seed=seed)
    .map(label_encoder)
    .map(resizer)
    .batch(batch_size)
    .map(whitener)
)


Define optimized data preprocessing pipeline. Parallelizes data transformations, caches dataset in memory, and prefetches elements.


In [None]:
dataset_optimized = (
    dataset.shuffle(shuffle_buffer_size, seed=seed)
    .map(label_encoder, num_parallel_calls=tf.data.AUTOTUNE)
    .map(resizer, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .map(whitener, num_parallel_calls=tf.data.AUTOTUNE)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)


Benchmark performance.


In [None]:
num_repeats = 5
interval = 0.01


def benchmark(dataset):
    start = time.perf_counter()
    elapsed = 0.0
    for _ in range(num_repeats):
        for _ in dataset:
            time.sleep(interval)
            elapsed += interval
    return time.perf_counter() - start - elapsed


elapsed_naive = benchmark(dataset_naive)
elapsed_optimized = benchmark(dataset_optimized)
speed_up = elapsed_naive / elapsed_optimized

print(f"Naive: {elapsed_naive:.2f}s")
print(f"Optimized: {elapsed_optimized:.2f}s")
print(f"Speed-up: x{speed_up:.2f}")
