## Import torch and model

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [3]:
import sys
sys.path.append("../common")

from model_generator import ModelGenerator
from net import Net

## Set hyper params

In [4]:
batch_size = 64

dropout_on = True
batchnorm_on = True 

scheduler_step_size = 20

## for recasting

lr_recasting = 0.001
num_epoch_recasting = 60

## for fine tune

lr_fine_tune = 0.001
num_epoch_fine_tune = 100

In [5]:
model_gen = ModelGenerator(dropout = dropout_on, batchnorm = batchnorm_on)

model_gen.CifarWrnConfig(k = 10, num_layers = 28, cifar = 10)

# Recasting block
# 0: conv layer, 1-13: Residual block
recasting_block_indices = range(1, 13)
target_block_type = 'ResidualBlock'

# Compression rate
# the number of filters decreased to [compression_rate]

compression_ratio = 0.2

## file path
pretrained_model = './cifar10_wrn_28_10_pretrained.pth'
compressed_model = './cifar10_wrn_28_10_to_convenet.pth'

## Load dataset

In [6]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))])

transform_train = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, 4),
     transforms.ToTensor(),
     transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


## Load pre-trained model (teacher network)

In [7]:
model = model_gen.GetCifarWrn()
teacher = Net(model)

state = torch.load(pretrained_model)
teacher.LoadFromStateDict(state)

teacher.Gpu()

correct = 0
total = 0
teacher.TestMode()
for data in testloader:
    images, labels = data
    outputs = teacher(Variable(images.cuda()))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()

print('Accuracy of the network on the 10000 test images: %4.2f %%' % (100 * correct / total))

Accuracy of the network on the 10000 test images: 95.67 %


## Define student network

In [8]:
model = model_gen.GetCifarWrn()
student = Net(model)

state = torch.load(pretrained_model)
student.LoadFromStateDict(state)

student.Gpu()

In [9]:
correct = 0
total = 0
student.TestMode()
for data in testloader:
    images, labels = data
    outputs = student(Variable(images.cuda()))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()

print('Accuracy of the network on the 10000 test images: %4.2f %%' % (100 * correct / total))

Accuracy of the network on the 10000 test images: 95.67 %


## Sequential recasting

In [10]:
# define MSE loss
MSE = nn.MSELoss()

teacher.TestMode()

for block_idx in recasting_block_indices:
    
    ################################################    Recasting process ######################################################
    # current block recasting
    
    config = student.GetBlockConfig(block_idx)
    
    config[2] = round(config[2] * compression_ratio)    # apply compression ratio
    
    # Handling corner case: bottleneck block recasting
    if len(config) == 5:                         
        is_bottleneck = True
        mid_feature = config[4]
        # We reduce the output dimension of bottleneck block.
        # output dimension of new block is the same with output dimension of 3x3 conv in bottleneck block
        config[4] = round(mid_feature * compression_ratio)
    else :
        is_bottleneck = False
        
    new_block = model_gen.GenNewBlock([target_block_type, config])
    source_block_type = config[0]
    
    student.Recasting(block_idx, new_block)
    
    
    # next block recasting
    
    config = student.GetBlockConfig(block_idx + 1)
    
    config[1] = round(config[1] * compression_ratio)    # apply compression ratio
    
    # Handling corner case: bottleneck block recasting
    if is_bottleneck == True:                         
        # Change next input dim to output dim of target block
        config[1] = round(mid_feature * compression_ratio)
    
    new_block = model_gen.GenNewBlock([config[0], config])
    student.Recasting(block_idx + 1, new_block)
    
    ################################################    Recasting process end ##################################################
    
    student.Gpu()
    
    params = student.GetCurrParams(block_idx)
    
    optimizer = optim.Adam(params, lr = lr_recasting)
    scheduler = lr_scheduler.StepLR(optimizer, step_size = scheduler_step_size)
    
    print('\nBlock %d recasting is done (%s -> %s).' %(block_idx, source_block_type, target_block_type))
    print('Training start\n')
    for epoch in range(num_epoch_recasting):  # loop over the dataset multiple times
        
        running_loss = 0.0
        scheduler.step()
        
        student.TrainMode()
            
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data

            # wrap them in Variable
            inputs = Variable(inputs.cuda())
            labels = Variable(labels.cuda())

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            corrects = teacher(inputs, next_block= block_idx + 1)
            outputs = student(inputs, next_block = block_idx + 1)

            targets = Variable(corrects.data.clone())
            
            loss = MSE(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss = (running_loss * i + loss.cpu().data.numpy()) / (i+1)

        
        correct = 0
        total = 0
        student.TestMode()
        for data in testloader:
            images, labels = data
            outputs = student(Variable(images.cuda()))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.cuda()).sum()
        
        test_acc = 100 * correct / total
        

        print('(%d/%d) epoch end, loss: %3.6f, Test Acc: %4.2f %%' %(epoch + 1, num_epoch_recasting, running_loss, test_acc))
    
    
