In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
from fastai.vision import *

In [3]:
train_ds = torchvision.datasets.MNIST(
    root='../data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: ../data
    Split: Train

In [4]:
test_ds = torchvision.datasets.MNIST(
    root='../data',
    train=False,
    transform=transforms.ToTensor()
)
test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: ../data
    Split: Test

In [5]:
train_dl = torch.utils.data.DataLoader(train_ds,32,True)
test_dl = torch.utils.data.DataLoader(test_ds,32,False)

In [6]:
input_size = train_ds[0][0].numel()

In [7]:
train_ds[0][0].size()

torch.Size([1, 28, 28])

In [8]:
def conv(ni,nf): return nn.Conv2d(ni,nf,kernel_size=3,stride=2,padding=1)

In [9]:
def conv2(ni,nf): return conv_layer(ni,nf,stride=2)

In [10]:
class Flatten(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, x): return x.view(x.size(0), -1)

In [11]:
model = nn.Sequential(
    conv(1,8),
    nn.BatchNorm2d(8),
    nn.ReLU(),
    conv(8,16),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    conv(16,32),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    conv(32,16),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    conv(16,10),
    nn.BatchNorm2d(10),
    Flatten()
)

In [12]:
model = nn.Sequential(
    conv2(1,8),
    res_block(8),
    conv2(8,16),
    res_block(16),
    conv2(16,32),
    res_block(32),
    conv2(32,16),
    res_block(16),
    conv2(16,10),
    Flatten()
)

In [13]:
images = next(iter(train_dl))

In [14]:
model(images[0]);

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=3e-2)

In [17]:
from tqdm import tqdm_notebook as tqdm

In [18]:
def one_epoch_man(model,train_dl,test_dl):
    model = model.cuda()
    for i,(images,labels) in enumerate(tqdm(train_dl)):
        model.train()
        images,labels = images.cuda(),labels.cuda()
        optimizer.zero_grad()

        preds = model(images)

        loss = criterion(preds,labels)

        loss.backward()
        optimizer.step()
        if(i%500==0):
            correct = 0
            total = 0
            print(f'loss: {loss.item()}')
    for img,y in test_dl:
        model.eval()
        img,y = img.cuda(),y.cuda()
        preds = model(img)
        preds = preds.argmax(1)
        total += preds.size(0)
        correct += (preds==y).sum()
    print(f'{correct} / {total}')

In [19]:
for i in range(10):
    one_epoch_man(model,train_dl,test_dl)

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 2.6801140308380127
loss: 0.11926175653934479
loss: 0.12449122965335846
loss: 0.06954637914896011

9775 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.018819577991962433
loss: 0.01750928908586502
loss: 0.0361352264881134
loss: 0.07246145606040955

9848 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.16146010160446167
loss: 0.00641825795173645
loss: 0.23205575346946716
loss: 0.013303600251674652

9832 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.07015632092952728
loss: 0.04091142863035202
loss: 0.0431848019361496
loss: 0.00877080112695694

9872 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.04702369496226311
loss: 0.015446983277797699
loss: 0.1592365801334381
loss: 0.03721669316291809

9878 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.006892001256346703
loss: 0.022715725004673004
loss: 0.005583405494689941
loss: 0.0028140731155872345

9883 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.0026420727372169495
loss: 0.03084160014986992
loss: 0.09192574769258499
loss: 0.012433307245373726

9867 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.07179925590753555
loss: 0.0022024735808372498
loss: 0.0019204914569854736
loss: 0.012230418622493744

9893 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.08792822808027267
loss: 0.021009014919400215
loss: 0.2804306149482727
loss: 0.0027601998299360275

9901 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.010432697832584381
loss: 0.032034505158662796
loss: 0.003544360399246216
loss: 0.000725671648979187

9891 / 10000
