<a href="https://colab.research.google.com/github/guzey/hse_dl/blob/master/fixup_works_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# ---------------------------------------------------------------------------- #
# An implementation of https://arxiv.org/pdf/1512.03385.pdf                    #
# See section 4.2 for the model architecture on CIFAR-10                       #
# Some part of the code was referenced from below                              #
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py   #
# ---------------------------------------------------------------------------- #

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import random
from IPython.display import clear_output

In [0]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Image preprocessing modules
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                             train=True, 
                                             transform=transform,
                                             download=True)

test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                            train=False, 
                                            transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=100, 
                                          shuffle=False)

In [0]:
# For updating learning rate
def update_lr(optimizer, lr):    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def plot_history(train_history, val_history, title='loss'):
    plt.figure()
    plt.title('{}'.format(title))
    plt.plot(train_history, label='train', zorder=1)
    
    points = np.array(val_history)
    
    plt.scatter(points[:, 0], points[:, 1], marker='+', s=180, c='orange', label='val', zorder=2)
    plt.xlabel('train steps')
    
    plt.legend(loc='best')
    plt.grid()

    plt.show()

# Train epoch
def train_epoch(model):
    curr_lr = learning_rate
    loss_log, acc_log = [], []
    model.train()
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            output = model(images)
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            acc = correct / total
            acc_log.append(acc)
            
            
            loss = criterion(output, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss = loss.item()
            loss_log.append(loss)

#             if (i+1) % 100 == 0:
#                 print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
#                        .format(epoch+1, num_epochs, i+1, total_step, loss)

    # Decay learning rate
#     if (epoch+1) % 1 == 0:
#     print(curr_lr)
#     curr_lr /= 3
#     update_lr(optimizer, curr_lr)
                
    return loss_log, acc_log
          
# train the model

def train(model):
                       
    train_log, train_acc_log = [], []
    val_log, val_acc_log = [], []
                       
    total_step = len(train_loader)
    curr_lr = learning_rate
                       
    for epoch in range(num_epochs):
    
        print("Epoch {} of {}".format(epoch, num_epochs))
        train_loss, train_acc = train_epoch(model)
        val_loss, val_acc = test(model)
        
        train_log.extend(train_loss)
        train_acc_log.extend(train_acc)

        # 100 = batch_size
        steps = len(train_dataset) / 100
        val_log.append((steps * (epoch + 1), np.mean(val_loss)))
        val_acc_log.append((steps * (epoch + 1), np.mean(val_acc)))

        clear_output()
        plot_history(train_log, val_log)    
        plot_history(train_acc_log, val_acc_log, title='accuracy')
        print("Epoch {} error = {:.2%}".format(epoch, 1 - val_acc_log[-1][1]))

  
    print("Final error: {:.2%}".format(1 - val_acc_log[-1][1]))

# Test the model
def test(model):
  loss_log, acc_log = [], []
  model.eval()
  with torch.no_grad():
      correct = 0
      total = 0
      for images, labels in test_loader:
          images = images.to(device)
          labels = labels.to(device)
          output = model(images)
          loss = criterion(output, labels)
          _, predicted = torch.max(output.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()
          acc = correct / total
          acc_log.append(acc)
        
          loss = loss.item()
          loss_log.append(loss)
      return loss_log, acc_log

#       print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
# torch.save(model.state_dict(), 'resnet.ckpt')

In [0]:
# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=1, bias=True)

# 3. add a scalar multiplier (initialized at 1) in every branch 
# mult layer
# https://discuss.pytorch.org/t/is-scale-layer-available-in-pytorch/7954/8
class ScaleLayer(nn.Module):

  def __init__(self, init_value=1):
    super().__init__()
    self.scale = nn.Parameter(torch.FloatTensor([init_value]))

  def forward(self, input):
       return input * self.scale
    
# bias layer
class BiasLayer(nn.Module):

  def __init__(self, init_value=0):
    super().__init__()
    self.bias = nn.Parameter(torch.FloatTensor([init_value]))

  def forward(self, input):
       return input + self.bias

# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.bias1 = BiasLayer()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bias2 = BiasLayer()
        self.relu = nn.ReLU(inplace=True)
        self.bias3 = BiasLayer()
        self.conv2 = conv3x3(out_channels, out_channels)
    # 3. add a scalar multiplier (initialized at 1) in every branch 
        self.mult = ScaleLayer()
        self.bias4 = BiasLayer()
        self.downsample = downsample
  
  
        # 1. initialize the last layer of each residual branch to 0
        if self.downsample:  
          torch.nn.init.constant_(self.downsample.weight, 0)
          torch.nn.init.constant_(self.downsample.bias, 0)
        else:
          torch.nn.init.constant_(self.conv2.weight, 0)
          torch.nn.init.constant_(self.conv2.bias, 0)
      # bias=False in conv3x3 so this doesn't work NOTE FOR LATER TO TRY TO SET TRUE
      # SET BIAS TO TRUE
#         
      
        # 2. initialize every other layer using a standard method (e.g., He et al. (2015))
        # NOTE FOR LATER TRY FAN_OUT
      # comment this out because then initialize wight layers inside residual branches lol
#         torch.nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='leaky_relu')
#         torch.nn.init.kaiming_normal_(self.conv1.bias, mode='fan_in', nonlinearity='leaky_relu')
        
        # 2. L - number of residual branches; m - a small number, e.g. 2 or 3
        # scale only the weight layers inside residual branches by L^(−1)/2m−2,
        # i.e. can scale by sqrt(L)
        scale_factor = torch.sqrt(torch.tensor(num_layers*3, dtype=torch.float64))
        self.conv1.weight.data.mul_(scale_factor)
        
  
        
    def forward(self, x):
        residual = x
        # 3. add a scalar bias (initialized at 0) before each convolution, linear, and element-wise activation layer
        out = self.bias1(x)
        out = self.conv1(out)
#         out = self.bn1(out)
        out = self.bias2(out)
        out = self.relu(out)
        out = self.bias3(out)
        out = self.conv2(out)
    # 3. add a scalar multiplier (initialized at 1) in every branch 
        out = self.mult(out)
        out = self.bias4(out)
#  new       out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

# ResNet
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16)
#         self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[1], 2)
        self.layer3 = self.make_layer(block, 64, layers[2], 2)
#         self.bn9 = nn.BatchNorm2d(64)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)
        
        # 1. initialize the classification layer to 0
        torch.nn.init.constant_(self.fc.weight, 0)
        torch.nn.init.constant_(self.fc.bias, 0)
        
        # 2. initialize every other layer using a standard method (e.g., He et al. (2015))
        # NOTE FOR LATER TRY FAN_OUT
        torch.nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='leaky_relu')