print('\nSequential recasting is finished')


Block 1 recasting is done (ResidualBlock -> ResidualBlock).
Training start

(1/60) epoch end, loss: 0.005611, Test Acc: 95.13 %
(2/60) epoch end, loss: 0.000529, Test Acc: 95.48 %
(3/60) epoch end, loss: 0.000298, Test Acc: 95.45 %
(4/60) epoch end, loss: 0.000210, Test Acc: 95.53 %
(5/60) epoch end, loss: 0.000179, Test Acc: 95.60 %
(6/60) epoch end, loss: 0.000145, Test Acc: 95.54 %
(7/60) epoch end, loss: 0.000144, Test Acc: 95.56 %
(8/60) epoch end, loss: 0.000118, Test Acc: 95.63 %
(9/60) epoch end, loss: 0.000169, Test Acc: 95.56 %
(10/60) epoch end, loss: 0.000108, Test Acc: 95.54 %
(11/60) epoch end, loss: 0.000098, Test Acc: 95.60 %
(12/60) epoch end, loss: 0.000092, Test Acc: 95.66 %
(13/60) epoch end, loss: 0.000088, Test Acc: 95.58 %
(14/60) epoch end, loss: 0.000084, Test Acc: 95.46 %
(15/60) epoch end, loss: 0.000082, Test Acc: 95.56 %
(16/60) epoch end, loss: 0.000082, Test Acc: 95.44 %
(17/60) epoch end, loss: 0.000080, Test Acc: 95.48 %
(18/60) epoch end, loss: 0.0000

(32/60) epoch end, loss: 0.000315, Test Acc: 94.35 %
(33/60) epoch end, loss: 0.000314, Test Acc: 94.44 %
(34/60) epoch end, loss: 0.000312, Test Acc: 94.69 %
(35/60) epoch end, loss: 0.000312, Test Acc: 94.90 %
(36/60) epoch end, loss: 0.000310, Test Acc: 94.43 %
(37/60) epoch end, loss: 0.000311, Test Acc: 94.49 %
(38/60) epoch end, loss: 0.000310, Test Acc: 94.25 %
(39/60) epoch end, loss: 0.000310, Test Acc: 94.59 %
(40/60) epoch end, loss: 0.000309, Test Acc: 94.37 %
(41/60) epoch end, loss: 0.000306, Test Acc: 95.60 %
(42/60) epoch end, loss: 0.000307, Test Acc: 95.58 %
(43/60) epoch end, loss: 0.000306, Test Acc: 95.55 %
(44/60) epoch end, loss: 0.000307, Test Acc: 95.58 %
(45/60) epoch end, loss: 0.000306, Test Acc: 95.59 %
(46/60) epoch end, loss: 0.000306, Test Acc: 95.57 %
(47/60) epoch end, loss: 0.000306, Test Acc: 95.59 %
(48/60) epoch end, loss: 0.000306, Test Acc: 95.59 %
(49/60) epoch end, loss: 0.000306, Test Acc: 95.59 %
(50/60) epoch end, loss: 0.000306, Test Acc: 9

