In [1]:
import argparse
import os, sys
import time
import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

# Data Preprocess

In [2]:
transform_train  = transforms.Compose([
        transforms.RandomCrop(32, padding = 4),
        transforms.RandomHorizontalFlip(p=0.5),
        # range 0 to 1
        transforms.ToTensor(), 
        # ref mean and std
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

# ResNet20 Architectrue

In [3]:
class Block(nn.Module):

    def __init__(self, in_channel, out_channel, stride=1, downsize=None):
        
        super(Block, self).__init__()
        # 32 -> 16, 16->8 need stride=2
        self.stride = stride
        
        # each block contains 2 3x3 conv
        self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        # second filter always use stride=1 to keep orginal size
        
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        
        self.downsize = downsize # a function used to make the size conform
        
    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsize is not None:
            # resize the volume that goes through the skip connection
            residual = self.downsize(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, block):
        super(ResNet, self).__init__()
        
        self.in_channel = 16
        self.downsize = None
        
        # very first layer, change [128 3 32 32] to [128 16 32 32]
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()
        
        # input image: 32 x 32, then 3x3x16 conv with stride = 1 and pad = 1, so we have 32x32x16 in the first block
        self.layer1 = self.build_layer(block, 16, 3, stride=1)
        self.layer2 = self.build_layer(block, 32, 3, stride=2)
        self.layer3 = self.build_layer(block, 64, 3, stride=2)
        
        self.linear = nn.Linear(64, 10)
        
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                m.bias.data.zero_() 
        
    def build_layer(self, block, out_channel, num_block=3, stride=1):
        if stride != 1 or self.in_channel != out_channel:
            self.downsize = nn.Sequential(
                # use 1x1 conv to do the size matching
                nn.Conv2d(self.in_channel, out_channel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channel),
            )
        else:
            self.downsize = None
        # should have 3 layers
        layers = []
        layers.append(block(self.in_channel, out_channel, stride, self.downsize))
        self.in_channel = out_channel
        # only the first block need to downsize
        for i in range(1, num_block):
            layers.append(block(self.in_channel, out_channel))

        return nn.Sequential(*layers)
    
    def forward(self, x):
        # 32x32x3 -> 32x32x16
        out = self.conv1(x)
        out = self.bn1(out)     
        out = self.relu(out)

        out = self.layer1(out)   # 32x32x16
        out = self.layer2(out)   # 16x16x32
        out = self.layer3(out)   # 8x8x64

        out = F.avg_pool2d(input=out, kernel_size=out.size(3))
        out = out.view(out.size(0), -1)
        
        out = self.linear(out)

        return out

# Param Setting

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = ResNet(Block)
net = net.to(device)

INITIAL_LR = 0.1
MOMENTUM = 0.9
REG = 1e-4
EPOCHS = 100
DECAY_EPOCHS = 2
DECAY = 0.1

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=net.parameters(), lr=current_learning_rate, momentum=MOMENTUM, weight_decay=REG, nesterov=False)

In [7]:
CHECKPOINT_PATH = "./saved_model"

TRAIN_FROM_SCRATCH = True

CKPT_PATH = "./saved_model/model.h5"

def get_checkpoint(ckpt_path):
    try:
        ckpt = torch.load(ckpt_path)
    except Exception as e:
        print (e)
        return None
    return ckpt

ckpt = get_checkpoint(CKPT_PATH)
if ckpt is None or TRAIN_FROM_SCRATCH:
    if not TRAIN_FROM_SCRATCH:
        print("Checkpoint not found.")
    print("Training from scratch ...")
    start_epoch = 0
    current_learning_rate = INITIAL_LR
else:
    print("Successfully loaded checkpoint: %s" %CKPT_PATH)
    net.load_state_dict(ckpt['net'])
    start_epoch = ckpt['epoch'] + 1
    current_learning_rate = ckpt['lr']
    print("Starting from epoch %d " %start_epoch)

print("Starting from learning rate %f:" %current_learning_rate)

Training from scratch ...
Starting from learning rate 0.100000:


# Start Training

In [None]:
global_step = 0
best_val_acc = 0

train_loss_list = []
val_loss_list = []
train_acc_list = []
val_acc_list = []

for i in range(start_epoch, EPOCHS):
    print(datetime.datetime.now())
    net.train()
    print("Epoch %d:" %i)

    total_examples = 0
    correct_examples = 0

    train_loss = 0
    train_acc = 0
    
    # Train the training dataset for 1 epoch.
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        _, predicted = outputs.max(1)
        total_examples += predicted.size(0) # 128 for each batch
        correct_examples += predicted.eq(targets).sum().item()
        train_loss += loss
        global_step += 1
                
    avg_loss = train_loss / (batch_idx + 1)
    avg_acc = correct_examples / total_examples
    
    train_loss_list.append(avg_loss)
    train_acc_list.append(avg_acc)

    print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))
    print(datetime.datetime.now())
    # Validate on the validation dataset
    # validate for each epoch
    print("Validation...")
    total_examples = 0
    correct_examples = 0
    
    net.eval()

    val_loss = 0
    val_acc = 0
    
    # Disable gradient during validation
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(valloader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            # Generate output from the DNN.
            outputs = net(inputs)
            loss = criterion(outputs, targets)            
            # Calculate predicted labels
            _, predicted = outputs.max(1)
            total_examples += predicted.size(0)
            correct_examples += predicted.eq(targets).sum().item()
            val_loss += loss

    avg_loss = val_loss / len(valloader)
    avg_acc = correct_examples / total_examples
    
    val_loss_list.append(avg_loss)
    val_acc_list.append(avg_acc)
    
    print("Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc))

    # Handle the learning rate scheduler.
    if i % DECAY_EPOCHS == 0 and i != 0:
        current_learning_rate = current_learning_rate * DECAY
        for param_group in optimizer.param_groups:
            param_group['lr'] = current_learning_rate
        print("Current learning rate has decayed to %f" %current_learning_rate)
    
    # Save for checkpoint
    if avg_acc > best_val_acc:
        best_val_acc = avg_acc
        if not os.path.exists(CHECKPOINT_PATH):
            os.makedirs(CHECKPOINT_PATH)
        print("Saving ...")
        state = {'net': net.state_dict(),
                 'epoch': i,
                 'lr': current_learning_rate}
        torch.save(state, os.path.join(CHECKPOINT_PATH, 'model.h5'))

print("Optimization finished.")