In [1]:
# install: tqdm (progress bars)
!pip install tqdm

import torch
import torch.nn as nn
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision.datasets as ds

## Load the data (CIFAR-10)

In [2]:
def load_cifar(datadir='./data_cache'): # will download ~400MB of data into this dir. Change the dir if neccesary. If using paperspace, you can make this /storage
    train_ds = ds.CIFAR10(root=datadir, train=True,
                           download=True, transform=None)
    test_ds = ds.CIFAR10(root=datadir, train=False,
                          download=True, transform=None)

    def to_xy(dataset):
        X = torch.Tensor(np.transpose(dataset.data, (0, 3, 1, 2))).float() / 255.0  # [0, 1]
        Y = torch.Tensor(np.array(dataset.targets)).long()
        return X, Y

    X_tr, Y_tr = to_xy(train_ds)
    X_te, Y_te = to_xy(test_ds)
    return X_tr, Y_tr, X_te, Y_te

def make_loader(dataset, batch_size=128):
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
            shuffle=True, num_workers=4, pin_memory=True)

X_tr, Y_tr, X_te, Y_te = load_cifar()
train_dl = make_loader(TensorDataset(X_tr, Y_tr))
test_dl = make_loader(TensorDataset(X_te, Y_te))

Files already downloaded and verified
Files already downloaded and verified


## Training helper functions

In [12]:
def train_epoch(model, train_dl : DataLoader, opt, k = 50):
    ''' Trains model for one epoch on the provided dataloader, with optimizer opt. Logs stats every k batches.'''
    loss_func = nn.CrossEntropyLoss()
    model.train()
    model.cuda()

    netLoss = 0.0
    nCorrect = 0
    nTotal = 0
    for i, (xB, yB) in enumerate(tqdm(train_dl)):
        opt.zero_grad()
        xB, yB = xB.cuda(), yB.cuda()
        outputs = model(xB)
        loss = loss_func(outputs, yB)
        loss.backward()
        opt.step()
        netLoss += loss.item() * len(xB)
        with torch.no_grad():
            _, preds = torch.max(outputs, dim=1)
            nCorrect += (preds == yB).float().sum()
            nTotal += preds.size(0)
        
        if (i+1) % k == 0:
            train_acc = nCorrect/nTotal
            avg_loss = netLoss/nTotal
            print(f'\t [Batch {i+1} / {len(train_dl)}] Train Loss: {avg_loss:.3f} \t Train Acc: {train_acc:.3f}')
  
    train_acc = nCorrect/nTotal
    avg_loss = netLoss/nTotal
    return avg_loss, train_acc


def evaluate(model, test_dl, loss_func=nn.CrossEntropyLoss().cuda()):
    ''' Returns loss, acc'''
    model.eval()
    model.cuda()
    nCorrect = 0.0
    nTotal = 0
    net_loss = 0.0
    with torch.no_grad():
        for (xb, yb) in test_dl:
            xb, yb = xb.cuda(), yb.cuda()
            outputs = model(xb)
            loss = len(xb) * loss_func(outputs, yb)
            _, preds = torch.max(outputs, dim=1)
            nCorrect += (preds == yb).float().sum()
            net_loss += loss
            nTotal += preds.size(0)

    acc = nCorrect.cpu().item() / float(nTotal)
    loss = net_loss.cpu().item() / float(nTotal)
    return loss, acc

In [7]:
## Define model

In [8]:
## 5-Layer CNN for CIFAR
## This is the Myrtle5 network by David Page (https://myrtle.ai/learn/how-to-train-your-resnet-4-architecture/)

class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), x.size(1))