#         torch.nn.init.kaiming_normal_(self.conv.bias, mode='fan_in', nonlinearity='leaky_relu')
        
    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = conv3x3(self.in_channels, out_channels, stride=stride)
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv(x)
#         out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
#         out = self.bn9(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [0]:
# Zhang
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class FixupBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(FixupBasicBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.bias1a = nn.Parameter(torch.zeros(1))
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bias1b = nn.Parameter(torch.zeros(1))
        self.relu = nn.ReLU(inplace=True)
        self.bias2a = nn.Parameter(torch.zeros(1))
        self.conv2 = conv3x3(planes, planes)
        self.scale = nn.Parameter(torch.ones(1))
        self.bias2b = nn.Parameter(torch.zeros(1))
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x + self.bias1a)
        out = self.relu(out + self.bias1b)

        out = self.conv2(out + self.bias2a)
        out = out * self.scale + self.bias2b

        if self.downsample is not None:
            identity = self.downsample(x + self.bias1a)
            identity = torch.cat((identity, torch.zeros_like(identity)), 1)

        out += identity
        out = self.relu(out)

        return out


class FixupResNet(nn.Module):

    def __init__(self, block, layers, num_classes=10):
        super(FixupResNet, self).__init__()
        self.num_layers = sum(layers)
        self.inplanes = 16
        self.conv1 = conv3x3(3, 16)
        self.bias1 = nn.Parameter(torch.zeros(1))
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.bias2 = nn.Parameter(torch.zeros(1))
        self.fc = nn.Linear(64, num_classes)

        for m in self.modules():
            if isinstance(m, FixupBasicBlock):
                nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
                nn.init.constant_(m.conv2.weight, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 0)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1:
            downsample = nn.AvgPool2d(1, stride=stride)

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x + self.bias1)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x + self.bias2)

        return x

In [0]:
# Hyper-parameters
num_layers = 4
num_epochs = 10
learning_rate = 0.001

In [0]:
# model = FixupResNet(FixupBasicBlock, [num_layers, num_layers, num_layers]).to(device)
model = ResNet(ResidualBlock, [num_layers, num_layers, num_layers]).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [231]:
train(model)

Epoch 0 of 10


KeyboardInterrupt: ignored