In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [20]:
from fastai.vision import *

In [2]:
train_ds = torchvision.datasets.MNIST(
    root='../data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: ../data
    Split: Train

In [3]:
test_ds = torchvision.datasets.MNIST(
    root='../data',
    train=False,
    transform=transforms.ToTensor()
)
test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: ../data
    Split: Test

In [4]:
train_dl = torch.utils.data.DataLoader(train_ds,32,True)
test_dl = torch.utils.data.DataLoader(test_ds,32,False)

In [5]:
input_size = train_ds[0][0].numel()

In [6]:
train_ds[0][0].size()

torch.Size([1, 28, 28])

In [7]:
def conv(ni,nf): return nn.Conv2d(ni,nf,kernel_size=3,stride=2,padding=1)

In [22]:
def conv2(ni,nf): return conv_layer(ni,nf,stride=2)

In [9]:
class Flatten(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, x): return x.view(x.size(0), -1)

In [11]:
model = nn.Sequential(
    conv(1,8),
    nn.BatchNorm2d(8),
    nn.ReLU(),
    conv(8,16),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    conv(16,32),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    conv(32,16),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    conv(16,10),
    nn.BatchNorm2d(10),
    Flatten()
)

In [33]:
model = nn.Sequential(
    conv2(1,8),
    conv2(8,16),
    conv2(16,32),
    conv2(32,16),
    conv2(16,10),
    Flatten()
)

In [34]:
class ResBlock(nn.Module):
    def __init__(self,nf):
        super().__init__()
        self.conv1 = conv_layer(nf,nf)
        self.conv2 = conv_layer(nf,nf)
    def forward(self,x): return x + self.conv2(self.conv1)

In [35]:
model

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ReLU(inplace)
    (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): ResBlock(
    (conv1): Sequential(
      (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): ReLU(inplace)
      (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv2): Sequential(
      (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): ReLU(inplace)
      (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (2): Sequential(
    (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ReLU(inplace)
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (3): ResBlock(
    (conv1): Sequential(
      (0): Conv2d(1

In [36]:
images = next(iter(train_dl))

In [37]:
model(images[0]);

TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not Sequential

In [28]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

In [29]:
from tqdm import tqdm_notebook as tqdm

In [30]:
def one_epoch_man(model,train_dl,test_dl):
    model = model.cuda()
    for i,(images,labels) in enumerate(tqdm(train_dl)):
        model.train()
        images,labels = images.cuda(),labels.cuda()
        optimizer.zero_grad()

        preds = model(images)

        loss = criterion(preds,labels)

        loss.backward()
        optimizer.step()
        if(i%500==0):
            correct = 0
            total = 0
            print(f'loss: {loss.item()}')
            for img,y in test_dl:
                model.eval()
                img,y = img.cuda(),y.cuda()
                preds = model(img)
                preds = preds.argmax(1)
                total += preds.size(0)
                correct += (preds==y).sum()
            print(f'{correct} / {total}')

In [31]:
for i in range(10):
    one_epoch_man(model,train_dl,test_dl)

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 2.18306040763855
1031 / 10000
loss: 0.7011539936065674
9256 / 10000
loss: 0.350107878446579
9570 / 10000
loss: 0.280245304107666
9663 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.15156999230384827
9737 / 10000
loss: 0.13944827020168304
9760 / 10000
loss: 0.15151086449623108
9807 / 10000
loss: 0.18694138526916504
9791 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.12928009033203125
9791 / 10000
loss: 0.08030291646718979
9807 / 10000
loss: 0.17070308327674866
9831 / 10000
loss: 0.07393885403871536
9807 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.13530823588371277
9807 / 10000
loss: 0.15191644430160522
9836 / 10000
loss: 0.04462313652038574
9818 / 10000
loss: 0.09935211390256882
9860 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.08039829879999161
9840 / 10000
loss: 0.13140355050563812
9832 / 10000
loss: 0.0300586000084877
9850 / 10000
loss: 0.13933993875980377
9854 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.02623194456100464
9849 / 10000
loss: 0.05290022864937782
9873 / 10000
loss: 0.08444135636091232
9866 / 10000
loss: 0.04485201835632324
9860 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.06801018863916397
9867 / 10000
loss: 0.0839841291308403
9865 / 10000
loss: 0.2426808476448059
9863 / 10000
loss: 0.060018353164196014
9858 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.13609784841537476
9863 / 10000
loss: 0.01757301390171051
9868 / 10000
loss: 0.07891058176755905
9870 / 10000
loss: 0.02109910547733307
9874 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.12078457325696945
9883 / 10000
loss: 0.07468067854642868
9860 / 10000
loss: 0.060308732092380524
9879 / 10000
loss: 0.01157332956790924
9867 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.008969686925411224
9885 / 10000
loss: 0.020272694528102875
9883 / 10000
loss: 0.15701386332511902
9886 / 10000
loss: 0.07180352509021759
9888 / 10000

