In [1]:
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

random_seed = 123
learning_rate = 0.01
num_epochs = 10
batch_size = 128

# Architecture
num_classes = 10


##########################
### MNIST DATASET
##########################

# Note transforms.ToTensor() scales input images
# to 0-1 range
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


9913344it [00:00, 25517945.53it/s]                             


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


29696it [00:00, 50426741.53it/s]         

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


1649664it [00:00, 11104144.50it/s]                           


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


5120it [00:00, 14882076.56it/s]         


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Image batch dimensions: torch.Size([128, 1, 28, 28])
Image label dimensions: torch.Size([128])


In [3]:
# ResNet with conv blocks for resizing
class ConvNet(torch.nn.Module):
    def __init__(self, num_classes):
        super(ConvNet, self).__init__()
        
        # 1st residual block 
        # 28x28x1 -> 14x14x4
        self.conv_1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=4,
                                      kernel_size=(3,3),
                                      stride=(2,2),  #reduce size by half
                                      padding=1)
        
        self.conv_1_bn = torch.nn.BatchNorm2d(4)
        
        # 14x14x4 -> 14x14x8
        self.conv_2 = torch.nn.Conv2d(in_channels=4,
                                      out_channels=8,
                                      kernel_size=(1,1),
                                      stride=(1, 1),
                                      padding=0
                                      )
        self.conv_2_bn = torch.nn.BatchNorm2d(8)
        
        # 28x28x1 -> 14x14x8
        self.conv_shortcut_1 = torch.nn.Conv2d(in_channels=1,
                                               out_channels=8,
                                               kernel_size=(1, 1),
                                               stride=(2, 2),
                                               padding=0)
        self.conv_shortcut_1_bn = torch.nn.BatchNorm2d(8)
        
        # 2nd residual block
        # 14x14x8 -> 7x7x16
        self.conv_3 = torch.nn.Conv2d(in_channels=9,
                                      out_channels=16,
                                      kernel_size=(3, 3),
                                      stride=(2,2),
                                      padding=1)
        self.conv_3_bn = torch.nn.BatchNorm2d(16)
        
        # 7x7x16 -> 7x7x32
        self.conv_4 = torch.nn.Conv2d(in_channels=16,
                                      out_channels=32,
                                      kernel_size=(1,1),
                                      stride=(1,1),
                                      padding=0)
        self.conv_4_bn = torch.nn.BatchNorm2d(32)
        
        # 14x14x8 -> 7x7x32
        self.conv_shortcut_2 = torch.nn.Conv2d(in_channels=8,
                                               out_channels=32,
                                               kernel_size=(1,1),
                                               stride=(2,2),
                                               padding=0)
        self.conv_shortcut_2_bn = torch.nn.BatchNorm2d(32)
        
        self.linear_1 = torch.nn.Linear(7*7*32, num_classes)
        
    def forward(self,x):
        # 1st residual block
        shortcut = x
        out = self.conv_1(x)  #28x28x1 -> 14x14x4
        out = self.conv_1_bn(out)
        out = F.relu(out)
        
        out = self.conv_2(out)  #14x14x4 -> 714x14x8
        out = self.conv_2_bn
        
        # match up dimensions using a linear function (no relu)
        shortcut = self.conv_shortcut_1(shortcut)
        shortcut = self.conv_shortcut_1_bn(shortcut)
        
        out += shortcut
        out = F.relu(out)
        
        # 2nd residual block
        shortcut = out
        out = self.conv_3(out)  # 14x14x8 -> 7x7x16
        out = self.conv_3_bn(out)
        out = F.relu(out)
        
        out = self.conv_4(out)  # 7x7x16 -> 7x7x32
        out = self.conv_4_bn(out)
        
        # match up dimensions using a linear function (no relu)
        shortcut = self.conv_shortcut_2(shortcut)
        shortcut = self.conv_shortcut_2_bn(shortcut)
        
        out += shortcut
        out = F.relu(out)
        
        # Fully connected
        logits = self.linear_1(out.view(-1, 7*7*32))
        probas = F.softmax(logits, dim=1)
        return logits, probas

torch.manual_seed(random_seed)
model = ConvNet(num_classes=num_classes)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)



     

In [4]:
# Training 
def compute_accuracy(model, data_loader):
    correct_pred, num_examples = 0, 0
    for i, (features, targets) in enumerate(data_loader):
        logits, probas = model(features)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels==targets).sum()
    return correct_pred.float()/num_examples * 100

for epoch in range(num_epochs):
    model = model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):
        # Forward and back prop
        logits, probas = model(features)
        cost = F.cross_entropy(logits, targets)
        optimizer.zero_grad()
        
        # Update model parameters
        optimizer.step()
        
        # Logging
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch: %03d/%03d | Cost: %.4f' % (epoch+1, num_epochs, batch_idx, len(train_loader), cost))
            
    model = model.eval()  # eval mode to prevent update batchnorm params ent upd. batchnorm params during inference
    with torch.set_grad_enabled(False): # save memory during inference
        print('Epoch: %03d/%03d training accuracy: %.2d%%' %(epoch+1, num_epochs, compute_accuracy(model, train_loader)))

TypeError: unsupported operand type(s) for +=: 'BatchNorm2d' and 'Tensor'

In [None]:
# Evaluation
print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))