<a href="https://colab.research.google.com/github/klinime/Intro_to_CNN/blob/master/PyramidNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import nn, optim
from torchvision import transforms
import matplotlib.pyplot as plt
import glob, csv
import time, datetime

plt.rcParams['figure.figsize'] = [20, 10]
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
# Credit PyramidNet:     https://arxiv.org/pdf/1610.02915.pdf
# Credit DSConv:         https://arxiv.org/pdf/1610.02357.pdf
# Credit ShakeDrop:      https://arxiv.org/pdf/1802.02375.pdf
# Credit Cutout:         https://arxiv.org/pdf/1708.04552.pdf
# Credit Implementation: https://github.com/osmr/imgclsmob/blob/master/pytorch/pytorchcv/models
# Credit Implementation: https://github.com/uoguelph-mlrg/Cutout

class ShakeDrop(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x, b, alpha):
    y = (b + alpha - b * alpha) * x
    ctx.save_for_backward(b)
    return y

  @staticmethod
  def backward(ctx, dy):
    beta = torch.rand(dy.size(0), dtype=dy.dtype, device=dy.device).view(-1, 1, 1, 1)
    b, = ctx.saved_tensors
    return (b + beta - b * beta) * dy, None, None

class DSConv(nn.Module):
  def __init__(self, in_channels, out_channels, stride):
    super(DSConv, self).__init__()
    self.depthwise = nn.Conv2d(in_channels, in_channels, 3, stride=stride, \
                               padding=1, groups=in_channels, bias=False)
    self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=False)
  
  def forward(self, x):
    return self.pointwise(self.depthwise(x))

class Bottleneck(nn.Module):
  neck_ratio = 4
  
  def __init__(self, in_channels, out_channels, stride=1, p=1.0, downsample=None):
    super(Bottleneck, self).__init__()
    neck_channels = out_channels // Bottleneck.neck_ratio
    self.features = nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.Conv2d(in_channels, neck_channels, 1, bias=False), # downsample dim d
        nn.BatchNorm2d(neck_channels),
        nn.ReLU(inplace=True),
        DSConv(neck_channels, neck_channels, stride), # convolve with small dim
        nn.BatchNorm2d(neck_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(neck_channels, out_channels, 1, bias=False), # upsample dim d
        nn.BatchNorm2d(out_channels)
    )
    
    self.pad = (0, 0, 0, 0, 0, out_channels - in_channels) # dim d
    self.stride = stride
    self.shakedrop = ShakeDrop.apply
    self.p = p # probability for shakedrop
    self.downsample = downsample # dim w x h
  
  def forward(self, x):
    r = self.features(x)
    if self.training:
      b = torch.bernoulli(torch.full((1,), self.p, dtype=r.dtype, device=r.device))
      alpha = torch.empty(x.size(0), dtype=x.dtype, device=x.device).view(-1, 1, 1, 1).uniform_(-1.0, 1.0)
      r = self.shakedrop(r, b, alpha)
    else:
      r = self.p * r
    
    x = self.downsample(x) if self.downsample else x
    x = nn.functional.pad(x, self.pad)
    return x + r

