# Load Dataset

In [2]:
from kfp.components import create_component_from_func
from typing import (
    List,
    NamedTuple
)

BASE_IMAGE = "quay.io/ibm/kubeflow-notebook-image-ppc64le:latest"


def load_dataset(
    path: str,
    blackboard_dir: str,
    configuration: str = "",
    label_column: str = "",
    dataset_dir: str = "/blackboard/dataset",
) -> NamedTuple(
        'LoadDatasetOutput', [
            ('dataset_dir', str),
            ('labels', List[str])
        ]):
    '''
    Load a Huggingface Dataset.

            Parameters:
                    path: Path from which to load the dataset. Huggingfaces hub for datasets is supported. Example: "Lehrig/Monkey-Species-Collection".
                    blackboard_dir: Target directory of all data of a pipeline. Example: "/blackboard".
                    configuration: Name of the dataset configuration to load. Example: "downsized".
                    label_column: Optional name of a label column to be fetched as optional, additional output. Example: "label".
                    dataset_dir: Target directory where the dataset will be loaded to. Should be available as a mount from a PVC. Example: "/blackboard/dataset".

            Returns:
                    dataset_dir: Target directory where the dataset will be loaded to. Same value as input dataset_dir. Example: "/blackboard/dataset".
                    labels: List of labels, if available. Empty list otherwise. Example: ["cat", "dog"]
    '''

    from collections import namedtuple
    from datasets import load_dataset
    from datasets.dataset_dict import DatasetDict
    import logging
    import os
    from PIL.Image import Image
    import sys

    logging.basicConfig(
        stream=sys.stdout,
        level=logging.INFO,
        format='%(levelname)s %(asctime)s: %(message)s'
    )

    if not os.path.exists(blackboard_dir):
        raise ValueError(f"Missing blackboard mount of a persistent volume (PV) into '{dataset_dir}'.")

    if not configuration:
        configuration = None
    logging.info(f"Loading dataset from '{path}' using configuration '{configuration}'...")
    dataset = load_dataset(path, configuration)

    logging.info("Reading image files into bytes...")

    # see: https://huggingface.co/docs/datasets/v2.4.0/en/package_reference/main_classes#datasets.Dataset.save_to_disk
    def read_image_file(example):
        for column in example:
            if isinstance(example[column], Image):
                with open(example[column].filename, "rb") as f:
                    example[column] = {"bytes": f.read()}
        return example

    # note: batching in map caused caching issues, so not using it for now
    dataset = dataset.map(read_image_file)

    logging.info(f"Saving dataset to '{dataset_dir}'...")
    if not os.path.exists(dataset_dir):
        os.makedirs(dataset_dir)
    dataset.save_to_disk(dataset_dir)

    logging.info(f"Dataset saved. Contents of '{dataset_dir}':")
    logging.info(os.listdir(dataset_dir))

    labels = []
    if label_column:
        logging.info(f"Fetching labels from column '{label_column}'...")
        if isinstance(dataset, DatasetDict):
            dataset = next(iter(dataset.values()))
        labels = dataset.features[label_column].names
    output = namedtuple(
            'LoadDatasetOutput',
            ['dataset_dir',
             'labels']
        )

    logging.info("Finished.")
    return output(dataset_dir, labels)


load_dataset_comp = create_component_from_func(
    func=load_dataset,
    output_component_file='component.yaml',
    base_image=BASE_IMAGE
)