In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms
import torch.utils.data as datautils
import torch.optim as optim

In [2]:
class Shortcut(nn.Module):
  def __init__(self, in_channel: int, out_channel: int, stride: int):
    super().__init__()

    self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=stride)
    self.bn = nn.BatchNorm2d(num_features=out_channel)
  
  def forward(self, x: torch.Tensor):
    x = self.bn(self.conv(x))
    return x

In [3]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channel:int, out_channel:int, stride: int):
    super().__init__()

    self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1)
    self.bn1 = nn.BatchNorm2d(num_features=out_channel)
    self.act1 = nn.ReLU()
    
    self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1)
    self.bn2 = nn.BatchNorm2d(num_features=out_channel)
    
    if stride == 1:
      self.shortcut = nn.Identity()
    else:
      self.shortcut = Shortcut(in_channel=in_channel, out_channel=out_channel, stride=stride)

    self.act2 = nn.ReLU()
  
  def forward(self, x: torch.Tensor):

    shortcut = self.shortcut(x)

    x = self.act1(self.bn1(self.conv1(x)))
    x = self.bn2(self.conv2(x))

    x = self.act2(x + shortcut)

    return x

In [4]:
class ResnetBase(nn.Module):
  def __init__(self, n_blocks, n_channels, img_channels, first_kernel_size):
    super().__init__()

    assert len(n_blocks) == len(n_channels)

    self.conv1 = nn.Conv2d(in_channels=img_channels, out_channels=n_channels[0], stride=2, kernel_size=first_kernel_size, padding=first_kernel_size//2)
    self.bn1 = nn.BatchNorm2d(num_features=n_channels[0])

    previous_channel = n_channels[0]

    blocks = []
    for stage_i in range(len(n_blocks)):

      if len(blocks) == 0:
        stride = 1
      else:
        stride = 2
      
      # first block in stage
      blocks.append(ResidualBlock(in_channel=previous_channel, out_channel=n_channels[stage_i], stride=stride))

      previous_channel = n_channels[stage_i]

      for block_i in range(1, n_blocks[stage_i]):
          blocks.append(ResidualBlock(in_channel=n_channels[stage_i], out_channel=n_channels[stage_i], stride=1))

    self.blocks = nn.Sequential(*blocks)
  
  def forward(self, x: torch.Tensor):

    x = self.bn1(self.conv1(x))
    x = self.blocks(x)

    # global average pooling
    x = x.view(x.shape[0], x.shape[1], -1)
    x = x.mean(dim=-1)  
    
    return x

In [5]:
def train(model, batch_size, max_epoch, device):

  # transforms for dataset
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  preprocessing_train = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
  ])
  preprocessing_val = transforms.Compose([
    transforms.Resize(36),
    transforms.CenterCrop(32),
    transforms.ToTensor(),
    normalize,
  ])

  # Dataset
  dataset_train = datasets.CIFAR10(root='./dataset', train=True, transform=preprocessing_train, download=True)
  dataset_val = datasets.CIFAR10(root='./dataset', train=False, transform=preprocessing_val, download=True)

  # Dataloaders
  dataloader_train = datautils.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, num_workers=2)
  dataloader_val = datautils.DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False, num_workers=2)
  
  # Move model to device (GPU,...)
  model.to(device)

  # Loss function
  criterion = nn.CrossEntropyLoss()
  
  # Optimization algorithm
  optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

  # Train for number of epochs
  for epoch in range(max_epoch):
  
      # Loss of epoch
      running_loss = 0.0
  
      # Iterate over all batches
      for i, data in enumerate(dataloader_train, 0):
  
          images, labels = data
  
          # Makes all gradients zero
          optimizer.zero_grad()
  
          # Feedforward
          images = images.to(device)
          output = model(images)
  
          # Loss
          labels = labels.to(device)
          loss = criterion(output, labels)
  
          # Backpropagation
          loss.backward()
  
          # Update parameters
          optimizer.step()
  
          running_loss += loss.item()
  
          if i % 100 == 0 or i == len(dataloader_train)-1:
              print('epoch {}, mini-batch {}, loss {}'.format(epoch, i + 1, running_loss/1000.0))
              running_loss = 0.0
      
      # end of epoch
      scheduler.step()

      if epoch % 5 == 0 or epoch == max_epoch - 1:
        # accuracy on validation set
        correct = 0
        total = 0
        with torch.no_grad():
            for data in dataloader_val:
                images, ground_truth_labels = data
                images = images.to(device)
                ground_truth_labels = ground_truth_labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, dim=1)
                total += ground_truth_labels.size(0)
                correct += (ground_truth_labels==predicted).sum().item()
        print('>  Accuracy on test set is: {}%'.format(correct/total * 100))
  

In [None]:
# if gpu available use gpu otherwise use cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device: {}'.format(device))

# ResNet-44
model = ResnetBase(n_blocks=[15, 14, 14, 1], n_channels=[16, 32, 64, 10], img_channels=3, first_kernel_size=3)

train(model, batch_size=128, max_epoch=320, device=device)

Device: cuda:0
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataset/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./dataset/cifar-10-python.tar.gz to ./dataset
Files already downloaded and verified
epoch 0, mini-batch 1, loss 0.002348513603210449
epoch 0, mini-batch 101, loss 0.21305034899711608
epoch 0, mini-batch 201, loss 0.19723981618881226
epoch 0, mini-batch 301, loss 0.18990710651874543
epoch 0, mini-batch 391, loss 0.16701935803890228
>  Accuracy on test set is: 39.2%
epoch 1, mini-batch 1, loss 0.0019061276912689209
epoch 1, mini-batch 101, loss 0.1812601865530014
epoch 1, mini-batch 201, loss 0.17881095540523528
epoch 1, mini-batch 301, loss 0.17756730461120607
epoch 1, mini-batch 391, loss 0.15665478003025055
epoch 2, mini-batch 1, loss 0.0015686858892440796
epoch 2, mini-batch 101, loss 0.17251303172111512
epoch 2, mini-batch 201, loss 0.17028216016292572
epoch 2, mini-batch 301, loss 0.16783498585224152
epoch 2, mini-batch 391, loss 0.14829655528068542
epoch 3, mini-batch 1, loss 0.0017497577667236327
epoch 3, mini-batch 101, loss 0.16269482886791228
epoch 3, mini-batch 201