class PyramidNet(nn.Module):
  def __init__(self, num_classes, in_size=(32, 32), in_channels=3,
               init_channels=16, alpha=200, model=272):
    super(PyramidNet, self).__init__()
    self.in_channels = init_channels
    self.neck_channels = init_channels
    block_count = (model - 2) // 3 # minus init conv + linear, 3 layers per block
    group_depth = block_count // 3 # 3 groups in total, 32x32 -> 16x16 -> 8x8
    self.add_rate = alpha / block_count # additive increase for bottleneck
    final_drop_p = 0.5 # shakedrop p on final layer = 1 - final_drop_p
    drop_ps = [1 - (i + 1) / block_count * final_drop_p for i in range(block_count)]
    self.features = nn.Sequential(
        nn.Conv2d(in_channels, init_channels, 3, padding=1, bias=False),
        nn.BatchNorm2d(init_channels),
        self._pyr_group(1, drop_ps[:group_depth]),
        self._pyr_group(2, drop_ps[group_depth: group_depth*2]),
        self._pyr_group(2, drop_ps[group_depth*2:]),
        nn.BatchNorm2d(self.in_channels),
        nn.ReLU(inplace=True),
        nn.AdaptiveAvgPool2d(1)
    )
    self.classifier = nn.Linear(self.in_channels, num_classes)
    
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
      elif isinstance(m, nn.Linear):
        m.bias.data.zero_()
  
  def _pyr_group(self, stride, ps):
    downsample = None if stride == 1 else nn.AvgPool2d(2, stride=stride, ceil_mode=True)
    layers = []
    self.neck_channels += self.add_rate
    layers.append(Bottleneck(self.in_channels, int(round(self.neck_channels)) * Bottleneck.neck_ratio, 
                             stride=stride, p=ps[0], downsample=downsample))
    self.in_channels = int(round(self.neck_channels)) * Bottleneck.neck_ratio
    for p in ps[1:]:
      self.neck_channels += self.add_rate
      layers.append(Bottleneck(self.in_channels, int(round(self.neck_channels)) * Bottleneck.neck_ratio, p=p))
      self.in_channels = int(round(self.neck_channels)) * Bottleneck.neck_ratio
    return nn.Sequential(*layers)
  
  def forward(self, x):
    x = self.features(x)
    x.view(x.size(0), -1)
    x = self.classifier(torch.squeeze(x))
    return x

