In [None]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial

from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from datasets import load_dataset,load_dataset_builder

from fastprogress import progress_bar,master_bar
from fastai_course.datasets import *
from fastai_course.training import *
from fastai_course.conv import *

In [None]:
from fastcore.test import test_close

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)

mpl.rcParams['image.cmap'] = 'gray'

import logging
logging.disable(logging.WARNING)

In [None]:
x,y = 'image','label'
name = 'fashion_mnist'
dsd = load_dataset(name, ignore_verifications=True)

In [None]:
dsd['train'][0]

In [None]:
show_image(dsd['train'][0]['image'], figsize=(1,1))

In [None]:
@inplace
def transformi(b):
    b[x] = [TF.to_tensor(o) for o in b[x]]

In [None]:
bs = 256
tds = dsd.with_transform(transformi)

In [None]:
show_image(tds['train'][0]['image'], figsize=(1,1));

In [None]:
ds = tds['train']

In [None]:
cf = collate_dict(ds)
cf

In [None]:
get = itemgetter(*ds.features)
get

In [None]:
a = get(default_collate(ds))

In [None]:
xb, yb = a

In [None]:
xb.shape

In [None]:
def collate_(b):
    return to_device(cf(b))

def data_loaders(tds, bs, **kwargs):
    return {k: DataLoader(v, bs, **kwargs) for k,v in tds.items()}

In [None]:
tds.items()

In [None]:
dls = data_loaders(tds, bs, collate_fn=collate_)

In [None]:
dt = dls['train']
dv = dls['test']
xb, yb = next(iter(dt))
xb.shape, yb.shape

In [None]:
labels = ds.features[y].names
labels

In [None]:
lbl_getter = itemgetter(*yb[:16])
lbl_getter

In [None]:
titles = lbl_getter(labels)
titles

In [None]:
yb[:16]

In [None]:
mpl.rcParams['figure.dpi'] = 70
show_images(xb[:16], imsize=1.7, titles=titles)

In [None]:
from torch import optim
bs = 256
lr = 0.4

In [None]:
cnn = nn.Sequential(
    conv(1, 4),  # 14 * 14
    conv(4, 8),  # 7 * 7
    conv(8, 16), # 4 * 4
    conv(16, 16), # 2 * 2
    conv(16, 10, act=False),
    nn.Flatten()
).to(def_device)

In [None]:
opt = optim.SGD(cnn.parameters(), lr=lr)
loss, acc = fit(5, cnn, F.cross_entropy, opt, dt, dv)

In [None]:
type(dsd['train'][0]['image'])

In [None]:
type(tds['train'][0]['image'])

In [None]:
torch.equal(TF.to_tensor(dsd['train'][0]['image']), tds['train'][0]['image'])

### AutoEncoder

In [None]:
def deconv(ni, nf, ks=3, act=True):
    layers = [
        nn.UpsamplingNearest2d(scale_factor=2),
        nn.Conv2d(ni, nf, stride=1, kernel_size=ks, padding=ks//2),
    ]
    if act: layers.append(nn.ReLU())
    return nn.Sequential(*layers)

In [None]:
xb.shape

In [None]:
conv_layer = nn.Sequential(
    nn.ZeroPad2d(2),
    conv(1, 2),
    conv(2, 4),
).to(def_device)
encoded_out = conv_layer(xb)
encoded_out.shape

In [None]:
deconv_layer = nn.Sequential(
    deconv(4, 2),  # (2, 16, 16)
    deconv(2, 1, act=False), # (1, 32, 32)
    nn.ZeroPad2d(-2),  # (1, 28, 289)
    nn.Sigmoid()
).to(def_device)

In [None]:
deconv_layer(encoded_out).shape

In [None]:
def eval(model, loss_func, valid_dl, epoch=0):
    model.eval()
    with torch.no_grad():
        total_loss, count = 0., 0
        for xb, _ in valid_dl:
            pred = model(xb)
            loss = loss_func(pred, xb).item()
            count += len(xb)
            total_loss += loss * len(xb)
    
    print(epoch, f'{total_loss/count:.3f}')

In [None]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for i in range(epochs):
        model.train()
        for xb, _ in train_dl:
            pred = model(xb)
            loss = loss_func(pred, xb)
            opt.zero_grad()
            loss.backward()
            opt.step()
        eval(model, loss_func, valid_dl, i)

In [None]:
ae = nn.Sequential(
    nn.ZeroPad2d(2),
    conv(1, 2),
    conv(2, 4), # (4, 8, 8)
    # conv(4, 8),  # (8, 4, 4)
    
    # deconv(8, 4), # (4, 8, 8)
    deconv(4, 2),  # (2, 16, 16)
    deconv(2, 1, act=False), # (1, 32, 32)
    nn.ZeroPad2d(-2),  # (1, 28, 289)
    nn.Sigmoid()
).to(def_device)

In [None]:
eval(ae, F.mse_loss, dv)

In [None]:
opt = optim.SGD(ae.parameters(), lr=0.01)
fit(5, ae, F.mse_loss, opt, dt, dv)

In [None]:
opt = optim.SGD(ae.parameters(), lr=0.1)
fit(15, ae, F.mse_loss, opt, dt, dv)

In [None]:
p = ae(xb)
show_images(p[:16].data.cpu(), imsize=1.5)

In [None]:
show_images(xb[:16], imsize=1.5)

In [None]:
p[:16].shape