In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import json
import os
from functools import partial
import time
import logging
import random
import re

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

# funcs
def batch_annotations(annotations, batch_size):
    batches = []
    for i in range(0, len(annotations), batch_size):
        batch = annotations[i:i+batch_size]
        file_names, labels = [], []
        for example in batch:
            file_names.append(example["file_name"])
            labels.append(example["category_id"])
        
        zipped_batch = list(zip(file_names, labels))
        random.shuffle(zipped_batch) # shuffle
        file_names, labels = map(list, zip(*zipped_batch)) # unpack and convert back to lists
        
        batches.append((file_names, labels))
    return batches

def resize_image(file_name, label, directory_prefix, resize_method):
    try:
        # load and decode jpeg
        image = tf.io.read_file(tf.strings.join([tf.constant(directory_prefix), file_name]))
        image = tf.io.decode_jpeg(image, channels=3)
    except tf.errors.OpError as e:
        tf.print(f"Error loading or decoding image: {file_name}")
        raise e

    # resize
    if resize_method == "crop_or_pad":
        image = tf.image.resize_with_crop_or_pad(image, 224, 224)
    elif resize_method == "pad":
        image = tf.image.resize_with_pad(image, 224, 224)
    else:
        image = tf.image.resize(image, [224, 224])

    # inspect value in graph mode (alternatively can turn on eager mode for debugging)
    # tf.print("Range", tf.reduce_min(image), "Max value:", tf.reduce_max(image))

    # encode again for later serializing into TFRecord
    image = tf.io.encode_jpeg(tf.cast(image, tf.uint8))

    return image, label

def check_safe_shuffle(example_batches, categories):
    """
    Ensure correct category IDs after shuffling.
    """
    for example_batch in example_batches:
        file_names, labels = example_batch[0], example_batch[1]
        file_names_length = len(file_names)
        assert file_names_length == len(labels)
        for i in range(0, file_names_length):
            category_name = categories[str(labels[i])].replace("\u00D7", "") # e.g. id 2061
            assert category_name in file_names[i]

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize_example(image, label):
    image_shape = image.shape

    feature = {
        "image": _bytes_feature(image),
        "label": _int64_feature(label.numpy())
    }
    
    return tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()

def write_compressed_tfrecord(dataset, record_file):
    options = tf.io.TFRecordOptions(compression_type="GZIP")
    with tf.io.TFRecordWriter(record_file, options) as writer:
        for image, label in dataset:
            serialized_example = serialize_example(image, label)
            writer.write(serialized_example)

In [4]:
# config
records_per_file = 1000 # also acts as batch size / buffer size
dataset_type = "val2017"
directory_prefix = "/Volumes/t7/"
tfrecord_file_prefix = f"{directory_prefix}train_val_images-processed(Aves)/{dataset_type}/"
resize_method = "pad"

# read processed json batch annotations
with open(f"./train_val2017/{dataset_type}-processed.json", "r") as f:
    annotations = json.load(f)

    # create Aves subset
    annotations = [anno for anno in annotations if anno['file_name'].startswith('train_val_images/Aves')]

    random.shuffle(annotations) # shuffle

example_batches = batch_annotations(annotations, records_per_file)
example_batches_len = len(example_batches)
random.shuffle(example_batches) # shuffle

# ensure shuffling correct
with open(f"./train_val2017/categories.json", "r") as f:
    categories = json.load(f)
check_safe_shuffle(example_batches, categories)

# ensure directories exist
os.makedirs(os.path.dirname(tfrecord_file_prefix), exist_ok=True)

# prepare data, preprocess and shuffle batches
dataset_map_func = partial(resize_image, directory_prefix=directory_prefix, resize_method=resize_method)
for i, example_batch in enumerate(example_batches):
    start_time = time.time()

    # construct dataset, resize and shuffle batch
    dataset = tf.data.Dataset.from_tensor_slices(example_batch)
    dataset = dataset.map(dataset_map_func, num_parallel_calls=tf.data.AUTOTUNE).shuffle(buffer_size=records_per_file)

    # write as .tfrecord
    tfrecord_file = f"{tfrecord_file_prefix}inat17_batch-{i+1}-of-{example_batches_len}.tfrecord"
    write_compressed_tfrecord(dataset, tfrecord_file)

    total_time = time.time() - start_time
    print(f"Batch preprocessing complete for {tfrecord_file} (total time {total_time:.2f}s)")


Batch preprocessing complete for /Volumes/t7/train_val_images-processed(Aves)/val2017/inat17_batch-1-of-22.tfrecord (total time 2.28s)
Batch preprocessing complete for /Volumes/t7/train_val_images-processed(Aves)/val2017/inat17_batch-2-of-22.tfrecord (total time 2.00s)
Batch preprocessing complete for /Volumes/t7/train_val_images-processed(Aves)/val2017/inat17_batch-3-of-22.tfrecord (total time 2.02s)
Batch preprocessing complete for /Volumes/t7/train_val_images-processed(Aves)/val2017/inat17_batch-4-of-22.tfrecord (total time 2.05s)
Batch preprocessing complete for /Volumes/t7/train_val_images-processed(Aves)/val2017/inat17_batch-5-of-22.tfrecord (total time 2.15s)
Batch preprocessing complete for /Volumes/t7/train_val_images-processed(Aves)/val2017/inat17_batch-6-of-22.tfrecord (total time 0.53s)
Batch preprocessing complete for /Volumes/t7/train_val_images-processed(Aves)/val2017/inat17_batch-7-of-22.tfrecord (total time 2.03s)
Batch preprocessing complete for /Volumes/t7/train_val_