class Cutout():
  def __init__(self, n, size):
    self.n = n
    self.size = size
  
  def __call__(self, img):
    h, w = img.size(1), img.size(2)
    mask = np.ones((h, w), np.float32)
    for _ in range(self.n):
      x, y = np.random.randint(w), np.random.randint(h)
      x1 = np.clip(x - self.size // 2, 0, w)
      x2 = np.clip(x + self.size // 2, 0, w)
      y1 = np.clip(y - self.size // 2, 0, h)
      y2 = np.clip(y + self.size // 2, 0, h)
      mask[y1: y2, x1: x2] = 0
    mask = torch.from_numpy(mask).expand_as(img)
    return img * mask

In [0]:
class_count = 10
dataset = 'cifar' + str(class_count)
if dataset == 'cifar10':
  mean = [0.4914, 0.4822, 0.4465]
  std = [0.2470, 0.2435, 0.2616]
elif dataset == 'cifar100':
  mean = [0.5071, 0.4867, 0.4408]
  std = [0.2675, 0.2565, 0.2761]
else:
  raise AttributeError('Invalid Dataset')

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    Cutout(1, 8)
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

path = '/content/gdrive/My Drive/cifar_cnn/'
classes = tuple(next(csv.reader(open(path + dataset + '_classes.csv'))))
if dataset == 'cifar10':
  train_dataset = torchvision.datasets.CIFAR10(root=path, train=True, transform=train_transform)
elif dataset == 'cifar100':
  train_dataset = torchvision.datasets.CIFAR100(root=path, train=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

if dataset == 'cifar10':
  test_dataset = torchvision.datasets.CIFAR10(root=path, train=False, transform=test_transform)
elif dataset == 'cifar100':
  test_dataset = torchvision.datasets.CIFAR100(root=path, train=False, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True, num_workers=4)
final_test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
model = PyramidNet(class_count).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-5)

def save_checkpoint(state, filename=path + dataset + '_checkpoint.pt'):
  torch.save(state, filename)

def load_checkpoint(model, optimizer, filename=path + dataset + '_checkpoint.pt'):
  print('Loading Checkpoint...')
  checkpoint = torch.load(filename)
  start_epoch = checkpoint['epoch']
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  print('Loaded Checkpoint (Epoch {})'.format(start_epoch))
  return start_epoch, checkpoint['train_loss'], checkpoint['test_loss'], checkpoint['test_acc']

In [0]:
load = True
num_epochs = 10
total_step = len(train_loader)
mark = 60
start_epoch, train_loss, test_loss, test_acc = load_checkpoint(model, optimizer) if load else (0, [], [], [])

print('Start Training...')
start = time.time()
benchmark = start
for e in range(start_epoch, start_epoch + num_epochs):
  epoch = e + 1
  print('Epoch {}'.format(epoch))
  for i, (data, labels) in enumerate(train_loader):
    model.train() # training mode
    data, labels = data.to(device), labels.to(device)
    outputs = model(data)
    loss = criterion(outputs, labels)
    train_loss.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i + 1) % mark == 0:
      delta, benchmark = time.time() - benchmark, time.time()
      print('Epoch [{}/{}], Step [{}/{}], Training Loss: {:.4f}, Time: {:.2f}s'
            .format(epoch, start_epoch + num_epochs, i + 1, total_step, loss, delta))
    
    if i % 2:
      with torch.no_grad():
        model.eval() # eval mode
        data, labels = next(iter(test_loader))
        data, labels = data.to(device), labels.to(device)
        outputs = model.forward(data)
        loss = criterion(outputs, labels).item()
        _, predicted = torch.max(outputs.data, 1)
        acc = (predicted == labels).sum().item() / labels.size(0)
        test_loss.append(loss)
        test_acc.append(acc)
      if (i + 1) % mark == 0:
        print('Epoch [{}/{}], Step [{}/{}], Validation Loss: {:.4f}, Accuracy: {:.4f}'
              .format(epoch, start_epoch + num_epochs, i + 1, total_step, loss, acc))

  save_checkpoint({
      'epoch': epoch,
      'state_dict': model.state_dict(),
      'optimizer' : optimizer.state_dict(),
      'train_loss': train_loss,
      'test_loss': test_loss,
      'test_acc': test_acc
  })
  benchmark = time.time()

# Note: avg 78s per 60 batches size 64 on Tesla T4
print('Training Completed on {} Epochs, Time Elapsed: {}'
      .format(num_epochs, datetime.timedelta(seconds=round(benchmark - start))))

In [0]:
train_index = 0 # initial loss may corrupt average
test_index = 30
training_moving_avg = pd.DataFrame(data=train_loss[train_index:]).rolling(window=total_step, min_periods=1, center=True).mean().values
test_moving_avg = pd.DataFrame(data=test_loss[test_index:]).rolling(window=total_step//2, min_periods=1, center=True).mean().values
acc_moving_avg = pd.DataFrame(data=test_acc[test_index:]).rolling(window=total_step//2, min_periods=1, center=True).mean().values
ma_max = np.amax([np.amax(training_moving_avg), np.amax(test_moving_avg)])
training_moving_avg = training_moving_avg / ma_max
test_moving_avg = test_moving_avg / ma_max
plt.plot(np.arange(train_index, len(train_loss)), training_moving_avg, label='Training Loss')
plt.plot(np.arange(test_index*2, len(train_loss), 2), test_moving_avg, label='Validation Loss')
plt.plot(np.arange(test_index*2, len(train_loss), 2), acc_moving_avg, label='Validation Accuracy')
plt.legend(loc='upper right')
plt.show()

In [0]:
print('Start Testing...')
start = time.time()
with torch.no_grad():
  model.eval() # eval mode
  correct, total = 0, 0
  predict, actual = [], []
  for i, (data, labels) in enumerate(final_test_loader):
    if (i+1) % 16 == 0:
      print('Time elapsed: {}s'.format(time.time() - start))
    data, labels = data.to(device), labels.to(device)
    outputs = model(data)
    _, predicted = torch.max(outputs.data, 1)
    
    total += labels.size(0)
    predict.extend(predicted)
    actual.extend(labels)
    correct += (predicted == labels).sum().item()
print('Time elapsed: {:.2f}s'.format(time.time() - start))
print('Accuracy on the {:.2f} test images: {}%'.format(total, 100 * correct / total))

In [0]:
pl = [p.cpu().numpy().tolist() for p in predict]
gt = [a.cpu().numpy().tolist() for a in actual]

# Confusion matrix helpful when visualizing small number of classses
predict = pd.Series(pl, name='Predicted')
actual = pd.Series(gt, name='Actual')
confusion = pd.crosstab(actual, predict)
print(confusion)

acc_dict = {}
for num in range(len(classes)):
  acc_dict[classes[num]] = np.sum([p and g for (p, g) in zip(np.array(pl) == np.array(gt), np.array(gt) == num)]) / np.sum(np.array(gt) == num)
for class_name, acc in acc_dict.items():
# for class_name, acc in sorted(acc_dict.items(), key=lambda kv: kv[1]): # sort by accuracy
  print('{}: {}'.format(class_name, acc))