<a href="https://colab.research.google.com/github/mshadloo/Resnet_cifar10_pytorch/blob/main/resnet_cigar10_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

 ## Residual Networks
In theory, very deep networks can represent very complex functions; but in practice, because of vanishing/exploding gradient they are hard to train. Residual Networks, introduced by [He et al.](https://arxiv.org/pdf/1512.03385.pdf), allow you to train much deeper networks than were previously practically feasible.

In this notebook, I implemeneted the basic building blocks of ResNets and plain/residual network for CIFAR10 as described in the [original paper](https://arxiv.org/abs/1512.03385).

In plain/residual network for CIFAR10, the network inputs are $ 32 \times 32 \times 3 $ images. The first layer is $3 \times 3 $ convolutions. Then, there are 3 stages. Each stage is a stack of $2n$ layers, where $n$ is the number of basic blocks of each stage. Each basic block is a stack of 2 layers of $3 \times 3$ convolutions. The output of 3 stages are feature maps of sizes $\{ 32, 16, 8\}$ respectively, and the number of filters in each layer of 3 stages are $\{ 16, 32, 64\}$ respectively. In residual network, there is shortcut path from input of the basic block to its output, while in plain network, there is only main path.

In the cell below, you can see implementaion of basic building block of ResNets in the class "BasicBlock", and  implementaion of ResNets for CIFAR10 in the class "ResNet". The parameter $n$, the number of basic blocks of each stage of ResNets, determines the number of layers of the model. At the end of the following cell, I provided ResNet20, ResNet32, ResNet44, ResNet56 and ResNet 110 by setting this parameter. To have corresponding plain networks, you only need to call these methods by setting the argument "plain" to True.

In [None]:

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init




class Lambda(nn.Module):
    def __init__(self, func):
        super(Lambda, self).__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)



class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, projection = False, plain = False):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        self.plain = plain
        if stride != 1 or in_channels != out_channels:
            if projection:
                # projection shortcut, as option B in paper
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(out_channels))

            else :
                # identity shortcut, as option A in paper
                self.shortcut = Lambda(lambda x:F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, out_channels // 4, out_channels // 4),"constant", 0))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if not self.plain:
            out += self.shortcut(x)
        out = F.relu(out)
        return out



def init_weights(m):
    
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)


class ResNet(nn.Module):
    def __init__(self, num_blocks = 5, num_classes=10, plain = False):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        # First stack of residual blocks, output is 16 x 32 x 32
        self.layer1 = self.residual_layer(16, num_blocks, stride=1, plain = plain)
        # Second stack of residual blocks, output is 32 x 16 x 16
        self.layer2 = self.residual_layer(32, num_blocks, stride=2,plain = plain)
        # Third stack of residual blocks, output is 64 x 8 x 8
        self.layer3 = self.residual_layer( 64, num_blocks, stride=2,plain = plain)
        self.linear = nn.Linear(64, num_classes)
        # import pdb; pdb.set_trace()
        self.apply(init_weights)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.reshape(out.size(0), -1)
        out = self.linear(out)
        return out

    def residual_layer(self, out_channels, num_blocks, stride, plain = False):
        layers = []
        layers.append(BasicBlock(self.in_channels, out_channels, stride, plain=plain))
        self.in_channels = out_channels
        for i in range(num_blocks-1):
            layers.append(BasicBlock(self.in_channels, out_channels, stride=1, plain=plain))
            self.in_channels = out_channels

        return nn.Sequential(*layers)
def resnet20(plain = False):
    return ResNet(3, plain= plain)


def resnet32(plain = False):
    return ResNet(5,plain= plain)


def resnet44(plain = False):
    return ResNet(7, plain= plain)


def resnet56(plain = False):
    return ResNet(9, plain= plain)


def resnet110(plain = False):
    return ResNet(18, plain= plain)


## Training and Evaluating

### Google Colab
To train provided models, you required GPU. You can use GPUs of Google Colab. If you are using Google Colab and if the runtime restarts during training, you will lose your trained model. Then you have to start again from the scratch, which is not optimal. Instead, you should save your model checkpoint to Google Drive and reload it next time when you start. To do so, you need to mount your Google Drive  and give permission to Google Colab to access it. Also you need to create a directory in your Google Drive for this project. 

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
working_dir = "./drive/MyDrive/res_cifar10_pytorch"

### Hyperparameters

In [None]:
import easydict
args = easydict.EasyDict({
    "res_model":"resnet32",
    "epochs": 200,
    "batch_size": 128, "lr":0.1, 'momentum':0.9, 'weight_decay':5e-4,'save_dir':'save_dir','data_dir':'data','plain_mode':False,'resume':''
})

Let's run the cell below to load the required packages and set the directories for saving data, models and results.

In [None]:
import argparse
import os
import pickle
import time

import torch
import torchvision
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch.nn as nn