(3/60) epoch end, loss: 0.001185, Test Acc: 13.54 %
(4/60) epoch end, loss: 0.001023, Test Acc: 24.98 %
(5/60) epoch end, loss: 0.000925, Test Acc: 80.01 %
(6/60) epoch end, loss: 0.000855, Test Acc: 56.74 %
(7/60) epoch end, loss: 0.000800, Test Acc: 36.80 %
(8/60) epoch end, loss: 0.000753, Test Acc: 61.06 %
(9/60) epoch end, loss: 0.000713, Test Acc: 60.35 %
(10/60) epoch end, loss: 0.000683, Test Acc: 47.40 %
(11/60) epoch end, loss: 0.000656, Test Acc: 55.18 %
(12/60) epoch end, loss: 0.000635, Test Acc: 83.26 %
(13/60) epoch end, loss: 0.000614, Test Acc: 87.45 %
(14/60) epoch end, loss: 0.000595, Test Acc: 37.43 %
(15/60) epoch end, loss: 0.000558, Test Acc: 90.41 %
(16/60) epoch end, loss: 0.000456, Test Acc: 62.91 %
(17/60) epoch end, loss: 0.000631, Test Acc: 92.45 %
(18/60) epoch end, loss: 0.000537, Test Acc: 75.63 %
(19/60) epoch end, loss: 0.000458, Test Acc: 88.53 %
(20/60) epoch end, loss: 0.000431, Test Acc: 91.06 %
(21/60) epoch end, loss: 0.000393, Test Acc: 95.24 %


(36/60) epoch end, loss: 0.000124, Test Acc: 94.39 %
(37/60) epoch end, loss: 0.000123, Test Acc: 94.69 %
(38/60) epoch end, loss: 0.000123, Test Acc: 94.09 %
(39/60) epoch end, loss: 0.000124, Test Acc: 90.83 %
(40/60) epoch end, loss: 0.000127, Test Acc: 94.72 %
(41/60) epoch end, loss: 0.000121, Test Acc: 94.96 %
(42/60) epoch end, loss: 0.000121, Test Acc: 95.04 %
(43/60) epoch end, loss: 0.000121, Test Acc: 94.93 %
(44/60) epoch end, loss: 0.000121, Test Acc: 94.95 %
(45/60) epoch end, loss: 0.000121, Test Acc: 94.94 %
(46/60) epoch end, loss: 0.000120, Test Acc: 94.93 %
(47/60) epoch end, loss: 0.000120, Test Acc: 94.85 %
(48/60) epoch end, loss: 0.000120, Test Acc: 94.96 %
(49/60) epoch end, loss: 0.000120, Test Acc: 94.97 %
(50/60) epoch end, loss: 0.000120, Test Acc: 94.88 %
(51/60) epoch end, loss: 0.000120, Test Acc: 94.99 %
(52/60) epoch end, loss: 0.000120, Test Acc: 94.94 %
(53/60) epoch end, loss: 0.000120, Test Acc: 94.99 %
(54/60) epoch end, loss: 0.000120, Test Acc: 9

(7/60) epoch end, loss: 0.005347, Test Acc: 16.09 %
(8/60) epoch end, loss: 0.005168, Test Acc: 26.30 %
(9/60) epoch end, loss: 0.005020, Test Acc: 39.82 %
(10/60) epoch end, loss: 0.004984, Test Acc: 72.83 %
(11/60) epoch end, loss: 0.004822, Test Acc: 11.32 %
(12/60) epoch end, loss: 0.004722, Test Acc: 59.97 %
(13/60) epoch end, loss: 0.004597, Test Acc: 17.87 %
(14/60) epoch end, loss: 0.004542, Test Acc: 33.53 %
(15/60) epoch end, loss: 0.004418, Test Acc: 68.75 %
(16/60) epoch end, loss: 0.004298, Test Acc: 26.70 %
(17/60) epoch end, loss: 0.004365, Test Acc: 34.68 %
(18/60) epoch end, loss: 0.004171, Test Acc: 32.60 %
(19/60) epoch end, loss: 0.004175, Test Acc: 78.38 %
(20/60) epoch end, loss: 0.003973, Test Acc: 49.34 %
(21/60) epoch end, loss: 0.003117, Test Acc: 93.31 %
(22/60) epoch end, loss: 0.002770, Test Acc: 93.82 %
(23/60) epoch end, loss: 0.002658, Test Acc: 93.53 %
(24/60) epoch end, loss: 0.002580, Test Acc: 93.65 %
(25/60) epoch end, loss: 0.002525, Test Acc: 93.8

