Mount drive for checkpoints

In [None]:
import os
from google.colab import drive
drive.mount('/content/gdrive')

!ls "/content/gdrive/MyDrive/cnn-architectures/resnet"
root_path = '/gdrive/MyDrive/cnn-architectures/resnet'
path = '/content/gdrive/MyDrive/cnn-architectures/resnet'
os.chdir(path)

Standard imports

In [None]:
import numpy as np
import time

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

from torchsummary import summary

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt


Initialize weights

In [None]:
#use kaiming initialization
def _weights_init(m):
    """
        Initialization of CNN weights
    """
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

Identity layer (Lambda layer)

In [None]:
class IdentityConn(nn.Module):
    """
      Identity mapping between ResNet blocks with different sized feature maps
    """
    def __init__(self, lambd):
        super(IdentityConn, self).__init__()
        self.lambd = lambd  #since lambda is a python keyword

    def forward(self, x):
        return self.lambd(x)

Basic block

In [None]:
'''
Consists of 2 convolutional blocks each of which is followed by a Batch-norm layer.
Each basic block is "short-circuited" to create the identity mapping.
'''
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1): 
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
          self.shortcut = IdentityConn(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

ResNet class

In [None]:
'''
3 stacks of 2*n (n = number of basic blocks) layers
each of the 2n layers have feature maps of size {16,32,64} 
a stride of 2 is used for subsampling while performing the convolution
'''
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)
        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


ResNet 56 definition

In [None]:
def resnet56():
    return ResNet(BasicBlock, [9, 9, 9])

Set hyperparameters

In [None]:
class MyResNetArgs:
   def __init__(self, arch='resnet56' ,epochs=200, start_epoch=0, batch_size=128, lr=0.1, momentum=0.9, weight_decay=1e-4, print_freq=55,
                 evaluate=0, pretrained=0, save_dir='save_temp', save_every=10):
        self.save_every = save_every        #Saves checkpoints at every specified number of epochs
        self.save_dir = save_dir            #The directory used to save the trained models
        self.evaluate = evaluate            #evaluate model on the validation set
        self.pretrained = pretrained        #evaluate the pretrained model on the validation set
        self.print_freq = print_freq        #print frequency 
        self.weight_decay = weight_decay
        self.momentum = momentum 
        self.lr = lr                        #Learning rate
        self.batch_size = batch_size 
        self.start_epoch = start_epoch
        self.epochs = epochs
        self.arch = arch                    #ResNet model used

Model summary

In [None]:
args = MyResNetArgs('resnet56',pretrained=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

model = resnet56().to(device)

summary(model, (3,32,32))
best_prec1 = 0

Train function

In [None]:
def train(train_loader, model, criterion, optimizer, epoch):

    batch_time = AvgCalc()
    data_time = AvgCalc()
    losses = AvgCalc()
    top1 = AvgCalc()

    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)
        target = target.cuda()
        input_var = input.cuda()
        target_var = target
        if args.half:
            input_var = input_var.half()

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # compute gradient and perform one iteration of SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss (top-1%)
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1))

Validation function

In [None]:
def validate(val_loader, model, criterion):
    '''
      print the top-k classification accuracy and error
    '''
    
    batch_time = AvgCalc  losses = AvgCalc()
    top1 = AvgCalc()

    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            if args.half:
                input_var = input_var.half()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)
            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()


    print('Test\t  Prec@1: {top1.avg:.3f} (Err: {error:.3f} )\n'
          .format(top1=top1,error=100-top1.avg))

    return top1.avg


Save progress

In [None]:
def save_checkpoint(state, filename='checkpoint.th'):
    torch.save(state, filename)

Average accuracy of mini-batches

In [None]:
class AvgCalc(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

Top-k precision at specified k

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


Preprocess data

In [None]:
#normalize the images

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

#Training data loader
train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=4)

#Validation data loader
val_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=128, shuffle=False,
        num_workers=4)

#Specify classes
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Experiment...

In [None]:
def main():
    global args, best_prec1
    
    # Check whether save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    model = resnet56()
    model.cuda()

    # define loss function and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[100, 150], last_epoch=args.start_epoch - 1)
    
    if args.evaluate:
        print('evaluate')
        model.load_state_dict(torch.load(os.path.join(args.save_dir, 'model.th')))
        best_prec1 = validate(val_loader, model, criterion)
        return best_prec1
    
    for epoch in range(args.start_epoch, args.epochs):

        print('Learning rate {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader, model, criterion, optimizer, epoch)
        lr_scheduler.step()

        #evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        #best precision and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint(model.state_dict(), filename=os.path.join(args.save_dir, 'checkpoint.th'))
        if is_best:
            save_checkpoint(model.state_dict(), filename=os.path.join(args.save_dir, 'model.th'))

    return best_prec1

In [None]:
if __name__ == '__main__':
   best_prec1 = main()
   print('The lowest error from {} model after {} epochs is {error:.3f}'.format(args.arch,args.epochs,error=100-best_prec1)) 