In [None]:
from typing import Tuple

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.layers.experimental.preprocessing import (
    CenterCrop,
    RandomContrast,
    RandomRotation,
    RandomTranslation,
    RandomZoom,
)

from config_loader import load_setting

In [None]:
config_file = "../configs/CASIA-maxpy-clean_align.yml"

In [None]:
config = load_setting(config_file)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

@tfds.decode.make_decoder()
def onehot_encoding(example, feature, depth):
    return tf.one_hot(example, depth=depth, dtype=tf.int32)

def get_dataset(
    root_dir: str,
    split: str,
    input_shape: Tuple[int, int, int],
    n_classes: int, 
    batch_size: int,
    seed: int,
    **kwargs
):
    read_config = tfds.ReadConfig(shuffle_seed=seed)
    builder = tfds.ImageFolder(root_dir)
    ds = builder.as_dataset(
        split=split,
        batch_size=batch_size,
        shuffle_files=True,
        decoders={"label": onehot_encoding(depth=n_classes)},
        read_config=read_config,
        as_supervised=True,
    )

    height, width, n_channels = input_shape
    data_augmentation = tf.keras.Sequential(
        [
            RandomRotation(factor=0.05, fill_mode="nearest", seed=seed),
            RandomTranslation(
                height_factor=0.1, width_factor=0.1, fill_mode="wrap", seed=seed
            ),
            RandomZoom(height_factor=0.1, fill_mode="reflect", seed=seed),
            RandomContrast(factor=0.3, seed=seed),
            CenterCrop(height=height, width=width),
        ]
    )

    ds: tf.data.Dataset = (
        ds.map(lambda x, y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)
        .map(lambda x, y: (data_augmentation(x), y), num_parallel_calls=AUTOTUNE)
        .unbatch()
    )
        
    return ds

In [None]:
ds = get_dataset(**config)

In [None]:
n_rows = 5
n_cols = 4
n_samples = n_rows * n_cols

fig = plt.figure(figsize=(n_rows, n_cols), dpi=300)
for i, (x, _) in enumerate(ds.take(n_samples), 1):
    ax = fig.add_subplot(n_rows, n_cols, i)
    ax.imshow(x)
    ax.axis('off')

plt.show()