def make_cnn(c=64, num_classes=10):
    ''' Returns a 5-layer CNN with width parameter c. '''
    return nn.Sequential(
        # Layer 0
        nn.Conv2d(3, c, kernel_size=3, stride=1,
                  padding=1, bias=True),
        nn.BatchNorm2d(c),
        nn.ReLU(),

        # Layer 1
        nn.Conv2d(c, c*2, kernel_size=3,
                  stride=1, padding=1, bias=True),
        nn.BatchNorm2d(c*2),
        nn.ReLU(),
        nn.MaxPool2d(2),

        # Layer 2
        nn.Conv2d(c*2, c*4, kernel_size=3,
                  stride=1, padding=1, bias=True),
        nn.BatchNorm2d(c*4),
        nn.ReLU(),
        nn.MaxPool2d(2),

        # Layer 3
        nn.Conv2d(c*4, c*8, kernel_size=3,
                  stride=1, padding=1, bias=True),
        nn.BatchNorm2d(c*8),
        nn.ReLU(),
        nn.MaxPool2d(2),

        # Layer 4
        nn.MaxPool2d(4),
        Flatten(),
        nn.Linear(c*8, num_classes, bias=True)
    )

In [9]:
## Train

In [14]:
model = make_cnn()
opt = torch.optim.SGD(model.parameters(), lr=0.1)
epochs = 20
for i in range(epochs):
    print(f'Starting Epoch {i}')
    train_loss, train_acc = train_epoch(model, train_dl, opt)
    test_loss, test_acc = evaluate(model, test_dl)
    
    print(f'Epoch {i}:\t Train Loss: {train_loss:.3f} \t Train Acc: {train_acc:.3f}\t Test Acc: {test_acc:.3f}')

Starting Epoch 0


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 7.103 	 Train Acc: 0.140
	 [Batch 100 / 391] Train Loss: 4.647 	 Train Acc: 0.176
	 [Batch 150 / 391] Train Loss: 3.766 	 Train Acc: 0.205
	 [Batch 200 / 391] Train Loss: 3.300 	 Train Acc: 0.228
	 [Batch 250 / 391] Train Loss: 2.999 	 Train Acc: 0.249
	 [Batch 300 / 391] Train Loss: 2.781 	 Train Acc: 0.271
	 [Batch 350 / 391] Train Loss: 2.624 	 Train Acc: 0.288

Epoch 0:	 Train Loss: 2.517 	 Train Acc: 0.303	 Test Acc: 0.383
Starting Epoch 1


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 1.542 	 Train Acc: 0.427
	 [Batch 100 / 391] Train Loss: 1.545 	 Train Acc: 0.435
	 [Batch 150 / 391] Train Loss: 1.509 	 Train Acc: 0.448
	 [Batch 200 / 391] Train Loss: 1.494 	 Train Acc: 0.453
	 [Batch 250 / 391] Train Loss: 1.470 	 Train Acc: 0.461
	 [Batch 300 / 391] Train Loss: 1.451 	 Train Acc: 0.470
	 [Batch 350 / 391] Train Loss: 1.428 	 Train Acc: 0.480

Epoch 1:	 Train Loss: 1.411 	 Train Acc: 0.486	 Test Acc: 0.424
Starting Epoch 2


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 1.203 	 Train Acc: 0.565
	 [Batch 100 / 391] Train Loss: 1.205 	 Train Acc: 0.568
	 [Batch 150 / 391] Train Loss: 1.185 	 Train Acc: 0.580
	 [Batch 200 / 391] Train Loss: 1.172 	 Train Acc: 0.584
	 [Batch 250 / 391] Train Loss: 1.158 	 Train Acc: 0.588
	 [Batch 300 / 391] Train Loss: 1.147 	 Train Acc: 0.592
	 [Batch 350 / 391] Train Loss: 1.137 	 Train Acc: 0.595

Epoch 2:	 Train Loss: 1.120 	 Train Acc: 0.602	 Test Acc: 0.554
Starting Epoch 3


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.987 	 Train Acc: 0.656
	 [Batch 100 / 391] Train Loss: 0.972 	 Train Acc: 0.658
	 [Batch 150 / 391] Train Loss: 0.970 	 Train Acc: 0.659
	 [Batch 200 / 391] Train Loss: 0.962 	 Train Acc: 0.663
	 [Batch 250 / 391] Train Loss: 0.954 	 Train Acc: 0.666
	 [Batch 300 / 391] Train Loss: 0.947 	 Train Acc: 0.669
	 [Batch 350 / 391] Train Loss: 0.942 	 Train Acc: 0.670

