### Setup

In [None]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial

from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from datasets import load_dataset,load_dataset_builder

from fastprogress import progress_bar,master_bar
from miniai.datasets import *
from miniai.training import *
from miniai.conv import *

In [None]:
from fastcore.test import test_close

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

import logging
logging.disable(logging.WARNING)

### Data

In [None]:
x, y = 'image', 'label'
name='fashion_mnist'
dsd = load_dataset(name)
dsd

In [None]:
@inplace
def transformi(b):
    b[x] = [TF.to_tensor(o) for o in b[x]]

tds = dsd.with_transform(transformi)

In [None]:
ds = tds['train']
img = ds[0]['image']
show_image(img)

In [None]:
cf = collate_dict(ds)

def collate_(b):
    return to_device(cf(b))

bs = 256

def data_loaders(dsd, bs, **kwargs):
    return {k:DataLoader(v, bs, **kwargs) for k, v in dsd.items()}

dls = data_loaders(tds, bs, collate_fn=collate_)

In [None]:
dt = dls['train']
dv = dls['test']

xb, yb = next(iter(dt))

In [None]:
labels = ds.features[y].names
labels

In [None]:
lbl_getter = itemgetter(*yb[:16])
titles = lbl_getter(labels)

In [None]:
mpl.rcParams['figure.dpi'] = 70
show_images(xb[:16], imsize=1.7, titles=titles)

### Autoencoder

In [None]:
def deconv(ni, nf, ks=3, act=True):
    layers = [
        nn.UpsamplingNearest2d(scale_factor=2),
        nn.Conv2d(ni, nf, stride=1, kernel_size=ks, padding=ks//2)
    ]
    if act:
        layers.append(nn.ReLU())
    return nn.Sequential(*layers)

In [None]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):

    for i in range(epochs):

        model.train()
        for xb, _ in train_dl:
            pred = model(xb)
            loss = loss_func(pred, xb)
            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        with torch.no_grad():
            for xb, _ in valid_dl:
                pred = model(xb)
                loss = loss_func(pred, xb)

        print(f'{epochs} {loss:.3f}')

In [None]:
ae = nn.Sequential(
    nn.ZeroPad2d(2),
    conv(1,2),
    conv(2,4),
    conv(4,8),
    deconv(8, 4),
    deconv(4, 2),
    deconv(2, 1, act=False),
    nn.ZeroPad2d(-2),
    nn.Sigmoid()
).to(def_device)

In [None]:
from torch import optim

opt = optim.SGD(ae.parameters(), lr=0.01)
fit(5, ae, F.mse_loss, opt, dt, dv)

In [None]:
pred = ae(xb)
show_images(pred[:16].data.cpu(), imsize=1.5)

In [None]:
show_images(xb[:16].data.cpu(), imsize=1.5)