This code is a slightly altered version of tf.keras.preprocessing code

original source: https://github.com/keras-team/keras/blob/v3.3.3/keras/src/utils/image_dataset_utils.py#L12-L351

imports

In [None]:

import os
import random
import time
import warnings
from multiprocessing.pool import ThreadPool
import numpy as np

import tensorflow as tf

from tensorflow import keras

from keras.src.utils import dataset_utils
from keras.src.utils import image_utils
from keras.src.utils import io_utils



constants and variables

In [None]:
ALLOWLIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png')

data functions

In [None]:
def standardize_data_format(data_format):

    if data_format is None:

        return 'channels_last'

    return data_format

def labels_to_dataset(labels):

    label_ds = tf.data.Dataset.from_tensor_slices(labels)

    return label_ds

def iter_valid_files(directory, follow_links, formats):

    if not follow_links:

        walk = tf.io.gfile.walk(directory)

    else:

        walk = os.walk(directory, followlinks=follow_links)

    for root, _, files in sorted(walk, key=lambda x: x[0]):

        for fname in sorted(files):

            if fname.lower().endswith(formats):

                yield root, fname

def index_subdirectory(directory, class_indices, follow_links, formats):

    dirname = os.path.basename(directory)
    valid_files = iter_valid_files(directory, follow_links, formats)
    labels = []
    filenames = []
    for root, fname in valid_files:
        labels.append(class_indices[dirname])
        absolute_path = tf.io.gfile.join(root, fname)
        relative_path = tf.io.gfile.join(
            dirname, os.path.relpath(absolute_path, directory)
        )
        filenames.append(relative_path)
    return filenames, labels

def index_directory(
    directory,
    labels,
    formats,
    class_names=None,
    shuffle=True,
    seed=None,
    follow_links=False,
    verbose=True,
):

    subdirs = ['']

    class_names = subdirs

    class_indices = dict(zip(class_names, range(len(class_names))))

    pool = ThreadPool()
    results = []
    filenames = []

    for dirpath in (tf.io.gfile.join(directory, subdir) for subdir in subdirs):
        results.append(
            pool.apply_async(
                index_subdirectory,
                (dirpath, class_indices, follow_links, formats),
            )
        )
    labels_list = []
    for res in results:
        partial_filenames, partial_labels = res.get()
        labels_list.append(partial_labels)
        filenames += partial_filenames

    if verbose:
        io_utils.print_msg(f'Found {len(filenames)} files.')

    pool.close()
    pool.join()
    file_paths = [tf.io.gfile.join(directory, fname) for fname in filenames]

    if shuffle:
        # Shuffle globally to erase macro-structure
        if seed is None:
            seed = np.random.randint(1e6)
        rng = np.random.RandomState(seed)
        rng.shuffle(file_paths)
        if labels is not None:
            rng = np.random.RandomState(seed)
            rng.shuffle(labels)
    return file_paths, labels, class_names

def custom_image_dataset_from_directory(
    directory,
    labels=[],
    class_names=None,
    color_mode='rgb',
    batch_size=32,
    image_size=(200, 200),
    shuffle=False,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation='bilinear',
    follow_links=False,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    data_format=None,
    verbose=True,
):
    if color_mode == 'rgb':
        num_channels = 3
    elif color_mode == 'rgba':
        num_channels = 4
    elif color_mode == 'grayscale':
        num_channels = 1
    else:
        raise ValueError(
            '`color_mode` must be one of {"rgb", "rgba", "grayscale"}. '
            f'Received: color_mode={color_mode}'
        )

    interpolation = interpolation.lower()
    supported_interpolations = (
        'bilinear',
        'nearest',
        'bicubic',
        'area',
        'lanczos3',
        'lanczos5',
        'gaussian',
        'mitchellcubic',
    )
    if interpolation not in supported_interpolations:
        raise ValueError(
            'Argument `interpolation` should be one of '
            f'{supported_interpolations}. '
            f'Received: interpolation={interpolation}'
        )

    if seed is None:
        seed = np.random.randint(1e6)

    image_paths, labels, class_names = index_directory(
        directory,
        labels,
        formats=ALLOWLIST_FORMATS,
        class_names=class_names,
        shuffle=shuffle,
        seed=seed,
        follow_links=follow_links,
        verbose=verbose,
    )

    data_format = standardize_data_format(data_format=data_format)
    if batch_size is not None:
        shuffle_buffer_size = batch_size * 8
    else:
        shuffle_buffer_size = 1024

    if not image_paths:
            raise ValueError(
                f'No images found in directory {directory}. '
                f'Allowed formats: {ALLOWLIST_FORMATS}'
            )

    dataset = paths_and_labels_to_dataset(
            image_paths=image_paths,
            image_size=image_size,
            num_channels=num_channels,
            labels=labels,
            interpolation=interpolation,
            crop_to_aspect_ratio=crop_to_aspect_ratio,
            pad_to_aspect_ratio=pad_to_aspect_ratio,
            data_format=data_format,
            shuffle=shuffle,
            shuffle_buffer_size=shuffle_buffer_size,
            seed=seed,
    )

    if batch_size is not None:
        dataset = dataset.batch(batch_size)

    dataset = dataset.prefetch(tf.data.AUTOTUNE)
        # Users may need to reference `class_names`.
    dataset.class_names = class_names

        # Include file paths for images as attribute.
    dataset.file_paths = image_paths

    return dataset, image_paths

def paths_and_labels_to_dataset(
    image_paths,
    image_size,
    num_channels,
    labels,
    interpolation,
    data_format,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    shuffle=False,
    shuffle_buffer_size=None,
    seed=None,
):
    '''Constructs a dataset of images and labels.'''
    path_ds = tf.data.Dataset.from_tensor_slices(image_paths)

    label_ds = labels_to_dataset(labels)

    ds = tf.data.Dataset.zip((path_ds, label_ds))

    if shuffle:
        ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed)

    args = (
        image_size,
        num_channels,
        interpolation,
        data_format,
        crop_to_aspect_ratio,
        pad_to_aspect_ratio,
    )
    ds = ds.map(
            lambda x, y: (load_image(x, *args), y),
            num_parallel_calls=tf.data.AUTOTUNE,
        )

    return ds

def load_image(
    path,
    image_size,
    num_channels,
    interpolation,
    data_format,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
):
    '''Load an image from a path and resize it.'''
    img = tf.io.read_file(path)
    img = tf.image.decode_image(
        img, channels=num_channels, expand_animations=False
    )

    if pad_to_aspect_ratio and crop_to_aspect_ratio:
        raise ValueError(
            'Only one of `pad_to_aspect_ratio`, `crop_to_aspect_ratio`'
            ' can be set to `True`.'
        )

    if crop_to_aspect_ratio:
        from keras.src.backend import tensorflow as tf_backend

        if data_format == 'channels_first':
            img = tf.transpose(img, (2, 0, 1))
        img = image_utils.smart_resize(
            img,
            image_size,
            interpolation=interpolation,
            data_format=data_format,
            backend_module=tf_backend,
        )
    elif pad_to_aspect_ratio:
        img = tf.image.resize_with_pad(
            img, image_size[0], image_size[1], method=interpolation
        )
        if data_format == 'channels_first':
            img = tf.transpose(img, (2, 0, 1))
    else:
        img = tf.image.resize(img, image_size, method=interpolation)
        if data_format == 'channels_first':
            img = tf.transpose(img, (2, 0, 1))

    if data_format == 'channels_last':
        img.set_shape((image_size[0], image_size[1], num_channels))
    else:
        img.set_shape((num_channels, image_size[0], image_size[1]))
    return img