Epoch 3:	 Train Loss: 0.933 	 Train Acc: 0.674	 Test Acc: 0.644
Starting Epoch 4


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.830 	 Train Acc: 0.710
	 [Batch 100 / 391] Train Loss: 0.826 	 Train Acc: 0.714
	 [Batch 150 / 391] Train Loss: 0.806 	 Train Acc: 0.719
	 [Batch 200 / 391] Train Loss: 0.807 	 Train Acc: 0.718
	 [Batch 250 / 391] Train Loss: 0.794 	 Train Acc: 0.722
	 [Batch 300 / 391] Train Loss: 0.792 	 Train Acc: 0.724
	 [Batch 350 / 391] Train Loss: 0.789 	 Train Acc: 0.725

Epoch 4:	 Train Loss: 0.781 	 Train Acc: 0.728	 Test Acc: 0.605
Starting Epoch 5


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.717 	 Train Acc: 0.747
	 [Batch 100 / 391] Train Loss: 0.703 	 Train Acc: 0.757
	 [Batch 150 / 391] Train Loss: 0.690 	 Train Acc: 0.761
	 [Batch 200 / 391] Train Loss: 0.688 	 Train Acc: 0.763
	 [Batch 250 / 391] Train Loss: 0.690 	 Train Acc: 0.763
	 [Batch 300 / 391] Train Loss: 0.685 	 Train Acc: 0.764
	 [Batch 350 / 391] Train Loss: 0.685 	 Train Acc: 0.763

Epoch 5:	 Train Loss: 0.682 	 Train Acc: 0.764	 Test Acc: 0.592
Starting Epoch 6


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.574 	 Train Acc: 0.803
	 [Batch 100 / 391] Train Loss: 0.569 	 Train Acc: 0.802
	 [Batch 150 / 391] Train Loss: 0.570 	 Train Acc: 0.804
	 [Batch 200 / 391] Train Loss: 0.572 	 Train Acc: 0.802
	 [Batch 250 / 391] Train Loss: 0.574 	 Train Acc: 0.801
	 [Batch 300 / 391] Train Loss: 0.580 	 Train Acc: 0.800
	 [Batch 350 / 391] Train Loss: 0.583 	 Train Acc: 0.799

Epoch 6:	 Train Loss: 0.583 	 Train Acc: 0.798	 Test Acc: 0.573
Starting Epoch 7


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.501 	 Train Acc: 0.832
	 [Batch 100 / 391] Train Loss: 0.499 	 Train Acc: 0.832
	 [Batch 150 / 391] Train Loss: 0.506 	 Train Acc: 0.830
	 [Batch 200 / 391] Train Loss: 0.518 	 Train Acc: 0.825
	 [Batch 250 / 391] Train Loss: 0.512 	 Train Acc: 0.827
	 [Batch 300 / 391] Train Loss: 0.515 	 Train Acc: 0.825
	 [Batch 350 / 391] Train Loss: 0.509 	 Train Acc: 0.827

Epoch 7:	 Train Loss: 0.509 	 Train Acc: 0.826	 Test Acc: 0.614
Starting Epoch 8


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.426 	 Train Acc: 0.851
	 [Batch 100 / 391] Train Loss: 0.420 	 Train Acc: 0.854
	 [Batch 150 / 391] Train Loss: 0.426 	 Train Acc: 0.853
	 [Batch 200 / 391] Train Loss: 0.422 	 Train Acc: 0.855
	 [Batch 250 / 391] Train Loss: 0.424 	 Train Acc: 0.854
	 [Batch 300 / 391] Train Loss: 0.427 	 Train Acc: 0.854
	 [Batch 350 / 391] Train Loss: 0.433 	 Train Acc: 0.851

