In [None]:
#|default_exp dataset

In [None]:
%cd ..
%load_ext autoreload
%autoreload 2

In [None]:
import logging
from typing import Any, Callable

import datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import matplotlib as mpl
import matplotlib.pyplot as plt
from datasets import load_dataset, load_dataset_builder

from tensorviewer import tv, opts
from tensorviewer.config import set_notebook

In [None]:
logging.disable(logging.WARNING)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams["image.cmap"] = "gray_r"
set_notebook()

## Fashion MNIST

In [None]:
name = "fashion_mnist"
builder = load_dataset_builder(name)

In [None]:
print(builder.info.description)

In [None]:
fashion = load_dataset(name, ignore_verifications=True)

In [None]:
fashion["train"][0]

In [None]:
X_KEY, Y_KEY = list(fashion["train"].features)

In [None]:
def inplace(func: Callable) -> Callable:
    def _inner(obj: Any) -> Any:
        func(obj)
        return obj
    return _inner

In [None]:
@inplace
def transform(batch: dict): batch[X_KEY] = [TF.to_tensor(t) for t in batch[X_KEY]]

In [None]:
BATCH_SIZE = 256

In [None]:
tds = fashion.with_transform(transform)

In [None]:
tv(torch.stack(tds["train"][:10]["image"]).squeeze(), axes_visible=False)

In [None]:
from torch.utils.data import DataLoader

In [None]:
x = next(iter(DataLoader(tds["train"])))

In [None]:
from operator import itemgetter
from typing import Mapping
from torch.utils.data import default_collate

DEFAULT_DEVICE = "cuda:1"

LABELS = fashion["train"].features["label"].names


class CollateDict:
    def __init__(self, keys: list[str], device: str = "cpu"):
        self.fn = collate_dict(keys)
        self.device = device
    def __call__(self, batch: list[dict]):
        return to_device(self.fn(batch), self.device)

def collate_dict(keys: list[str]):
    get = itemgetter(*keys)
    def _collate(batch: list[dict]):
        return tuple(default_collate(t) for t in zip(*[get(d) for d in batch]))
    return _collate

def to_device(x, device: str):
    if isinstance(x, Mapping): return {k: v.to(device) for k, v in x.items()}
    return type(x)(o.to(device) for o in x)

def get_dls(datasets: dict, batch_size: int, **kwargs):
    return {
        key: DataLoader(dataset, batch_size, **kwargs) 
        for key, dataset in datasets.items()
    }

def get_labels(y): return itemgetter(*y)(LABELS)

In [None]:
dls = get_dls(tds, 16, collate_fn=CollateDict(["image", "label"], DEFAULT_DEVICE))

In [None]:
x, y = next(iter(dls["train"]))

In [None]:
with plt.rc_context({"figure.figsize": (7, 7), "figure.dpi": 70}):
    tv(x.cpu().squeeze(), axes_titles=get_labels(y), axes_visible=False)

In [None]:
img = x[0].cpu()

In [None]:
left_edge = torch.tensor([
    [-1., 1., 0.],
    [-1., 1., 0.],
    [-1., 1., 0.],
])

In [None]:
result = left_edge.view(-1) @ F.unfold(img, (3, 3))

In [None]:
tv(result.view(26, 26))

## Classifier

In [None]:
def conv(ni: int, nf: int, ks: int = 3, stride: int = 2, relu: bool = True):
    m = nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2)
    if relu: m = nn.Sequential(m, nn.ReLU())
    return m

In [None]:
net = nn.Sequential(
    conv(1, 4),
    conv(4, 8),
    conv(8, 16),
    conv(16, 16),
    conv(16, 10, relu=False),
    nn.Flatten()
).to(DEFAULT_DEVICE)

In [None]:
import torch.optim as optim
from FastAI2022p2.core import fit

In [None]:
bs = 256
lr = 0.4

In [None]:
dls = get_dls(tds, bs, collate_fn=CollateDict(["image", "label"], device=DEFAULT_DEVICE), num_workers=0)

In [None]:
fit(5, net, F.cross_entropy, optim.SGD(net.parameters(), lr=lr), dls["train"], dls["test"])

## Autoencoder

In [None]:
def deconv(ni: int, nf: int, ks: int = 3, relu: bool = True):
    layers = [nn.UpsamplingNearest2d(scale_factor=2),
              nn.Conv2d(ni, nf, stride=1, kernel_size=ks, padding=ks//2)]
    if relu: layers.append(nn.ReLU())
    return nn.Sequential(*layers)

In [None]:
def validate(model, loss_fn, data_loader, epoch=0):
    model.eval()
    with torch.no_grad():
        total, count = 0.0, 0
        for xb, _ in data_loader:
            pred = model(xb)
            n = len(xb)
            count += n
            total += loss_fn(pred, xb).item()*n
    print(epoch, f"{total/count:.3f}")

In [None]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb,_ in train_dl:
            loss = loss_func(model(xb), xb)
            loss.backward()
            opt.step()
            opt.zero_grad()
        validate(model, loss_func, valid_dl, epoch)

In [None]:
auto_encoder = nn.Sequential(
    nn.ZeroPad2d(2),
    conv(1, 2),
    conv(2, 4),
    deconv(4, 2),
    deconv(2, 1, relu=False),
    nn.ZeroPad2d(-2),
    nn.Sigmoid()
).to(DEFAULT_DEVICE)

In [None]:
validate(auto_encoder, F.mse_loss, dls["test"])

In [None]:
opt = optim.SGD(auto_encoder.parameters(), lr=0.01)
fit(5, auto_encoder, F.mse_loss, opt, dls["train"], dls["test"])

In [None]:
xb, _ = next(iter(dls["train"]))

In [None]:
pred = auto_encoder(xb)

In [None]:
tv(pred.squeeze(), axes_visible=False)

In [None]:
tv(xb.squeeze(), axes_visible=False)

In [None]:
opt = optim.SGD(auto_encoder.parameters(), lr=0.1)
fit(5, auto_encoder, F.mse_loss, opt, dls["train"], dls["test"])

In [None]:
tv(auto_encoder(xb).squeeze(), axes_visible=False)