## Import torch and model

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [3]:
import sys
sys.path.append("../common")

from model_generator import ModelGenerator
from net import Net

## Set hyper params

In [4]:
batch_size = 64

dropout_on = False
batchnorm_on = True 

scheduler_step_size = 20

## for recasting

lr_recasting = 0.001
num_epoch_recasting = 60

## for fine tune

lr_fine_tune = 0.001
num_epoch_fine_tune = 100

In [5]:
model_gen = ModelGenerator(dropout = dropout_on, batchnorm = batchnorm_on)

model_gen.CifarVgg16Config(cifar = 10)

# Recasting block
# 0-13: conv block
recasting_block_indices = range(0, 13)
target_block_type = 'ConvBlock'

# Compression rate
# the number of filters decreased to [compression_rate]

compression_ratio = 0.4

## file path
pretrained_model = './cifar10_vgg16_pretrained.pth'
compressed_model = './cifar10_vgg16_to_convenet.pth'

## Load dataset

In [6]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))])

transform_train = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, 4),
     transforms.ToTensor(),
     transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


## Load pre-trained model (teacher network)

In [7]:
model = model_gen.GetCifarVgg16()
teacher = Net(model)

state = torch.load(pretrained_model)
teacher.LoadFromStateDict(state)

teacher.Gpu()

correct = 0
total = 0
teacher.TestMode()
for data in testloader:
    images, labels = data
    outputs = teacher(Variable(images.cuda()))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()

print('Accuracy of the network on the 10000 test images: %4.2f %%' % (100 * correct / total))

Accuracy of the network on the 10000 test images: 93.01 %


## Define student network

In [8]:
model = model_gen.GetCifarVgg16()
student = Net(model)

state = torch.load(pretrained_model)
student.LoadFromStateDict(state)

student.Gpu()

In [9]:
correct = 0
total = 0
student.TestMode()
for data in testloader:
    images, labels = data
    outputs = student(Variable(images.cuda()))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()

print('Accuracy of the network on the 10000 test images: %4.2f %%' % (100 * correct / total))

Accuracy of the network on the 10000 test images: 93.01 %


## Sequential recasting

In [10]:
# define MSE loss
MSE = nn.MSELoss()

teacher.TestMode()

for block_idx in recasting_block_indices:
    
    ################################################    Recasting process ######################################################
    # current block recasting
    
    config = student.GetBlockConfig(block_idx)
    
    config[2] = round(config[2] * compression_ratio)    # apply compression ratio
    
    # Handling corner case: bottleneck block recasting
    if len(config) == 5:                         
        is_bottleneck = True
        mid_feature = config[4]
        # We reduce the output dimension of bottleneck block.
        # output dimension of new block is the same with output dimension of 3x3 conv in bottleneck block
        config[4] = round(mid_feature * compression_ratio)
    else :
        is_bottleneck = False
        
    new_block = model_gen.GenNewBlock([target_block_type, config])
    source_block_type = config[0]
    
    student.Recasting(block_idx, new_block)
    
    
    # next block recasting
    
    config = student.GetBlockConfig(block_idx + 1)
    
    config[1] = round(config[1] * compression_ratio)    # apply compression ratio
    
    # Handling corner case: bottleneck block recasting
    if is_bottleneck == True:                         
        # Change next input dim to output dim of target block
        config[1] = round(mid_feature * compression_ratio)
    
    new_block = model_gen.GenNewBlock([config[0], config])
    student.Recasting(block_idx + 1, new_block)
    
    ################################################    Recasting process end ##################################################
    
    
    student.Gpu()
    
    params = student.GetCurrParams(block_idx)
    
    optimizer = optim.Adam(params, lr = lr_recasting)
    scheduler = lr_scheduler.StepLR(optimizer, step_size = scheduler_step_size)
    
    print('\nBlock %d recasting is done (%s -> %s).' %(block_idx, source_block_type, target_block_type))
    print('Training start\n')
    for epoch in range(num_epoch_recasting):  # loop over the dataset multiple times
        
        running_loss = 0.0
        scheduler.step()
        
        student.TrainMode()
            
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data

            # wrap them in Variable
            inputs = Variable(inputs.cuda())
            labels = Variable(labels.cuda())

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            corrects = teacher(inputs, next_block= block_idx + 1)
            outputs = student(inputs, next_block = block_idx + 1)

            targets = Variable(corrects.data.clone())
            
            loss = MSE(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss = (running_loss * i + loss.cpu().data.numpy()) / (i+1)

        
        correct = 0
        total = 0
        student.TestMode()
        for data in testloader:
            images, labels = data
            outputs = student(Variable(images.cuda()))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.cuda()).sum()
        
        test_acc = 100 * correct / total
        

        print('(%d/%d) epoch end, loss: %3.6f, Test Acc: %4.2f %%' %(epoch + 1, num_epoch_recasting, running_loss, test_acc))
    
    
