In [None]:
# Import images using tf.data. Finer control.
import os 
import tensorflow as tf 
from tensorflow.data import Dataset
from tensorflow.io import decode_png, read_file
from tensorflow.image import pad_to_bounding_box, transpose, convert_image_dtype
import matplotlib.pyplot as plt

img_dir = '/path/to/data/data_directory/' # Path to your image dataset folder.

def read_filenames(root_dir):
    files = []
    for r, d, f in os.walk(root_dir):
        d.sort()
        for file in f:
            if 'DS_Store' not in file:
                files.append(os.path.join(r, file))
    return files
    
files = read_filenames(img_dir+'training')
test_files = read_filenames(img_dir+'validation')
num_examples = len(files)
num_test = len(test_files)
print(f'Found {num_examples} training examples.')
print(f'Found {num_test} testing examples.')

tf_dataset = Dataset.from_tensor_slices(files)
test = Dataset.from_tensor_slices(test_files)

# for f in tf_dataset.take(10): 
#     print(f) # Show dataset elements

class_names = ['F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O'] # This should be your directory structure
channel_num=1 # Number of channels. 3 for rgb
target_height=573 # Target image height
target_width=493 # Target image width


def get_label(file_path):
    """Class labels converted to one-hot vectors."""
    parts = tf.strings.split(file_path, os.path.sep)
    one_hot = parts[-2] == class_names
    return tf.cast(one_hot, tf.int32)

def decode_img(img):
    """Custom resizing with 0-padding."""
    img = decode_png(img, channels=channel_num, dtype=tf.uint8)
    return pad_to_bounding_box(transpose(img), 0, 0, target_height, target_width) # Transpose needed to avoid bug. If tensorflow fixes their ****, this line may cause issues in the future. In this case, remove the transpose.

def process_path(file_path):
    """Process images. Return image-label pairing."""
    label = get_label(file_path)
    img = read_file(file_path)
    img = decode_img(img)
    return img, label

map_fn = lambda x,y: ((convert_image_dtype(x, tf.float32)), y) # Normalization.
batch_size = 32
num_train = 3000

# Create img, label pairs with filenames. Normalize, shuffle, and batch data.
tf_dataset = tf_dataset.map(process_path).shuffle(buffer_size=num_examples, seed=42).map(map_fn)
train = tf_dataset.take(num_train).batch(batch_size)
valid = tf_dataset.skip(num_train).batch(batch_size)
test = test.map(process_path).shuffle(buffer_size=num_test, seed=42).map(map_fn).batch(batch_size) 

# Visualize some example images.
inspect = train.take(1)
inspect = inspect.as_numpy_iterator().next()[0]
plt.figure(1,figsize=(14,3))
for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(inspect[i], cmap='gray') 
    plt.xticks([])
    plt.yticks([])