# Data - FashionMNIST

> FashionMNIST DataModule

In [None]:
#| default_exp fashionmnist

In [None]:
#| export
import lightning as L
import torch
import torchvision.transforms.functional as F

from slow_diffusion.data import DiffusionDataModule, noisify, show_images
from slow_diffusion.training import UnetLightning

In [None]:
# |exports
class FashionMNISTDataModule(DiffusionDataModule):
    """Fasion MNIST datamodule"""

    def __init__(self, bs, n_workers=0):
        super().__init__(
            "fashion_mnist",
            bs,
            n_workers,
            img_size=(32, 32),
        )

    def noisify_fn(self, x_0):
        x_0 = F.convert_image_dtype(x_0, torch.float)
        # zero-center so that the mean does not change after adding noise
        x_0 -= 0.5
        return noisify(x_0)

In [None]:
dm = FashionMNISTDataModule(4)
dm.setup()

In [None]:
def preview(dataloder, n=4):
    (x_t, ts), _ = next(iter(dataloder))
    show_images(x_t[:n], [f"t={t.item():.2f}" for t in ts[:n]])

In [None]:
preview(dm.train_dataloader());

In [None]:
preview(dm.val_dataloader());

In [None]:
unet = UnetLightning(
    nfs=(224, 448, 672, 896),
    n_blocks=(3, 2, 2, 1, 1),
    color_channels=1,
)
trainer = L.Trainer(max_epochs=2, fast_dev_run=True)
trainer.fit(model=unet, datamodule=dm)

For debugging, only

In [None]:
# |exports
class TinyFashionMNISTDataModule(FashionMNISTDataModule):
    def post_process(self, ds):
        return ds["train"].select(range(100)).train_test_split(test_size=0.5)

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()