## Fine-tuning (KD + Cross-entropy)

In [11]:
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

# define loss functions
MSE = nn.MSELoss()
criterion = nn.CrossEntropyLoss()

# pruning ratio for every layer    
optimizer = optim.Adam(student.GetTotalParams(), lr = lr_fine_tune)
scheduler = lr_scheduler.StepLR(optimizer, step_size = scheduler_step_size)
teacher.TestMode()
student.Gpu()

print('Fine tuning is finished')

for epoch in range(num_epoch_fine_tune):  # loop over the dataset multiple times

    running_loss = 0.0
    scheduler.step()
    student.TrainMode()
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs = Variable(inputs.cuda())
        labels = Variable(labels.cuda())

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        corrects = teacher(inputs)
        outputs = student(inputs)

        targets = Variable(corrects.data.clone())
        loss_KD = MSE(outputs, targets)
        loss_CE = criterion(outputs, labels)
        
        loss = loss_KD + loss_CE
        
        loss.backward()
        optimizer.step()

        running_loss = (running_loss * i + loss.cpu().data.numpy()) / (i+1)

    correct = 0
    total = 0
    student.TestMode()
    for data in testloader:
        images, labels = data
        outputs = student(Variable(images.cuda()))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.cuda()).sum()

    print('(%d/%d) epoch end, loss: %3.6f, Test Acc: %4.2f %%' %(epoch + 1, num_epoch_fine_tune, running_loss, 100 * correct / total))
    
print('\nFine tuning is finished')

Fine tuning is finished
(1/100) epoch end, loss: 0.369223, Test Acc: 72.17 %
(2/100) epoch end, loss: 0.392479, Test Acc: 74.12 %
(3/100) epoch end, loss: 0.394845, Test Acc: 63.95 %
(4/100) epoch end, loss: 0.395872, Test Acc: 82.64 %
(5/100) epoch end, loss: 0.388932, Test Acc: 78.22 %
(6/100) epoch end, loss: 0.373884, Test Acc: 87.39 %
(7/100) epoch end, loss: 0.382953, Test Acc: 66.20 %
(8/100) epoch end, loss: 0.370945, Test Acc: 83.51 %
(9/100) epoch end, loss: 0.362885, Test Acc: 78.31 %
(10/100) epoch end, loss: 0.382576, Test Acc: 77.32 %
(11/100) epoch end, loss: 0.370378, Test Acc: 78.51 %
(12/100) epoch end, loss: 0.340747, Test Acc: 69.88 %
(13/100) epoch end, loss: 0.360994, Test Acc: 72.57 %
(14/100) epoch end, loss: 0.372912, Test Acc: 75.10 %
(15/100) epoch end, loss: 0.328891, Test Acc: 72.72 %
(16/100) epoch end, loss: 0.330091, Test Acc: 88.53 %
(17/100) epoch end, loss: 0.331560, Test Acc: 46.65 %
(18/100) epoch end, loss: 0.327259, Test Acc: 75.93 %
(19/100) epoc

In [12]:
student.PrintBlocksDetail()

[[Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  'ConvBlock'],
 [Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True),
  ReLU(inplace),
  Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True),
  [Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False),
   BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)],
  ReLU(inplace),
  'ResidualBlock'],
 [Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True),
  ReLU(inplace),
  Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True),
  ReLU(inplace),
  'ResidualBlock'],
 [Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  BatchNorm2d(32, eps=1e-05, momentum=