In [1]:
import torch, torchvision
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
import time

In [2]:
# Training settings
bs = 64
lr = 0.001
num_epoch = 5
num_classes = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Training using ' + device)

Training using cuda


In [4]:
# Load MNIST dataset
mnist_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data/mnist/', train=True,
                                          download=True,
                                          transform=mnist_transforms)

test_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=False,
                                         download=True,
                                         transform=mnist_transforms)

train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=bs)

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


 79%|███████▉  | 7864320/9912422 [00:16<00:12, 159372.95it/s] 

Extracting ./data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw



0it [00:00, ?it/s][A

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0/28881 [00:00<?, ?it/s][A
32768it [00:00, 51307.18it/s]                           [A

0it [00:00, ?it/s][A

Extracting ./data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0/1648877 [00:00<?, ?it/s][A
  1%|          | 16384/1648877 [00:00<00:20, 80331.10it/s][A
  2%|▏         | 40960/1648877 [00:00<00:17, 93476.54it/s][A
  6%|▌         | 98304/1648877 [00:01<00:13, 113777.87it/s][A
 11%|█         | 180224/1648877 [00:01<00:11, 130765.88it/s][A
 18%|█▊        | 294912/1648877 [00:01<00:07, 169824.21it/s][A
 25%|██▍       | 409600/1648877 [00:01<00:05, 214689.01it/s][A
 30%|██▉       | 491520/1648877 [00:02<00:05, 213173.98it/s][A
 36%|███▋      | 598016/1648877 [00:02<00:04, 255506.31it/s][A
 43%|████▎     | 704512/1648877 [00:02<00:03, 312864.49it/s][A
 48%|████▊     | 786432/1648877 [00:02<00:02, 315525.35it/s][A
 53%|█████▎    | 868352/1648877 [00:03<00:02, 339811.84it/s][A
 58%|█████▊    | 958464/1648877 [00:03<00:01, 364761.55it/s][A
 63%|██████▎   | 1040384/1648877 [00:03<00:01, 374613.28it/s][A
 69%|██████▊   | 1130496/1648877 [00:03<00:01, 309356.34it/s][A
 78%|███████▊  | 1277952/1648877 [00:04<00:00, 373323.51it/

Extracting ./data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz




8192it [00:00, 20501.27it/s]            [A[A


Extracting ./data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw
Processing...
Done!


9920512it [00:30, 159372.95it/s]                             
1654784it [00:22, 345895.97it/s]                             [A

In [47]:
# Training loop
def train(model, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        for idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and Update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Print the loss
            if (idx) % bs-1 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}] | Batch [{idx*len(images)}/{len(train_loader.dataset)}] | Loss: {loss.item():.4f}')

In [37]:
def test(model, criterion):
    model.eval()
    with torch.no_grad():
        correct = 0
        test_loss = 0
        
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            out = model(images)
            # total loss
            test_loss += criterion(out, labels)
            # get the index of the max value, calculate how many accurate predictions
            pred = out.data.max(1, keepdim=True)[1]
            correct += pred.eq(labels.data.view_as(pred)).cpu().sum()
            
        # Average loss for the whole test 10000 images    
        test_loss /= len(test_loader.dataset)
        print("==========================")
        print(f"Test set: Average Loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}, {100*correct/len(test_loader.dataset):.0f}%")

In [38]:
# Run() aka main
def run(model, criterion, optimizer):
    # Training
    train_time = time.time()
    train(model, criterion, optimizer, num_epoch)
    m, s = divmod(time.time() - train_time, 60)
    print(f'Training Time: {m:.0f}m {s:.0f}s')
    # Testing
    test_time = time.time()
    test(model, criterion)
    m, s = divmod(time.time() - test_time, 60)
    print(f'Testing Time: {m:.0f}m {s:.0f}s')
    # Total
    m, s = divmod(time.time() - train_time, 60)
    print(f'Total Time: {m:.0f}m {s:.0f}s\nTrained on {device}')

### Simple CNN model 

In [8]:
# Build the model
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.fc = nn.Linear(7*7*32, num_classes)
        
        
    def forward(self, x):
    
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

### Simple Inception model

In [44]:
class InceptionA(nn.Module):
    def __init__(self, in_channels):
        super(InceptionA, self).__init__()
        self.branch1x1 = nn.Conv2d(in_channels, 16, kernel_size=1)
        
        self.branch5x5_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
        self.branch5x5_2 = nn.Conv2d(16, 24, kernel_size=5, padding=2)
        
        self.branch3x3_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
        self.branch3x3_2 = nn.Conv2d(16, 24, kernel_size=3, padding=1)
        self.branch3x3_3 = nn.Conv2d(24, 24, kernel_size=3, padding=1)
        
        self.branch_pool = nn.Conv2d(in_channels, 24, kernel_size=1)
        
    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        
        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)
        
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)
        branch3x3 = self.branch3x3_3(branch3x3)
        
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)
        
        return torch.cat([branch1x1, branch5x5, branch3x3, branch_pool], dim=1) 

In [45]:
class SimpleInception(nn.Module):
    def __init__(self):
        super(SimpleInception, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(88, 20, kernel_size=5)
        
        self.incept1 = InceptionA(in_channels=10)
        self.incept2 = InceptionA(in_channels=20)
        
        self.mp = nn.MaxPool2d(2)
        self.fc = nn.Linear(1408, 10)
        
    def forward(self, x):
        x = F.relu(self.mp(self.conv1(x)))
        x = self.incept1(x)
        x = F.relu(self.mp(self.conv2(x)))
        x = self.incept2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

###  Simple CNN

In [48]:
model = SimpleCNN(num_classes).to(device)

# Optimzer and loss
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Train & Run
run(model, criterion, optimizer)

Epoch [1/5] | Batch [64/60000] | Loss: 2.4511
Epoch [1/5] | Batch [4160/60000] | Loss: 0.3595
Epoch [1/5] | Batch [8256/60000] | Loss: 0.2785
Epoch [1/5] | Batch [12352/60000] | Loss: 0.0589
Epoch [1/5] | Batch [16448/60000] | Loss: 0.1183
Epoch [1/5] | Batch [20544/60000] | Loss: 0.1268
Epoch [1/5] | Batch [24640/60000] | Loss: 0.1576
Epoch [1/5] | Batch [28736/60000] | Loss: 0.0413
Epoch [1/5] | Batch [32832/60000] | Loss: 0.1169
Epoch [1/5] | Batch [36928/60000] | Loss: 0.1426
Epoch [1/5] | Batch [41024/60000] | Loss: 0.0296
Epoch [1/5] | Batch [45120/60000] | Loss: 0.1077
Epoch [1/5] | Batch [49216/60000] | Loss: 0.0733
Epoch [1/5] | Batch [53312/60000] | Loss: 0.1120
Epoch [1/5] | Batch [57408/60000] | Loss: 0.0141
Epoch [2/5] | Batch [64/60000] | Loss: 0.0183
Epoch [2/5] | Batch [4160/60000] | Loss: 0.0353
Epoch [2/5] | Batch [8256/60000] | Loss: 0.0695
Epoch [2/5] | Batch [12352/60000] | Loss: 0.0290
Epoch [2/5] | Batch [16448/60000] | Loss: 0.0385
Epoch [2/5] | Batch [20544/600

### Simple Inception (add batch norm)

In [49]:
model = SimpleInception().to(device)

# Optimzer and loss
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Train & Test
run(model, criterion, optimizer)

Epoch [1/5] | Batch [64/60000] | Loss: 2.2859
Epoch [1/5] | Batch [4160/60000] | Loss: 0.4193
Epoch [1/5] | Batch [8256/60000] | Loss: 0.1811
Epoch [1/5] | Batch [12352/60000] | Loss: 0.1826
Epoch [1/5] | Batch [16448/60000] | Loss: 0.1750
Epoch [1/5] | Batch [20544/60000] | Loss: 0.1108
Epoch [1/5] | Batch [24640/60000] | Loss: 0.0156
Epoch [1/5] | Batch [28736/60000] | Loss: 0.0462
Epoch [1/5] | Batch [32832/60000] | Loss: 0.0780
Epoch [1/5] | Batch [36928/60000] | Loss: 0.0953
Epoch [1/5] | Batch [41024/60000] | Loss: 0.0704
Epoch [1/5] | Batch [45120/60000] | Loss: 0.0185
Epoch [1/5] | Batch [49216/60000] | Loss: 0.1125
Epoch [1/5] | Batch [53312/60000] | Loss: 0.0138
Epoch [1/5] | Batch [57408/60000] | Loss: 0.0983
Epoch [2/5] | Batch [64/60000] | Loss: 0.1020
Epoch [2/5] | Batch [4160/60000] | Loss: 0.0524
Epoch [2/5] | Batch [8256/60000] | Loss: 0.0363
Epoch [2/5] | Batch [12352/60000] | Loss: 0.1190
Epoch [2/5] | Batch [16448/60000] | Loss: 0.2216
Epoch [2/5] | Batch [20544/600