In [1]:
# |default_exp utils.data
# |export
from typing import Tuple, Iterable, Literal
import torch
from torch.utils import data
from torchvision import datasets, transforms
from functools import cache

## Data

In [2]:
# |export
@cache
def load_iris(*, force_single_precision=False) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load the Iris dataset and return the data and targets.

    Args:
        force_single_precision: Whether to force the data to be single precision.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The data and targets.
    """
    from sklearn import datasets

    iris = datasets.load_iris()
    data = torch.from_numpy(iris["data"])
    if force_single_precision:
        data = data.float()

    targets = torch.from_numpy(iris["target"]).to(torch.long)
    return data, targets

In [3]:
# |export
@cache
def get_mnist_datasets(
    cache_path: str, *, normalization: bool
) -> Tuple[data.Dataset, data.Dataset]:
    """
    Get the MNIST datasets.

    Args:
        cache_path: The path to the cache directory.
        normalization: Whether to normalize the images based on the mean and standard deviation of the MNIST dataset.
    Returns:
        Tuple[data.Dataset, data.Dataset]: The MNIST train and test datasets.
    """
    if normalization:
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.1307,), (0.3081,)
                ),  # normalization, see https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457
            ]
        )
    else:
        transform = transforms.Compose([transforms.ToTensor()])

    mnist_train_set = datasets.MNIST(
        root=cache_path, train=True, download=True, transform=transform
    )
    mnist_test_set = datasets.MNIST(
        root=cache_path, train=False, download=True, transform=transform
    )
    return mnist_train_set, mnist_test_set


def _calc_load_num(num: int | None, classes: int | Iterable[int] | None, dataset_len: int) -> int:
    if num is None:
        return dataset_len
    elif classes is None:
        # num not None, classes is None
        return num
    else:
        # num not None, classes is not None
        return dataset_len


def load_mnist_images(
    *,
    cache_path: str,
    num: int | None,
    from_subset: Literal["test", "train", "all"],
    shuffle: bool,
    normalization: bool,
    classes: int | Iterable[int] | None = None,
    return_labels: bool = False,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
    """
    Load MNIST images from the cache directory.

    Args:
        cache_path: The path to the cache directory.
        num: The number of images to load.
        from_subset: The subset of the dataset to load.
        shuffle: Whether to shuffle the dataset.
        classes: The classes to load.
        normalization: Whether to normalize the images based on the mean and standard deviation of the MNIST dataset.
        return_labels: Whether to return the labels.
    Returns:
        torch.Tensor: The MNIST images. If return_labels is True, return a tuple of (images, labels).
    """
    assert num is None or num > 0
    assert classes is None or isinstance(classes, (int, Iterable)), (
        "classes must be an integer or an iterable of integers"
    )
    train_set, test_set = get_mnist_datasets(cache_path, normalization=normalization)
    if from_subset == "train":
        dataloader = data.DataLoader(
            train_set, _calc_load_num(num, classes, len(train_set)), shuffle=shuffle
        )
        images, labels = next(iter(dataloader))
    elif from_subset == "test":
        dataloader = data.DataLoader(
            test_set, _calc_load_num(num, classes, len(test_set)), shuffle=shuffle
        )
        images, labels = next(iter(dataloader))
    elif from_subset == "all":
        dataloader = data.DataLoader(
            train_set, _calc_load_num(num, classes, len(train_set)), shuffle=shuffle
        )
        train_data, train_labels = next(iter(dataloader))
        dataloader = data.DataLoader(
            test_set, _calc_load_num(num, classes, len(test_set)), shuffle=shuffle
        )
        test_data, test_labels = next(iter(dataloader))
        images = torch.cat([train_data, test_data], dim=0)
        labels = torch.cat([train_labels, test_labels], dim=0)
    else:
        raise ValueError(f"Invalid subset: {from_subset}")

    if classes is not None:
        if isinstance(classes, int):
            classes = [classes]
        else:
            classes = list(classes)

        mask = torch.isin(labels, torch.tensor(classes))
        images = images[mask]
        labels = labels[mask]

    if num is None:
        num = len(images)

    if return_labels:
        return images[:num], labels[:num]
    else:
        return images[:num]

In [4]:
# |export
@cache
def get_fashion_mnist_datasets(cache_path: str) -> Tuple[data.Dataset, data.Dataset]:
    """
    Get the Fashion MNIST datasets.

    Args:
        cache_path: The path to the cache directory.
    Returns:
        Tuple[data.Dataset, data.Dataset]: The Fashion MNIST train and test datasets.
    """

    fmnist_train_set = datasets.FashionMNIST(
        root=cache_path, train=True, download=True, transform=transforms.ToTensor()
    )
    fmnist_test_set = datasets.FashionMNIST(
        root=cache_path, train=False, download=True, transform=transforms.ToTensor()
    )
    return fmnist_train_set, fmnist_test_set

In [5]:
import os

cwd = os.getcwd()
cache_path = os.path.join(cwd, "datasets")

fmnist_trainset, fmnist_testset = get_fashion_mnist_datasets(cache_path)
mnist_trainset, mnist_testset = get_mnist_datasets(cache_path, normalization=True)

In [6]:
for num in [10, None]:
    for classes in [0, None]:
        images = load_mnist_images(
            cache_path=cache_path,
            num=num,
            from_subset="train",
            shuffle=False,
            classes=classes,
            normalization=False,
        )