# Loading data in TensorFlow functions

In [1]:
import numpy as np
import os
import tensorflow as tf

In [18]:
def dataset_classifcation(path, resize_h, resize_w, train=True, limit=None):
    
    # list all paths to data classes except DS_Store
    class_folders = [f for f in sorted(os.listdir(path)) if not f.startswith('.')]
    # load images
    images = []
    classes = []
    for i, c in enumerate(class_folders):
        #images_per_class = sorted(os.path.join(path, c))
        images_per_class = [f for f in sorted(os.listdir(os.path.join(path, c))) if 'jpg' in f]
        
        for image_per_class in images_per_class:
            images.append(os.path.join(path, c, image_per_class))
            # the index will be the class label
            classes.append(i)

    images_tf = tf.data.Dataset.from_tensor_slices(images)
    classes_tf = tf.data.Dataset.from_tensor_slices(classes)
    # put two arrays together so that each image has its classifying label 
    dataset = tf.data.Dataset.zip((images_tf, classes_tf))
    
    @tf.function
    def read_images(image_path, class_type, mirrored=False):
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image)
        

        h, w, c = image.shape
        
        if not (h == resize_h and w == resize_w):
            image = tf.image.resize(
            image, [resize_h, resize_w], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
            image.set_shape((224, 224, 3))
        
        # change DType of image to float32
        image = tf.cast(image, tf.float32)
        class_type = tf.cast(class_type, tf.float32)
        
        # normalise the image pixels
        image = (image / 255.0)

        return image, class_type

    dataset = dataset.map(
        read_images,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
        deterministic=False)

    return dataset

In [19]:
path = '/Users/georgebrockman/code/georgebrockman/Autoenhance.ai/RoomDetection/images/training_data/'
train_ds = dataset_classifcation(path, 224, 224)
train_ds

<ParallelMapDataset shapes: ((224, 224, 3), ()), types: (tf.float32, tf.float32)>