print('\nSequential recasting is finished')


Block 0 recasting is done (ConvBlock -> ConvBlock).
Training start

(1/60) epoch end, loss: 0.017315, Test Acc: 92.67 %
(2/60) epoch end, loss: 0.002007, Test Acc: 92.75 %
(3/60) epoch end, loss: 0.001264, Test Acc: 92.86 %
(4/60) epoch end, loss: 0.000794, Test Acc: 92.82 %
(5/60) epoch end, loss: 0.000622, Test Acc: 92.90 %
(6/60) epoch end, loss: 0.000568, Test Acc: 92.90 %
(7/60) epoch end, loss: 0.000529, Test Acc: 92.88 %
(8/60) epoch end, loss: 0.000502, Test Acc: 92.92 %
(9/60) epoch end, loss: 0.000484, Test Acc: 92.93 %
(10/60) epoch end, loss: 0.000475, Test Acc: 93.09 %
(11/60) epoch end, loss: 0.000460, Test Acc: 93.07 %
(12/60) epoch end, loss: 0.000447, Test Acc: 92.91 %
(13/60) epoch end, loss: 0.000455, Test Acc: 92.81 %
(14/60) epoch end, loss: 0.000435, Test Acc: 92.89 %
(15/60) epoch end, loss: 0.000441, Test Acc: 93.09 %
(16/60) epoch end, loss: 0.000436, Test Acc: 92.97 %
(17/60) epoch end, loss: 0.000428, Test Acc: 92.98 %
(18/60) epoch end, loss: 0.000424, Test

(33/60) epoch end, loss: 0.002855, Test Acc: 92.27 %
(34/60) epoch end, loss: 0.002853, Test Acc: 92.32 %
(35/60) epoch end, loss: 0.002852, Test Acc: 92.36 %
(36/60) epoch end, loss: 0.002852, Test Acc: 92.34 %
(37/60) epoch end, loss: 0.002852, Test Acc: 92.39 %
(38/60) epoch end, loss: 0.002852, Test Acc: 92.33 %
(39/60) epoch end, loss: 0.002851, Test Acc: 92.43 %
(40/60) epoch end, loss: 0.002850, Test Acc: 92.35 %
(41/60) epoch end, loss: 0.002843, Test Acc: 92.32 %
(42/60) epoch end, loss: 0.002840, Test Acc: 92.40 %
(43/60) epoch end, loss: 0.002843, Test Acc: 92.34 %
(44/60) epoch end, loss: 0.002842, Test Acc: 92.32 %
(45/60) epoch end, loss: 0.002843, Test Acc: 92.47 %
(46/60) epoch end, loss: 0.002843, Test Acc: 92.31 %
(47/60) epoch end, loss: 0.002842, Test Acc: 92.36 %
(48/60) epoch end, loss: 0.002846, Test Acc: 92.33 %
(49/60) epoch end, loss: 0.002844, Test Acc: 92.38 %
(50/60) epoch end, loss: 0.002843, Test Acc: 92.37 %
(51/60) epoch end, loss: 0.002839, Test Acc: 9

(5/60) epoch end, loss: 0.002483, Test Acc: 89.62 %
(6/60) epoch end, loss: 0.002405, Test Acc: 89.36 %
(7/60) epoch end, loss: 0.002357, Test Acc: 89.55 %
(8/60) epoch end, loss: 0.002330, Test Acc: 89.87 %
(9/60) epoch end, loss: 0.002303, Test Acc: 89.73 %
(10/60) epoch end, loss: 0.002274, Test Acc: 90.05 %
(11/60) epoch end, loss: 0.002252, Test Acc: 89.91 %
(12/60) epoch end, loss: 0.002237, Test Acc: 90.10 %
(13/60) epoch end, loss: 0.002217, Test Acc: 89.77 %
(14/60) epoch end, loss: 0.002228, Test Acc: 89.96 %
(15/60) epoch end, loss: 0.002189, Test Acc: 89.87 %
(16/60) epoch end, loss: 0.002171, Test Acc: 90.17 %
(17/60) epoch end, loss: 0.002164, Test Acc: 90.09 %
(18/60) epoch end, loss: 0.002160, Test Acc: 90.40 %
(19/60) epoch end, loss: 0.002140, Test Acc: 89.59 %
(20/60) epoch end, loss: 0.002131, Test Acc: 90.51 %
(21/60) epoch end, loss: 0.002076, Test Acc: 90.79 %
(22/60) epoch end, loss: 0.002070, Test Acc: 90.81 %
(23/60) epoch end, loss: 0.002063, Test Acc: 90.75 

