# TensorFlow Dataset

[TensorFlow Datasets](https://www.tensorflow.org/datasets/overview)

### tf.data.Dataset

tf.data.Dataset API is the first and foremost API you should understand when using TensorFlow. When I started using TensorFlow, it was quite hard to understand what it is and I was stuck to modeling and testing. However, if you don't understand the tf.data.Dataset you cannot create your own dataset for modeling or testing. 

According the the official documentation, tf.data.Dataset API provides below three things: 

1. Create a source dataset from your input data.
2. Apply dataset transformations to preprocess the data.
3. Iterate over the dataset and process the elements.

### keras.preprocessing

keras.preprocessing API is to load and preprocess data. 

keras.preprocessing.image is a set of tools for real-time data augmentation on image data. 

[Load using keras.preprocessing](https://www.tensorflow.org/tutorials/load_data/images#load_using_keraspreprocessing)


In [35]:
import tensorflow_datasets as tfds
import tensorflow.compat.v2 as tf
from tensorflow import keras

from IPython.display import display, HTML
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import numpy as np
import requests
import pathlib
import csv
import os


def import_mnist_dataset(log=False):

    fashion_mnist = keras.datasets.fashion_mnist
    (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

    if log:
        print("type of images  : ", type(train_images))
        print("shape of images : ", train_images.shape)
        print("type of label   : ", type(train_labels))
        print("shape of label  : ", train_labels.shape)

        print("type of images  : ", type(test_images))
        print("shape of images : ", test_images.shape)
        print("type of label   : ", type(test_labels))
        print("shape of label  : ", test_labels.shape)

        print("sampel Image")
        plt.imshow(train_images[0])
        
    return (train_images, train_labels), (test_images, test_labels)


def download_iris_dataset(log=False):
    train_dataset_url = "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv"

    train_dataset_file_path = tf.keras.utils.get_file(
        fname=os.path.basename(train_dataset_url),
        origin=train_dataset_url
    )

    if log:
        print("Local copy of the dataset file: {}".format(train_dataset_file_path))

        data = pd.read_csv(train_dataset_file_path)
        display(data.head())
        
    return train_dataset_file_path


def parse_iris_dataset(train_dataset_file_path, log=False, image_display=True):
    """
    tf.data.experimental.make_csv_dataset()
    https://www.tensorflow.org/api_docs/python/tf/data/experimental/make_csv_dataset
    """
    
    # column order in CSV file
    column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']

    feature_names = column_names[:-1]
    label_name = column_names[-1]
    
    batch_size = 32

    train_dataset = tf.data.experimental.make_csv_dataset(
        train_dataset_file_path,
        batch_size,
        column_names=column_names,
        label_name=label_name,
        num_epochs=1)
    
    # Extract first batch.
    # As batch_size = 32, train_dataset is iterative.
    features, labels = next(iter(train_dataset))

    # Display scatter plot of the data.
    if image_display:
        plt.scatter(features['petal_length'],
                    features['sepal_length'],
                    c=labels,
                    cmap='viridis')

        plt.xlabel("Petal length")
        plt.ylabel("Sepal length")
        plt.show()
    
    if log:
        print("Features: {}".format(feature_names))
        print("Label: {}".format(label_name))

        print("Type  : ", type(train_dataset))
        
        print("features : ", features["petal_length"])
        print("labels   : ", labels)
    
    return train_dataset


def execute():

    ds = download_iris_dataset()
    parse_iris_dataset(ds, log=True)


def import_mnist_dataset_info(log=False):
    ds, info = tfds.load('mnist', split='train', shuffle_files=True, with_info=True)
    return ds, info


def display_dataset(ds):
    """Check inside of dataset."""
    for d in ds.take(1):
        print(list(d.keys()))
        image = d["image"]
        label = d["label"]
        print(image.shape, label)

        
def import_dataset_as_numpy():
    # as_supervised=True: Output 2-tuple structure (input, label)
    ds = tfds.load("mnist", split="train", as_supervised=True)
    for image, label in tfds.as_numpy(ds.take(1)):
        print(type(image), type(label), label)


def visualize_dataset():
    """Visualize mnist dataset.

    tfds.visualization.show_examples()
    https://www.tensorflow.org/datasets/api_docs/python/tfds/visualization/show_examples
    """
    ds, info = tfds.load("mnist", split="train", with_info=True)
    fig = tfds.show_examples(ds, info)
    

def cats_and_dog_dataset_load(log=False):
    """Load dataset
    
    Example
    -------
    (raw_train, raw_validation, raw_test), metadata = load_training_dataset(log=True)

    """

    # Each data is tf.data.Dataset.
    # Metadata is info data.
    (raw_train, raw_validation, raw_test), metadata = tfds.load(
        'cats_vs_dogs',
        split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
        with_info=True,
        as_supervised=True,
    )
    
    if log:
        print("raw_train : ", raw_train)
        print("Type      : ", type(raw_train))
    
    return (raw_train, raw_validation, raw_test), metadata
    

def download_flower_dataset():
    data_dir = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                             fname='flower_photos', untar=True)
    return pathlib.Path(data_dir)


def load_images_with_image_generator(
    data_dir="../images", 
    class_names=["dogs", "cats"], 
    batch_size=32,
    img_size=224,
    log=True):
    
    """Image dataset loader
    
    Arguments
    ---------
    data_dir    : str
    class_names : list
        If you set ["dogs", "cats"], data_dir should have those directories.
        class_names is mapping to directory structure. 
        data_dir/
            /cats 
            /dogs 
    batch_size  : int
    img_size    : int
    log         : bool
    
    Return
    ------
    train_data_gen : keras_preprocessing.image.directory_iterator.DirectoryIterator
    
    Example
    -------
    image_data_gen = load_images()
    image_batch, label_batch = next(train_data_gen)
    show_batch(image_batch, label_batch)
    """

    # Get total number of image files.
    total_image_files = 0
    for name in class_names:
        class_path = os.path.join(data_dir, name)
        print(class_path)

        # Get number of files in directory.
        path, dirs, files = next(os.walk(class_path))
        total_image_files += len(files)
        
    if log:
        print("image_count : ", total_image_files)

    # Set up parameters.
    STEPS_PER_EPOCH = np.ceil(total_image_files/batch_size)

    # Generate ImageDataGenerator.
    # The 1./255 is to convert from uint8 to float32 in range [0,1].
    image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

    # Load Images.
    train_data_gen = image_generator.flow_from_directory(directory=str(data_dir),
                                                         batch_size=batch_size,
                                                         shuffle=True,
                                                         target_size=(img_size, img_size),
                                                         classes = list(class_names)
                                                        )
    if log:
        print("train_data_gen : ", type(train_data_gen))

    return train_data_gen
    

def show_batch(train_data_gen, class_names=["cats", "dogs"]):
    """
    Arguments
    ---------
    image_data_gen : keras_preprocessing.image.directory_iterator.DirectoryIterator
    
    Example
    -------
    image_data_gen = load_images()
    show_batch(image_data_gen)
    """
    
    # Extract batch.
    image_batch, label_batch = next(train_data_gen)

    plt.figure(figsize=(10,10))
    
    for n in range(25):
        ax = plt.subplot(5,5, n+1)
        plt.imshow(image_batch[n])

        # Extract label. 
        label_index = 0
        for index, l in enumerate(label_batch[n]):
            if l == 1:
                label_index = index
                
        plt.title(class_names[label_index])
        plt.axis('off')
        

def check_image_data_gen(train_data_gen):
    """Check data type and shep of train data gen."""
    image_batch, label_batch = next(train_data_gen)
    
    print("image_batch    : ", type(image_batch), image_batch.shape)
    print("label_batch    : ", type(label_batch), label_batch.shape)
        
    print("image_batch[0] : ", type(image_batch[0]), image_batch[0].shape)
    print("label_batch[0] : ", type(label_batch[0]), label_batch[0].shape)


In [36]:
if __name__ == "__main__":
    CLASS_NAMES=["cats", "dogs"]
    image_data_gen = load_images_with_image_generator(class_names=CLASS_NAMES)
    #show_batch(image_data_gen, class_names=CLASS_NAMES)
    check_image_data_gen(image_data_gen)

../images/cats
../images/dogs
image_count :  2001
Found 2001 images belonging to 2 classes.
train_data_gen :  <class 'keras_preprocessing.image.directory_iterator.DirectoryIterator'>
image_batch    :  <class 'numpy.ndarray'> (32, 224, 224, 3)
label_batch    :  <class 'numpy.ndarray'> (32, 2)
image_batch[0] :  <class 'numpy.ndarray'> (224, 224, 3)
label_batch[0] :  <class 'numpy.ndarray'> (2,)
