# Homework 3, exercise 2 - Residual Neural Network on CIFAR10

In this exercise we implement a (slightly modified) ResNet as introduced in [this paper](https://arxiv.org/pdf/1512.03385.pdf).

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import time

For this exercise it is recommended to use the GPU!

In [2]:

use_cuda = True

if use_cuda and torch.cuda.is_available():
  device = torch.device('cuda:0')
else:
  device = torch.device('cpu')

device

device(type='cuda', index=0)

### Load the CIFAR10 dataset

In [3]:
import torchvision
import torchvision.transforms as transforms

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data_cifar', train=True,
                                        download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR10(root='./data_cifar', train=False,
                                       download=True, transform=transform_test)

batch_size = 128

c, w, h = 3, 32, 32

trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=batch_size,
                                          shuffle=True)

testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=batch_size,
                                         shuffle=True)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


## Exercise - Implement a Residual Block

Residual neural networks mainly consist of components called Residual Blocks. One residual block can be expressed as **y** = *F*(**x**) + **x** where **x** and **y** are the input and output of the block, respectively. So the input **x** is added to the result of *F*(**x**) using a *skip connection*. In this exercise, *F* consists of:
* a convolutional layer with `in_channels` input channels, `hidden_channels` output channels, a kernel size of (3, 3), a stride of 1, padding of 1 and no bias parameter.
* a batch normalisation layer 
* ReLU activation
* a convolutional layer with `hidden_channels` input channels, `out_channels` output channels, a kernel size of (3, 3), a stride of 1, padding of 1 and no bias parameter.
* a batch normalisation layer

After this the `skip_connection` is applied. If the dimensions of *F*(**x**) and **x** don't match an extra linear projection is applied to **x** so the dimensions do match. This has already been implemented for you. You only need to call it at the right place. 
Finally, a ReLU activation is applied on the output **y**


In [18]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=(3,3), stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(hidden_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=(3,3), stride = 1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        if in_channels != out_channels:  # F(x) and x dimensions do not match! Define a projection for input x
            self.skip_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels)
                )
        else:
            self.skip_connection = lambda x: x  # The dimensions already match! No need to do a projection on x

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x += self.skip_connection(residual)
        x = self.relu(x)
        return (x)
    


  

## Exercise - Implement a Residual Neural Network
Now you can use the previously defined Residual Block to create your ResNet.

The network consists of:
* a convolutional layer with `in_channels` input channels, 64 output channels, a stride of 1, padding of 1 and no bias parameter,
* a batch normalisation layer
* ReLU activation
* a max pooling layer with kernel size (3, 3), a stride of 2 and padding of 1,
* eight residual blocks, with (64, 64, 128, 128, 256, 256, 512, 512) channels, respectively (see code below) 
* an average pooling layer over all feature maps (already present)
* a dense layer to form the output distribution (already present)

In [22]:
class ResNet(nn.Module):

    def __init__(self, in_channels, out_size):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels,64, kernel_size=1, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.res_blocks = nn.ModuleList(
            [
             ResidualBlock(64, 64, 64),
             ResidualBlock(64, 64, 64),
         
             ResidualBlock(64, 128, 128),
             ResidualBlock(128, 128, 128),
         
             ResidualBlock(128, 256, 256),
             ResidualBlock(256, 256, 256),

             ResidualBlock(256, 512, 512),
             ResidualBlock(512, 512, 512),
            ]
        )

        self.dense_layer = nn.Linear(512, out_size)
    
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.res_blocks[0](x)
        x = self.res_blocks[1](x)
        x = self.res_blocks[2](x)
        x = self.res_blocks[3](x)
        x = self.res_blocks[4](x)
        x = self.res_blocks[5](x)
        x = self.res_blocks[6](x)
        x = self.res_blocks[7](x)
        
        x = F.avg_pool2d(x, x.shape[2:])
    
        x = x.view(x.size(0), -1)
        x = self.dense_layer(x)

        return x



### Initialize the network, Loss function and Optimizer

In [26]:
net = ResNet(c, len(classes)).to(device)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

## Exercise - Train/evaluate the network
Train the network you built using the code below. Add the following answers in your report:
* What test accuracy were you able to get?
* How many layers does your network have? (counting only convolutional and dense layers)
* Why do the skip connections help for training deep neural networks?

In [27]:
start=time.time()

for epoch in range(0,20):

  net.train()  # Put the network in train mode
  for i, (x_batch, y_batch) in enumerate(trainloader):
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)  # Move the data to the device that is used
    
    optimizer.zero_grad()  # Set all currenly stored gradients to zero 

    y_pred = net(x_batch)

    loss = criterion(y_pred, y_batch)

    loss.backward()

    optimizer.step()

    # Compute relevant metrics
    
    y_pred_max = torch.argmax(y_pred, dim=1)  # Get the labels with highest output probability

    correct = torch.sum(torch.eq(y_pred_max, y_batch)).item()  # Count how many are equal to the true labels

    elapsed = time.time() - start  # Keep track of how much time has elapsed

    # Show progress every 20 batches 
    if not i % 20:
      print(f'epoch: {epoch}, time: {elapsed:.3f}s, loss: {loss.item():.3f}, train accuracy: {correct / batch_size:.3f}')
    
    correct_total = 0

  net.eval()  # Put the network in eval mode
  for i, (x_batch, y_batch) in enumerate(testloader):
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)  # Move the data to the device that is used

    y_pred = net(x_batch)
    y_pred_max = torch.argmax(y_pred, dim=1)

    correct_total += torch.sum(torch.eq(y_pred_max, y_batch)).item()

  print(f'Accuracy on the test set: {correct_total / len(testset):.3f}')