(38/60) epoch end, loss: 0.000298, Test Acc: 91.75 %
(39/60) epoch end, loss: 0.000299, Test Acc: 91.92 %
(40/60) epoch end, loss: 0.000297, Test Acc: 91.76 %
(41/60) epoch end, loss: 0.000294, Test Acc: 91.89 %
(42/60) epoch end, loss: 0.000294, Test Acc: 91.68 %
(43/60) epoch end, loss: 0.000292, Test Acc: 91.74 %
(44/60) epoch end, loss: 0.000293, Test Acc: 91.79 %
(45/60) epoch end, loss: 0.000294, Test Acc: 91.90 %
(46/60) epoch end, loss: 0.000294, Test Acc: 91.72 %
(47/60) epoch end, loss: 0.000294, Test Acc: 91.81 %
(48/60) epoch end, loss: 0.000293, Test Acc: 91.61 %
(49/60) epoch end, loss: 0.000294, Test Acc: 91.91 %
(50/60) epoch end, loss: 0.000294, Test Acc: 91.80 %
(51/60) epoch end, loss: 0.000291, Test Acc: 91.90 %
(52/60) epoch end, loss: 0.000292, Test Acc: 91.83 %
(53/60) epoch end, loss: 0.000292, Test Acc: 91.87 %
(54/60) epoch end, loss: 0.000291, Test Acc: 91.74 %
(55/60) epoch end, loss: 0.000292, Test Acc: 91.84 %
(56/60) epoch end, loss: 0.000292, Test Acc: 9

(10/60) epoch end, loss: 0.000073, Test Acc: 89.45 %
(11/60) epoch end, loss: 0.000079, Test Acc: 89.10 %
(12/60) epoch end, loss: 0.000066, Test Acc: 90.12 %
(13/60) epoch end, loss: 0.000059, Test Acc: 89.78 %
(14/60) epoch end, loss: 0.000060, Test Acc: 88.58 %
(15/60) epoch end, loss: 0.000123, Test Acc: 89.19 %
(16/60) epoch end, loss: 0.000066, Test Acc: 89.56 %
(17/60) epoch end, loss: 0.000060, Test Acc: 89.18 %
(18/60) epoch end, loss: 0.000055, Test Acc: 88.77 %
(19/60) epoch end, loss: 0.000055, Test Acc: 78.83 %
(20/60) epoch end, loss: 0.000069, Test Acc: 88.11 %
(21/60) epoch end, loss: 0.000047, Test Acc: 91.15 %
(22/60) epoch end, loss: 0.000042, Test Acc: 90.73 %
(23/60) epoch end, loss: 0.000040, Test Acc: 91.41 %
(24/60) epoch end, loss: 0.000039, Test Acc: 91.28 %
(25/60) epoch end, loss: 0.000037, Test Acc: 91.05 %
(26/60) epoch end, loss: 0.000039, Test Acc: 90.92 %
(27/60) epoch end, loss: 0.000037, Test Acc: 91.39 %
(28/60) epoch end, loss: 0.000036, Test Acc: 9

(43/60) epoch end, loss: 0.003856, Test Acc: 90.42 %
(44/60) epoch end, loss: 0.003912, Test Acc: 90.17 %
(45/60) epoch end, loss: 0.003879, Test Acc: 89.86 %
(46/60) epoch end, loss: 0.003875, Test Acc: 90.29 %
(47/60) epoch end, loss: 0.003773, Test Acc: 90.35 %
(48/60) epoch end, loss: 0.003797, Test Acc: 90.45 %
(49/60) epoch end, loss: 0.003759, Test Acc: 90.63 %
(50/60) epoch end, loss: 0.003791, Test Acc: 90.61 %
(51/60) epoch end, loss: 0.003752, Test Acc: 90.44 %
(52/60) epoch end, loss: 0.003812, Test Acc: 90.64 %
(53/60) epoch end, loss: 0.003751, Test Acc: 90.63 %
(54/60) epoch end, loss: 0.003751, Test Acc: 90.67 %
(55/60) epoch end, loss: 0.003815, Test Acc: 90.22 %
(56/60) epoch end, loss: 0.003804, Test Acc: 90.51 %
(57/60) epoch end, loss: 0.003766, Test Acc: 90.08 %
(58/60) epoch end, loss: 0.003753, Test Acc: 90.41 %
(59/60) epoch end, loss: 0.003780, Test Acc: 90.42 %
(60/60) epoch end, loss: 0.003785, Test Acc: 90.30 %

