In [4]:
import torch

import torch.nn as nn
import numpy as np
import datetime
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F

Load the data and apply transform as proposed in the paper (although this is for ImageNet instead of CIFAR10): horizontal flip and subtraction of per pixel means.

In [9]:
NUM_TRAIN = 49000

dtype = torch.float32

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

bs=64
transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                T.RandomHorizontalFlip()
            ])

cifar10_train = dset.CIFAR10('/data', train=True, download=True,
                             transform=transform)
loader_train = DataLoader(cifar10_train, batch_size=bs, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar10_val = dset.CIFAR10('/data', train=True, download=True,
                           transform=transform)
loader_val = DataLoader(cifar10_val, batch_size=bs, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar10_test = dset.CIFAR10('/data', train=False, download=True, 
                            transform=transform)
loader_test = DataLoader(cifar10_test, batch_size=bs)

cuda
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


We use the renet structure, composing our network using `resblocks` which skip certain connections to the next layer. The standardstrategy is to define a building block, such as a (Conv2d, ReLU, Conv2d) + skip
connection block, and then build the network using the `nn.sequential` api. 

In [2]:

class ResBlock(nn.Module):
    '''
    Creating one resblock.
    With Conv1+bn+ReLU+Conv2+bn+ReLU
    '''
    def __init__(self, in_chans, out_chans, stride=1, downsample=None):

        super().__init__()

        self.conv1 = nn.Conv2d(in_chans,out_chans,kernel_size=3,padding=1,stride=stride)

        torch.nn.init.kaiming_normal_(self.conv1.weight,\
           nonlinearity='relu')
        
        self.batch_norm1 = nn.BatchNorm2d(num_features=out_chans)
        # Note the second conv layer does not have stride to avoid reducing too much
        self.conv2 = nn.Conv2d(out_chans,out_chans,kernel_size=3,padding=1)
        torch.nn.init.kaiming_normal_(self.conv2.weight,\
           nonlinearity='relu')

        self.batch_norm2 = nn.BatchNorm2d(num_features=out_chans)

        self.downsample = downsample

    def forward(self, x):
        residual = x
       
        out = self.conv1(x)
        
        out = self.batch_norm1(out)
        out = torch.relu(out)
        out = self.conv2(out)
        
        if self.downsample:
            # Downsampling the residual if necessary
            residual = self.downsample(x)
        out += residual
        out = torch.relu(out)
        
        return out


The structure is $Conv2d+Subsampling+ResBlock \ Layers+AvgPool+FullyConnected$. We follow the similar structure proposed in the seminal paper by Kaiming He et al., which adopts conv layers with increasing filters and residual blocks.

In [6]:
class NetResDeep(nn.Module):
    def __init__(self, block, layers,in_channel=16):
        '''
        Initialize the NetResDeep network.

        Inputs:
        - block: a nn.Module object like a ResBlock 
        - layers: a list of integers specifying the number of extra resblocks
        - in_channel: an integer specifying the first downsampled size
        '''
        super().__init__()

        self.in_channel = in_channel
        
        self.conv1 = nn.Conv2d(3,self.in_channel,kernel_size=3,padding=1)
        self.bn = nn.BatchNorm2d(num_features=in_channel)
        
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[1], 2) # Increasing stride but also increasing out channels
        self.layer3 = self.make_layer(block, 64, layers[2], 2)

        self.avg_pool = nn.AvgPool2d(kernel_size=8)

        self.fc1 = nn.Linear(64, 10)
        #self.fc2 = nn.Linear(32, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn(out)
        out = self.relu(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)

        out = self.fc1(out)
        #out = self.fc2(self.relu(out))
        return out

    def make_layer(self, block, out_channels, blocks, stride=1):
        '''
        Making ResBlock layers.
        Inputs:
        - block: a nn.Module object like a ResBlock 
        - out_channels: an integer of the no. of channels that object should have 
        - blocks: an integer specifying how many such blocks is needed further
                  as the first block is used to make in_chans = out_chans
        - stride: an integer specifying stride, default 1
    
        Returns: nn.sequential object 
        '''
        downsample = None # Default no down sampling
        if (stride!=1) or (self.in_channel!=out_channels):
            # With downsampling as striding with stride specified
            downsample = nn.Sequential(\
                nn.Conv2d(self.in_channel, out_channels,stride=stride,kernel_size=3,padding=1),\
                    nn.BatchNorm2d(out_channels)) 
        block_layers = []
        # The blocks are the residual blocks above and can be any other types of blocks
        # Making the first ResBlock that downsample the original input
        
        block_layers.append(block(self.in_channel, out_channels,stride,downsample))

        # Next, set the in_channels to be the same as out_channels for consecutive blocks
        self.in_channel = out_channels
        # Concantenate the blocks
        for i in range(1, blocks):
            block_layers.append(block(out_channels, out_channels))
        return nn.Sequential(*block_layers)


def test_NetResDeep():
    x = torch.zeros((64, 3, 32, 32), dtype = dtype)  # minibatch size 64, image size [3, 32, 32]
    model = NetResDeep(ResBlock,[2,2,2])
    scores = model(x)
    print(x.shape)
    print(scores.size())  # you should see [64, 10]
test_NetResDeep()


torch.Size([64, 3, 32, 32])
torch.Size([64, 10])


We define two functions for training and accuracy checking. Functions are adopted from Stanford cs231n courses

In [10]:
def check_accuracy(loader, model):
    """
    Check accuracy of a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - loader: a DataLoader object
    - model: A PyTorch Module giving the model to evaluate.
    
    Returns: Nothing, but prints model accuracies during training.
    """
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print(f'{datetime.datetime.now()} Got {num_correct} / {num_samples} correct ({100*acc})')

In [13]:
def train(model, optimizer, epochs=1,print_every = 100):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()
            
            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if t % print_every == 0:
                print('Epoch %d Iteration %d, loss = %.4f' % (e, t, loss.item()))
                check_accuracy(loader_val, model)
                print()

Train the model with $2$ extra blocks for each layer using optimizer $Adam$.

In [14]:
model = NetResDeep(ResBlock,[2,2,2])
optimizer = optim.Adam(model.parameters(),lr=1e-3)
train(model,optimizer)

Epoch 0 Iteration 0, loss = 2.3735
Checking accuracy on validation set
2021-12-18 22:32:59.944778 Got 105 / 1000 correct (10.5)

Epoch 0 Iteration 100, loss = 1.8104
Checking accuracy on validation set
2021-12-18 22:33:04.234409 Got 391 / 1000 correct (39.1)

Epoch 0 Iteration 200, loss = 1.6882
Checking accuracy on validation set
2021-12-18 22:33:08.926244 Got 412 / 1000 correct (41.199999999999996)

Epoch 0 Iteration 300, loss = 1.2090
Checking accuracy on validation set
2021-12-18 22:33:15.149018 Got 363 / 1000 correct (36.3)

Epoch 0 Iteration 400, loss = 1.2791
Checking accuracy on validation set
2021-12-18 22:33:20.753635 Got 465 / 1000 correct (46.5)

Epoch 0 Iteration 500, loss = 1.3510
Checking accuracy on validation set
2021-12-18 22:33:26.318245 Got 548 / 1000 correct (54.800000000000004)

Epoch 0 Iteration 600, loss = 1.0987
Checking accuracy on validation set
2021-12-18 22:33:32.172055 Got 497 / 1000 correct (49.7)

Epoch 0 Iteration 700, loss = 1.2117
Checking accuracy on

In [None]:
torch.save(model.state_dict(), 'NetResDeep.ckpt')