epoch: 0, time: 0.283s, loss: 2.451, train accuracy: 0.055
epoch: 0, time: 5.644s, loss: 1.931, train accuracy: 0.289
epoch: 0, time: 11.135s, loss: 1.738, train accuracy: 0.359
epoch: 0, time: 16.427s, loss: 1.636, train accuracy: 0.359
epoch: 0, time: 21.653s, loss: 1.754, train accuracy: 0.289
epoch: 0, time: 26.913s, loss: 1.699, train accuracy: 0.359
epoch: 0, time: 32.210s, loss: 1.502, train accuracy: 0.430
epoch: 0, time: 37.531s, loss: 1.512, train accuracy: 0.461
epoch: 0, time: 42.820s, loss: 1.581, train accuracy: 0.367
epoch: 0, time: 48.069s, loss: 1.346, train accuracy: 0.453
epoch: 0, time: 53.309s, loss: 1.525, train accuracy: 0.469
epoch: 0, time: 58.550s, loss: 1.527, train accuracy: 0.469
epoch: 0, time: 63.797s, loss: 1.385, train accuracy: 0.516
epoch: 0, time: 69.048s, loss: 1.406, train accuracy: 0.398
epoch: 0, time: 74.301s, loss: 1.341, train accuracy: 0.516
epoch: 0, time: 79.558s, loss: 1.277, train accuracy: 0.555
epoch: 0, time: 84.810s, loss: 1.497, trai

epoch: 6, time: 720.742s, loss: 0.646, train accuracy: 0.766
epoch: 6, time: 725.963s, loss: 0.840, train accuracy: 0.727
epoch: 6, time: 731.194s, loss: 0.660, train accuracy: 0.805
epoch: 6, time: 736.427s, loss: 0.611, train accuracy: 0.781
epoch: 6, time: 741.648s, loss: 0.803, train accuracy: 0.727
epoch: 6, time: 746.863s, loss: 0.646, train accuracy: 0.797
epoch: 6, time: 752.077s, loss: 0.757, train accuracy: 0.750
epoch: 6, time: 757.291s, loss: 0.781, train accuracy: 0.742
Accuracy on the test set: 0.713
epoch: 7, time: 766.960s, loss: 0.521, train accuracy: 0.852
epoch: 7, time: 772.185s, loss: 0.691, train accuracy: 0.742
epoch: 7, time: 777.414s, loss: 0.646, train accuracy: 0.750
epoch: 7, time: 782.633s, loss: 0.677, train accuracy: 0.789
epoch: 7, time: 787.868s, loss: 0.706, train accuracy: 0.750
epoch: 7, time: 793.103s, loss: 0.574, train accuracy: 0.812
epoch: 7, time: 798.342s, loss: 0.635, train accuracy: 0.781
epoch: 7, time: 803.560s, loss: 0.667, train accuracy

epoch: 13, time: 1425.820s, loss: 0.374, train accuracy: 0.875
epoch: 13, time: 1431.067s, loss: 0.425, train accuracy: 0.844
epoch: 13, time: 1436.295s, loss: 0.489, train accuracy: 0.836
epoch: 13, time: 1441.531s, loss: 0.409, train accuracy: 0.859
epoch: 13, time: 1446.769s, loss: 0.564, train accuracy: 0.805
epoch: 13, time: 1452.008s, loss: 0.737, train accuracy: 0.758
epoch: 13, time: 1457.245s, loss: 0.721, train accuracy: 0.773
epoch: 13, time: 1462.495s, loss: 0.563, train accuracy: 0.805
epoch: 13, time: 1467.738s, loss: 0.362, train accuracy: 0.883
epoch: 13, time: 1472.979s, loss: 0.530, train accuracy: 0.766
epoch: 13, time: 1478.214s, loss: 0.382, train accuracy: 0.867
epoch: 13, time: 1483.499s, loss: 0.599, train accuracy: 0.805
epoch: 13, time: 1488.758s, loss: 0.476, train accuracy: 0.844
epoch: 13, time: 1494.023s, loss: 0.459, train accuracy: 0.812
epoch: 13, time: 1499.293s, loss: 0.423, train accuracy: 0.844
epoch: 13, time: 1504.607s, loss: 0.548, train accuracy

epoch: 19, time: 2117.071s, loss: 0.285, train accuracy: 0.883
epoch: 19, time: 2122.291s, loss: 0.297, train accuracy: 0.867
epoch: 19, time: 2127.528s, loss: 0.351, train accuracy: 0.883
epoch: 19, time: 2132.753s, loss: 0.277, train accuracy: 0.891
epoch: 19, time: 2137.982s, loss: 0.373, train accuracy: 0.883
epoch: 19, time: 2143.214s, loss: 0.319, train accuracy: 0.883
epoch: 19, time: 2148.436s, loss: 0.327, train accuracy: 0.906
epoch: 19, time: 2153.667s, loss: 0.330, train accuracy: 0.891
epoch: 19, time: 2158.886s, loss: 0.295, train accuracy: 0.898
epoch: 19, time: 2164.100s, loss: 0.298, train accuracy: 0.891
epoch: 19, time: 2169.319s, loss: 0.261, train accuracy: 0.930
epoch: 19, time: 2174.552s, loss: 0.291, train accuracy: 0.906
Accuracy on the test set: 0.828


In [28]:
correct_total = 0

for i, (x_batch, y_batch) in enumerate(testloader):
  x_batch, y_batch = x_batch.to(device), y_batch.to(device)  # Move the data to the device that is used

  y_pred = net(x_batch)
  y_pred_max = torch.argmax(y_pred, dim=1)

  correct_total += torch.sum(torch.eq(y_pred_max, y_batch)).item()

print(f'Accuracy on the test set: {correct_total / len(testset):.3f}')

Accuracy on the test set: 0.828