Sequential recasting is finished


## Fine-tuning (KD + Cross-entropy)

In [11]:
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

# define loss functions
MSE = nn.MSELoss()
criterion = nn.CrossEntropyLoss()

# pruning ratio for every layer    
optimizer = optim.Adam(student.GetTotalParams(), lr = lr_fine_tune)
scheduler = lr_scheduler.StepLR(optimizer, step_size = scheduler_step_size)
teacher.TestMode()
student.Gpu()

print('Fine tuning start\n')

for epoch in range(num_epoch_fine_tune):  # loop over the dataset multiple times

    running_loss = 0.0
    scheduler.step()
    student.TrainMode()
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs = Variable(inputs.cuda())
        labels = Variable(labels.cuda())

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        corrects = teacher(inputs)
        outputs = student(inputs)

        targets = Variable(corrects.data.clone())
        loss_KD = MSE(outputs, targets)
        loss_CE = criterion(outputs, labels)
        
        loss = loss_KD + loss_CE
        
        loss.backward()
        optimizer.step()

        running_loss = (running_loss * i + loss.cpu().data.numpy()) / (i+1)

    correct = 0
    total = 0
    student.TestMode()
    for data in testloader:
        images, labels = data
        outputs = student(Variable(images.cuda()))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.cuda()).sum()

    print('(%d/%d) epoch end, loss: %3.6f, Test Acc: %4.2f %%' %(epoch + 1, num_epoch_fine_tune, running_loss, 100 * correct / total))
    
print('\nFine tuning is finished')

Fine tuning start

(1/100) epoch end, loss: 2.009247, Test Acc: 87.44 %
(2/100) epoch end, loss: 1.749584, Test Acc: 89.63 %
(3/100) epoch end, loss: 1.414411, Test Acc: 90.03 %
(4/100) epoch end, loss: 1.276406, Test Acc: 90.19 %
(5/100) epoch end, loss: 1.258382, Test Acc: 90.25 %
(6/100) epoch end, loss: 1.221164, Test Acc: 90.01 %
(7/100) epoch end, loss: 1.232249, Test Acc: 90.22 %
(8/100) epoch end, loss: 1.195657, Test Acc: 90.22 %
(9/100) epoch end, loss: 1.379241, Test Acc: 90.67 %
(10/100) epoch end, loss: 1.173740, Test Acc: 90.32 %
(11/100) epoch end, loss: 1.090369, Test Acc: 90.44 %
(12/100) epoch end, loss: 1.121062, Test Acc: 90.18 %
(13/100) epoch end, loss: 1.131535, Test Acc: 90.69 %
(14/100) epoch end, loss: 1.079643, Test Acc: 89.84 %
(15/100) epoch end, loss: 1.074084, Test Acc: 90.17 %
(16/100) epoch end, loss: 1.368078, Test Acc: 90.44 %
(17/100) epoch end, loss: 1.046281, Test Acc: 90.66 %
(18/100) epoch end, loss: 0.993323, Test Acc: 90.71 %
(19/100) epoch end

In [12]:
student.PrintBlocksDetail()

[[Conv2d(3, 26, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  BatchNorm2d(26, eps=1e-05, momentum=0.1, affine=True),
  ReLU(inplace),
  'ConvBlock'],
 [Conv2d(26, 26, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  BatchNorm2d(26, eps=1e-05, momentum=0.1, affine=True),
  ReLU(inplace),
  'ConvBlock'],
 MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
 [Conv2d(26, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True),
  ReLU(inplace),
  'ConvBlock'],
 [Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True),
  ReLU(inplace),
  'ConvBlock'],
 MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
 [Conv2d(51, 102, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True),
  ReLU(inplace),
  'ConvBlock'],
 [Conv2d(102, 102, kernel_size=(3, 3), str