# Deep learning

In [2]:
import torch

from tqdm.auto import tqdm
import numpy as np


DATA_DIR = "data/office"

In [33]:
import sklearn.preprocessing
import imageio.v2 as imageio

import os
from typing import Callable


class ImageDataset(torch.utils.data.Dataset):
    """
    Lazily loads images from a root directory.
    Directory is assumed to be of shape "<root>/<class_name>/<instance_file>".
    """
    def __init__(self, data_dir: str, parser_func: Callable = imageio.imread):
        self.parser_func = parser_func
        self.label_encoder = sklearn.preprocessing.LabelEncoder()
        self.samples = self._load_dataset_paths(data_dir)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, label = self.samples[idx]
        image = self.parser_func(image_path)
        image_tensor = torch.tensor(image)

        return image_tensor, label

    def _load_dataset_paths(self, data_dir):
        class_names = os.listdir(data_dir)
        self.label_encoder.fit(class_names)

        samples = []
        for class_name in tqdm(class_names):
            class_data_dir = os.path.join(data_dir, class_name)

            for file_name in os.listdir(class_data_dir):
                samples.append(
                    (
                        os.path.join(class_data_dir, file_name),
                        self.label_encoder.transform([class_name])[0],
                    )
                )

        return samples


dataset = ImageDataset(os.path.join(DATA_DIR, "amazon"))

  0%|          | 0/31 [00:00<?, ?it/s]

In [31]:
sample_image, label = dataset[0]
print(sample_image.shape)

torch.Size([300, 300, 3])


In [34]:
print(len(dataset))

2817