Epoch 8:	 Train Loss: 0.429 	 Train Acc: 0.853	 Test Acc: 0.742
Starting Epoch 9


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.391 	 Train Acc: 0.865
	 [Batch 100 / 391] Train Loss: 0.373 	 Train Acc: 0.871
	 [Batch 150 / 391] Train Loss: 0.381 	 Train Acc: 0.868
	 [Batch 200 / 391] Train Loss: 0.379 	 Train Acc: 0.869
	 [Batch 250 / 391] Train Loss: 0.379 	 Train Acc: 0.870
	 [Batch 300 / 391] Train Loss: 0.379 	 Train Acc: 0.871
	 [Batch 350 / 391] Train Loss: 0.382 	 Train Acc: 0.870

Epoch 9:	 Train Loss: 0.377 	 Train Acc: 0.872	 Test Acc: 0.775
Starting Epoch 10


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.287 	 Train Acc: 0.906
	 [Batch 100 / 391] Train Loss: 0.300 	 Train Acc: 0.900
	 [Batch 150 / 391] Train Loss: 0.297 	 Train Acc: 0.899
	 [Batch 200 / 391] Train Loss: 0.297 	 Train Acc: 0.899
	 [Batch 250 / 391] Train Loss: 0.293 	 Train Acc: 0.900
	 [Batch 300 / 391] Train Loss: 0.294 	 Train Acc: 0.901
	 [Batch 350 / 391] Train Loss: 0.298 	 Train Acc: 0.900

Epoch 10:	 Train Loss: 0.298 	 Train Acc: 0.900	 Test Acc: 0.688
Starting Epoch 11


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.226 	 Train Acc: 0.934
	 [Batch 100 / 391] Train Loss: 0.220 	 Train Acc: 0.932
	 [Batch 150 / 391] Train Loss: 0.221 	 Train Acc: 0.931
	 [Batch 200 / 391] Train Loss: 0.227 	 Train Acc: 0.929
	 [Batch 250 / 391] Train Loss: 0.232 	 Train Acc: 0.927
	 [Batch 300 / 391] Train Loss: 0.237 	 Train Acc: 0.925
	 [Batch 350 / 391] Train Loss: 0.235 	 Train Acc: 0.925

Epoch 11:	 Train Loss: 0.241 	 Train Acc: 0.923	 Test Acc: 0.757
Starting Epoch 12


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.114 	 Train Acc: 0.964
	 [Batch 100 / 391] Train Loss: 0.157 	 Train Acc: 0.955
	 [Batch 150 / 391] Train Loss: 0.183 	 Train Acc: 0.949
	 [Batch 200 / 391] Train Loss: 0.180 	 Train Acc: 0.947
	 [Batch 250 / 391] Train Loss: 0.197 	 Train Acc: 0.942
	 [Batch 300 / 391] Train Loss: 0.199 	 Train Acc: 0.940
	 [Batch 350 / 391] Train Loss: 0.195 	 Train Acc: 0.940

Epoch 12:	 Train Loss: 0.196 	 Train Acc: 0.939	 Test Acc: 0.530
Starting Epoch 13


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.203 	 Train Acc: 0.956
	 [Batch 100 / 391] Train Loss: 0.141 	 Train Acc: 0.967
	 [Batch 150 / 391] Train Loss: 0.123 	 Train Acc: 0.969
	 [Batch 200 / 391] Train Loss: 0.129 	 Train Acc: 0.965
	 [Batch 250 / 391] Train Loss: 0.132 	 Train Acc: 0.963
	 [Batch 300 / 391] Train Loss: 0.132 	 Train Acc: 0.961
	 [Batch 350 / 391] Train Loss: 0.130 	 Train Acc: 0.962

Epoch 13:	 Train Loss: 0.157 	 Train Acc: 0.956	 Test Acc: 0.792
Starting Epoch 14


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.067 	 Train Acc: 0.981
	 [Batch 100 / 391] Train Loss: 0.059 	 Train Acc: 0.984
	 [Batch 150 / 391] Train Loss: 0.071 	 Train Acc: 0.981
	 [Batch 200 / 391] Train Loss: 0.064 	 Train Acc: 0.983
	 [Batch 250 / 391] Train Loss: 0.063 	 Train Acc: 0.983
	 [Batch 300 / 391] Train Loss: 0.059 	 Train Acc: 0.985
	 [Batch 350 / 391] Train Loss: 0.065 	 Train Acc: 0.984

