In [1]:
# |default_exp utils.data
# |export
from typing import Tuple
import torch
from torch.utils import data
from torchvision import datasets, transforms

## Data

In [2]:
# |export


def load_iris(*, force_single_precision=False) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Load the Iris dataset and return the data and targets.

    Args:
        force_single_precision: Whether to force the data to be single precision.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The data and targets.
    """
    from sklearn import datasets

    iris = datasets.load_iris()
    data = torch.from_numpy(iris["data"])
    if force_single_precision:
        data = data.float()

    targets = torch.from_numpy(iris["target"]).to(torch.long)
    return data, targets

In [3]:
def get_mnist_datasets(cache_path: str) -> Tuple[data.Dataset, data.Dataset]:
    """
    Get the MNIST datasets.

    Args:
        cache_path: The path to the cache directory.
    Returns:
        Tuple[data.Dataset, data.Dataset]: The MNIST train and test datasets.
    """

    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                (0.1307,), (0.3081,)
            ),  # normalization, see https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457
        ]
    )

    mnist_train_set = datasets.MNIST(
        root=cache_path, train=True, download=True, transform=transform
    )
    mnist_test_set = datasets.MNIST(
        root=cache_path, train=False, download=True, transform=transform
    )
    return mnist_train_set, mnist_test_set

In [4]:
def get_fashion_mnist_datasets(cache_path: str) -> Tuple[data.Dataset, data.Dataset]:
    """
    Get the Fashion MNIST datasets.

    Args:
        cache_path: The path to the cache directory.
    Returns:
        Tuple[data.Dataset, data.Dataset]: The Fashion MNIST train and test datasets.
    """
    
    fmnist_train_set = datasets.FashionMNIST(
        root=cache_path, train=True, download=True, transform=transforms.ToTensor()
    )
    fmnist_test_set = datasets.FashionMNIST(
        root=cache_path, train=False, download=True, transform=transforms.ToTensor()
    )
    return fmnist_train_set, fmnist_test_set

In [5]:
import os

cwd = os.getcwd()
cache_path = os.path.join(cwd, "datasets")

fmnist_trainset, fmnist_testset = get_fashion_mnist_datasets(cache_path)
mnist_trainset, mnist_testset = get_mnist_datasets(cache_path)