data_dir = os.path.join(working_dir, args.data_dir)
save_dir = os.path.join(working_dir, args.save_dir)
model_name = args.res_model  +("_plain5" if args.plain_mode else "")
checkpoint_name = "_".join([ model_name,'checkpoint.th'])
models={"resnet20":resnet20, "resnet32":resnet32, "resnet44":resnet44,"resnet56":resnet56,"resnet110":resnet110 }

### Dataset
I load and normalize CIFAR10 training and test datasets using torchvision. The training dataset contains 50K images and the test dataset contains 10K images.

In [None]:
def load_data():
    print('==> Loading data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform_train)

    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size= args.batch_size, shuffle=True, num_workers=4)

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform_test)

    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=4)

    return train_loader,test_loader

### One step of training:

In [None]:
def make_train_step(train_loader, model, loss_fn, optimizer):
    def train_step(epoch):
        model.train()
        train_loss = 0
        correct = 0
        data_count = 0
        
        for batch_ix, (inputs, targets) in enumerate(train_loader):
            curr_time = time.time()
      
            if torch.cuda.is_available():
                inputs, targets = inputs.cuda(), targets.cuda()

           
          
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
          
            outputs = outputs.float()

            _, predictions = outputs.max(1)

            batch_size = inputs.size(0)
            loss = loss.float()
            train_loss += loss.item() * batch_size
            data_count += batch_size
            correct += predictions.eq(targets).sum().item()

            if batch_ix % 50 == 0:
                print('Epoch: [{0}][{1}/{2}]\t'.format(
                    epoch, batch_ix, len(train_loader)))
                print("It took {:.3f}s".format(
                    time.time() - curr_time))
                print("  training loss:\t\t{:.6f}".format(train_loss / (data_count)))

        return train_loss / data_count, 100. * correct / data_count
    return train_step

### Evaluation

In [None]:
def evaluate(test_loader, model, loss_fn):
    model.eval()
    test_loss = 0
    correct = 0
    data_count = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            if torch.cuda.is_available():
                inputs, targets = inputs.cuda(), targets.cuda()

            outputs = model(inputs).float()
            loss = loss_fn(outputs, targets).float()

            test_loss += loss.item()*targets.size(0)
            _, predictions = outputs.max(1)
            data_count += targets.size(0)
            correct += predictions.eq(targets).sum().item()
    return test_loss / data_count, 100. * correct / data_count

### Saving results

In [None]:
def save_results(test_acc, train_acc, best_acc):
    f = {}
    f['test_acc'] = test_acc
    f['train_acc'] = train_acc
    f['best_acc'] = best_acc

    filename = os.path.join(save_dir, model_name)
    fName = open(filename + ".pkl", "wb")
    pickle.dump(f, fName)
    fName.close()

### Train

In [None]:
  best_acc =0

  train_loader, test_loader = load_data()
  print("==>Building model...")
  # model = resnet.__dict__[args.res_model](plain = args.plain_mode)
  model = models[args.res_model](plain = args.plain_mode)
  if torch.cuda.is_available():
      model.cuda()
      model = torch.nn.DataParallel(model)
      cudnn.benchmark = True
  if not os.path.exists(save_dir):
      os.makedirs(save_dir)

  loss_fn = nn.CrossEntropyLoss()
  if torch.cuda.is_available():
      loss_fn = loss_fn.cuda()


  optimizer = torch.optim.SGD(model.parameters(), args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

  lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60,100,150,180],gamma=0.1, last_epoch=-1)
  # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
  # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1, last_epoch=-1) 


  loss_train, acc_train, loss_test, acc_test = [], [], [], []
  train_step = make_train_step(train_loader, model, loss_fn, optimizer)
  print("==>Starting training...")
  for epoch in range(args.epochs):

      curr_time = time.time()
      loss, acc = train_step(epoch)
      train_time = time.time() - curr_time


      loss_train.append(loss)
      acc_train.append(acc)
      lr_scheduler.step()

      loss, acc = evaluate(test_loader, model, loss_fn)


      loss_test.append(loss)
      acc_test.append(acc)

      if acc > best_acc:
          print('Saving..')
          state = {
              'state_dict': model.state_dict(),
              'acc': acc,
              'epoch': epoch,
              'optimizer_state_dict': optimizer.state_dict(),
          }
          filename = os.path.join(save_dir, checkpoint_name)
          torch.save(state, filename)
          best_acc = acc
      print("Epoch {} of {} took {:.3f}s".format(
          epoch + 1, args.epochs, train_time))
      print("  training loss:\t\t{:.6f}".format(loss_train[-1]))
      print("  test loss:\t\t{:.6f}".format(loss_test[-1]))
      print("  test accuracy:\t\t{:.2f} %".format(
          acc_test[-1]))
  save_results(acc_test, acc_train, best_acc)
  print("best error:",1 - (best_acc / 100. ) 