Epoch 14:	 Train Loss: 0.065 	 Train Acc: 0.983	 Test Acc: 0.764
Starting Epoch 15


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.022 	 Train Acc: 0.997
	 [Batch 100 / 391] Train Loss: 0.019 	 Train Acc: 0.998
	 [Batch 150 / 391] Train Loss: 0.018 	 Train Acc: 0.998
	 [Batch 200 / 391] Train Loss: 0.017 	 Train Acc: 0.998
	 [Batch 250 / 391] Train Loss: 0.017 	 Train Acc: 0.998
	 [Batch 300 / 391] Train Loss: 0.017 	 Train Acc: 0.998
	 [Batch 350 / 391] Train Loss: 0.016 	 Train Acc: 0.998

Epoch 15:	 Train Loss: 0.016 	 Train Acc: 0.998	 Test Acc: 0.850
Starting Epoch 16


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.007 	 Train Acc: 1.000
	 [Batch 100 / 391] Train Loss: 0.006 	 Train Acc: 1.000
	 [Batch 150 / 391] Train Loss: 0.006 	 Train Acc: 1.000
	 [Batch 200 / 391] Train Loss: 0.007 	 Train Acc: 1.000
	 [Batch 250 / 391] Train Loss: 0.006 	 Train Acc: 1.000
	 [Batch 300 / 391] Train Loss: 0.006 	 Train Acc: 1.000
	 [Batch 350 / 391] Train Loss: 0.006 	 Train Acc: 1.000

Epoch 16:	 Train Loss: 0.006 	 Train Acc: 1.000	 Test Acc: 0.855
Starting Epoch 17


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.004 	 Train Acc: 1.000
	 [Batch 100 / 391] Train Loss: 0.004 	 Train Acc: 1.000
	 [Batch 150 / 391] Train Loss: 0.004 	 Train Acc: 1.000
	 [Batch 200 / 391] Train Loss: 0.004 	 Train Acc: 1.000
	 [Batch 250 / 391] Train Loss: 0.004 	 Train Acc: 1.000
	 [Batch 300 / 391] Train Loss: 0.004 	 Train Acc: 1.000
	 [Batch 350 / 391] Train Loss: 0.004 	 Train Acc: 1.000

Epoch 17:	 Train Loss: 0.004 	 Train Acc: 1.000	 Test Acc: 0.858
Starting Epoch 18


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.003 	 Train Acc: 1.000
	 [Batch 100 / 391] Train Loss: 0.003 	 Train Acc: 1.000
	 [Batch 150 / 391] Train Loss: 0.003 	 Train Acc: 1.000
	 [Batch 200 / 391] Train Loss: 0.003 	 Train Acc: 1.000
	 [Batch 250 / 391] Train Loss: 0.003 	 Train Acc: 1.000
	 [Batch 300 / 391] Train Loss: 0.003 	 Train Acc: 1.000
	 [Batch 350 / 391] Train Loss: 0.003 	 Train Acc: 1.000

Epoch 18:	 Train Loss: 0.003 	 Train Acc: 1.000	 Test Acc: 0.858
Starting Epoch 19


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

	 [Batch 50 / 391] Train Loss: 0.002 	 Train Acc: 1.000
	 [Batch 100 / 391] Train Loss: 0.002 	 Train Acc: 1.000
	 [Batch 150 / 391] Train Loss: 0.002 	 Train Acc: 1.000
	 [Batch 200 / 391] Train Loss: 0.002 	 Train Acc: 1.000
	 [Batch 250 / 391] Train Loss: 0.002 	 Train Acc: 1.000
	 [Batch 300 / 391] Train Loss: 0.002 	 Train Acc: 1.000
	 [Batch 350 / 391] Train Loss: 0.002 	 Train Acc: 1.000

Epoch 19:	 Train Loss: 0.002 	 Train Acc: 1.000	 Test Acc: 0.857
