In [None]:
#|default_exp conv

In [None]:
#|export
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import default_collate
from typing import Mapping

from fastai_course.training import *
from fastai_course.datasets import *

In [None]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl, numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from torch import tensor

from torch.utils.data import DataLoader

In [None]:
mpl.rcParams['image.cmap'] = 'gray'

In [None]:
path_gz = Path('data') / 'mnist.pkl.gz'
path_gz

In [None]:
with gzip.open(path_gz, 'rb') as f:
    ((x_train, y_train), (x_valid, y_valid), (x_test, y_test)) = pickle.load(f, encoding='latin-1')

In [None]:
x_train.shape

In [None]:
x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])

In [None]:
x_train.shape

In [None]:
x_imgs = x_train.view(-1, 1, 28, 28)
xv_imgs = x_valid.view(-1, 1, 28, 28)

In [None]:
type(x_train)

In [None]:
mpl.rcParams['figure.dpi'] = 30

In [None]:
im3 = x_imgs[7]
show_image(im3)

In [None]:
top_edge = tensor([[-1, -1, -1],
                   [0, 0, 0],
                   [1, 1, 1]]).float()

In [None]:
show_image(top_edge, noFrames=False);

In [None]:
df = pd.DataFrame(im3[:13, :23])
df.style.format(precision=2).set_properties(**{'font-size': '7pt'}).background_gradient('Greys')

In [None]:
(im3[3:6,14:17] * top_edge).sum()

In [None]:
show_image(im3[7:10,14:17])

In [None]:
(im3[7:10,14:17] * top_edge).sum()

In [None]:
def apply_kernel(row, col, kernel):
    return (im3[row - 1: row + 2, col - 1: col+2] * kernel).sum()

In [None]:
apply_kernel(4, 15, top_edge)

In [None]:
[[(i, j) for j in range(5)] for i in range(5)]

In [None]:
rng = range(1, 27)
top_edge_3 = tensor([[apply_kernel(i, j, top_edge) for j in rng] for i in rng])
show_image(top_edge_3);

In [None]:
show_image(im3)

In [None]:
left_edge = tensor([[-1, 0, 1],
                   [-1, 0, 0],
                   [-1, 0, 1]]).float()

In [None]:
rng = range(1, 27)
top_edge_3 = tensor([[apply_kernel(i, j, left_edge) for j in rng] for i in rng])
show_image(top_edge_3);

In [None]:
top_edge_3.shape

In [None]:
im3.shape

In [None]:
diag1_edge = tensor([[ 0,-1, 1],
                     [-1, 1, 0],
                     [ 1, 0, 0]]).float()

In [None]:
diag2_edge = tensor([[ 1,-1, 0],
                     [0, 1, -1],
                     [ 1, 0, 1]]).float()

In [None]:
xb = x_imgs[:16][:, None]
xb.shape

In [None]:
edge_kernels = torch.stack([left_edge, top_edge, diag1_edge, diag2_edge])[:, None]
edge_kernels.shape

In [None]:
batch_features = F.conv2d(xb, edge_kernels)
batch_features.shape

In [None]:
img0 = xb[1, 0]
show_image(img0);

In [None]:
show_images([batch_features[1, i] for i in range(4)])

In [None]:
n, m = x_train.shape
c = y_train.max() + 1
nh = 50
n, m, nh

In [None]:
model = nn.Sequential(
    nn.Linear(m, nh),
    nn.ReLU(),
    nn.Linear(nh, 10)
)

In [None]:
#|export
def conv(ni, nf, ks=3, stride=2, act=True):
    res = nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2)
    if act: res = nn.Sequential(res, nn.ReLU())
    return res

In [None]:
xb.shape

In [None]:
cnn_1 = conv(1, 4)
cnn_1(xb).shape

In [None]:
simple_cnn = nn.Sequential(
    conv(1, 4),  # 14 * 14
    conv(4, 8),  # 7 * 7
    conv(8, 16), # 4 * 4
    conv(16, 16), # 2 * 2
    conv(16, 10, act=False),
    nn.Flatten()
)

In [None]:
simple_cnn(xb).shape

In [None]:
train_ds, valid_ds = Dataset(x_imgs, y_train), Dataset(xv_imgs, y_valid)

In [None]:
#|export
def_device = "mps" if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k: v.to(device) for k, v in x.items()}
    return type(x)(to_device(o, device) for o in x)

def collate_device(b):
    return to_device(default_collate(b))

In [None]:
def_device

In [None]:
default_collate(
    [torch.randn(2, 3),
    torch.randn(2, 3)]
).shape

In [None]:
from torch import optim
bs = 256
lr = 0.4
train_dl, valid_dl = get_dls(train_ds, valid_ds, bs, collate_fn=collate_device)
opt = optim.SGD(simple_cnn.parameters(), lr=lr)

In [None]:
for images in train_dl:
    print(images[0].shape)
    break

In [None]:
loss, acc = fit(5, simple_cnn.to(def_device), F.cross_entropy, opt, train_dl, valid_dl)

In [None]:
opt = optim.SGD(simple_cnn.parameters(), lr=lr/4)
loss, acc = fit(5, simple_cnn.to(def_device), F.cross_entropy, opt, train_dl, valid_dl)

In [None]:
import nbdev; nbdev.nbdev_export()