In [None]:
#|default_exp dataset

In [None]:
%cd ..
%load_ext autoreload
%autoreload 2

In [None]:
# Lesson 14: https://www.youtube.com/watch?v=veqj0DsZSXU
# Lesson 15: https://www.youtube.com/watch?v=0Hi2r4CaHvk

In [None]:
import logging
from typing import Any, Callable

import datasets
import torch
import torchvision.transforms.functional as TF
import matplotlib as mpl
import matplotlib.pyplot as plt
from datasets import load_dataset, load_dataset_builder

from tensorviewer import tv
from tensorviewer.config import set_notebook

In [None]:
logging.disable(logging.WARNING)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams["image.cmap"] = "gray_r"
set_notebook()

In [None]:
name = "fashion_mnist"
builder = load_dataset_builder(name)

In [None]:
print(builder.info.description)

In [None]:
fashion = load_dataset(name, ignore_verifications=True)

In [None]:
fashion["train"][0]

In [None]:
X_KEY, Y_KEY = list(fashion["train"].features)

In [None]:
def inplace(func: Callable) -> Callable:
    def _inner(obj: Any) -> Any:
        func(obj)
        return obj
    return _inner

In [None]:
@inplace
def transform(batch: dict): batch[X_KEY] = [TF.to_tensor(t) for t in batch[X_KEY]]

In [None]:
BATCH_SIZE = 256

In [None]:
tds = fashion.with_transform(transform)

In [None]:
tv(torch.stack(tds["train"][:10]["image"]).squeeze(), axes_visible=False)

In [None]:
from torch.utils.data import DataLoader

In [None]:
x = next(iter(DataLoader(tds["train"])))

In [None]:
type(x), x.keys()

In [None]:
from operator import itemgetter

In [None]:
itemgetter("image", "label")({"image": 1, "label": 2})

In [None]:
data = [
    {"image": 1, "label": 0},
    {"image": 2, "label": 0},
    {"image": 4, "label": 1},
    {"image": 3, "label": 0},
]

get = itemgetter("image", "label")

list(zip(*[get(b) for b in data]))

In [None]:
torch.stack([
    torch.tensor([[1, 2, 3]]),
    torch.tensor([[1, 2, 3]]),
    torch.tensor([[1, 2, 3]]),
    torch.tensor([[1, 2, 3]]),
]).shape

In [None]:
from torch.utils.data import default_collate

In [None]:
from operator import itemgetter
from typing import Mapping

DEFAULT_DEVICE = "cuda:1"

def collate_dict(keys: list[str]):
    get = itemgetter(*keys)
    def _collate(batch: list[dict]):
        return tuple(default_collate(t) for t in zip(*[get(d) for d in batch]))
    return _collate

def place_on_device(device: str = DEFAULT_DEVICE):    
    def _on_device(collate: Callable):
        def _wrapped(batch: tuple):
            return to_device(collate(batch), device)
        return _wrapped
    return _on_device

def to_device(x, device: str):
    if isinstance(x, Mapping): return {k: v.to(device) for k, v in x.items()}
    return type(x)(o.to(device) for o in x)

def make_dls(datasets: dict, batch_size: int, **kwargs):
    return {
        key: DataLoader(dataset, batch_size, **kwargs) 
        for key, dataset in datasets.items()
    }

In [None]:
dls = make_dls(tds, 4, collate_fn=place_on_device()(collate_dict(["image", "label"])))

In [None]:
x, y = next(iter(dls["train"]))

In [None]:
x.shape, y.shape, x.device, y.device