In [1]:
import os, sys
import time
import math
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.autograd import Variable
import torch.optim.lr_scheduler as lr_scheduler

import logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format="%(asctime)s %(message)s", datefmt="%m-%d %H:%M")

In [10]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 10)
        )
        self.avg_pool = nn.AdaptiveAvgPool2d((2,2))

    def forward(self, x, mode=None):
        x = self.features(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


def converged(old, new):
    converge = True
    for old_score, new_score in zip(old, new):
        converge = converge and abs(old_score - new_score) < 0.001
    return converge

In [3]:
# Data Loader
def get_data():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), normalize])
    valid_transform = transforms.Compose([transforms.ToTensor(), normalize])
    train_loader = torch.utils.data.DataLoader(
                    datasets.CIFAR10(root="./data",train=True, transform=train_transform,download=True),
                    batch_size=256, num_workers=4, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
                    datasets.CIFAR10(root="./data",train=False, transform=train_transform,download=True),
                    batch_size=256, num_workers=4, shuffle=True)

    return train_loader, test_loader

In [4]:
# Utilities
class AverageMeter(object):
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, *meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix
    
    def print(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))
    
    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def adjust_learning_rate(optimizer, lr, verbose=False):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    if verbose:
        print(optimizer.param_groups)

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def save_checkpoint(state, filename='checkpoint.pth.tar', dir=None, is_best=False):
    if dir is not None and not os.path.exists(dir):
        os.makedirs(dir)
    filename = filename if dir is None else os.path.join(dir, filename)
    torch.save(state, filename)
    if is_best:
        bestname = 'model_best.pth.tar'
        if dir is not None:
            bestname = os.path.join(dir, bestname)
        shutil.copyfile(filename, bestname)

def load_checkpoint(filename='checkpoint.pth.tar', dir=None):
    assert dir is None or os.path.exists(dir)
    if dir:
        filename = os.path.join(dir, filename)
    return torch.load(filename)

In [5]:
# Train function
def train(train_loader, **kwargs):
    epoch = kwargs.get('epoch')
    model = kwargs.get('model')
    criterion = kwargs.get('criterion')
    optimizer = kwargs.get('optimizer')
    cuda = kwargs.get('cuda')

    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, top5, prefix="Epoch:[{}]".format(epoch))

    model.train()
    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)
        end = time.time()
        if cuda:
            input = Variable(input).cuda()
            target = Variable(target).cuda()
        else:
            input = Variable(input)
            target = Variable(target)

        output = model(input)
        loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1,5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1[0], input.size(0))
        top5.update(acc5[0], input.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)

        if i % 100==0:
            progress.print(i)
    logging.info('====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))

    return top1.avg, top5.avg

In [6]:
# Validation function
def validate(val_loader, model, criterion, cuda=False):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, prefix='Test: ')

    model.eval()
    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            end = time.time()
            if cuda:
                input = Variable(input).cuda()
                target = Variable(target).cuda()
            else:
                input = Variable(input)
                target = Variable(target)

            output = model(input)
            loss = criterion(output, target)

            acc1, acc5 = accuracy(output, target, topk=(1,5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1[0], input.size(0))
            top5.update(acc5[0], input.size(0))

            batch_time.update(time.time() - end)

            if i % 100 == 0:
                progress.print(i)
        logging.info('====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))

    return top1.avg, top5.avg

In [None]:
# Training Configuration
cuda = True
model = AlexNet()
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=5e-3, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
criterion = nn.CrossEntropyLoss()

train_loader, test_loader = get_data()

if cuda:
    model = model.cuda()
    criterion = criterion.cuda()

## Phase 1. Learn important connections

In [None]:
top1 = 0
top5 = 0
for epoch in range(100):
    print("Epoch : {}, lr : {}".format(epoch, optimizer.param_groups[0]['lr']))
    print('===> [ Training ]')
    acc1_train, acc5_train = train(train_loader,
                            epoch=epoch, model=model,
                            criterion=criterion, optimizer=optimizer, cuda=cuda)

    print('===> [ Validation ]')
    acc1_valid, acc5_valid = validate(test_loader, model, criterion, cuda)

    chk = {
        'state_dict' : model.state_dict(),
        'epoch' : epoch,
        'optimizer' : optimizer.state_dict()
    }
    save_checkpoint(chk, dir='checkpoint', is_best=(top5 < acc5_valid))

    top1 = max(acc1_valid, top1)
    top5 = max(acc5_valid, top5)

    scheduler.step()

## Phase 2~3. Prune and Retraining

In [None]:
class Compressor(object):
    def __init__(self, model, cuda=False):
        self.model = model
        self.num_layers = 0
        self.num_dropout_layers = 0
        self.dropout_rates = {}

        self.count_layers()

        self.weight_masks = [None for _ in range(self.num_layers)]
        self.bias_masks = [None for _ in range(self.num_layers)]

        self.cuda = cuda

    def count_layers(self):
        for m in self.model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                self.num_layers += 1
            elif isinstance(m, nn.Dropout):
                self.dropout_rates[self.num_dropout_layers] = m.p
                self.num_dropout_layers += 1

    def prune(self):
        '''
        :return: percentage pruned in the network
        '''
        index = 0
        dropout_index = 0

        num_pruned, num_weights = 0, 0

        for m in self.model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                num = torch.numel(m.weight.data)

                if type(m) == nn.Conv2d:
                    if index == 0:
                        alpha = 0.015
                    else:
                        alpha = 0.2
                else:
                    if index == self.num_layers - 1:
                        alpha = 0.25
                    else:
                        alpha = 1

                # use a byteTensor to represent the mask and convert it to a floatTensor for multiplication
                weight_mask = torch.ge(m.weight.data.abs(), alpha * m.weight.data.std()).type('torch.FloatTensor')
                if self.cuda:
                    weight_mask = weight_mask.cuda()
                self.weight_masks[index] = weight_mask

                bias_mask = torch.ones(m.bias.data.size())
                if self.cuda:
                    bias_mask = bias_mask.cuda()

                # for all kernels in the conv2d layer, if any kernel is all 0, set the bias to 0
                # in the case of linear layers, we search instead for zero rows
                for i in range(bias_mask.size(0)):
                    if len(torch.nonzero(weight_mask[i]).size()) == 0:
                        bias_mask[i] = 0
                self.bias_masks[index] = bias_mask

                index += 1

                layer_pruned = num - torch.nonzero(weight_mask).size(0)
                logging.info('number pruned in weight of layer %d: %.3f %%' % (index, 100 * (layer_pruned / num)))
                bias_num = torch.numel(bias_mask)
                bias_pruned = bias_num - torch.nonzero(bias_mask).size(0)
                logging.info('number pruned in bias of layer %d: %.3f %%' % (index, 100 * (bias_pruned / bias_num)))

                num_pruned += layer_pruned
                num_weights += num

                m.weight.data *= weight_mask
                m.bias.data *= bias_mask

            elif isinstance(m, nn.Dropout):
                # update the dropout rate
                mask = self.weight_masks[index - 1]
                m.p = self.dropout_rates[dropout_index] * math.sqrt(torch.nonzero(mask).size(0) \
                                             / torch.numel(mask))
                dropout_index += 1
                logging.info("new Dropout rate: %4e" % m.p)

        # print(self.weight_masks)
        return num_pruned / num_weights


    def set_grad(self):
        # print(self.weight_masks)
        index = 0
        for m in self.model.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                m.weight.grad.data *= self.weight_masks[index]
                m.bias.grad.data *= self.bias_masks[index]
                index += 1

In [None]:
compressor = Compressor(model, cuda=cuda)
lr = 1e-4 # 1/10 of training
adjust_learning_rate(optimizer, lr, verbose=False)

model.train()
pruned_percentage = 0.0
prune_score = [AverageMeter(f'Acc@{k}', ':6.2f') for k in (1,5)]
validation_score = [0.0 for _ in (1,5)]

# Iterative Pruning
epoch = 1
max_epoch = 5
while True:
    if epoch == 1:
        new_pruned_percentage = compressor.prune()
        logging.info('Pruned %.3f %%' % (100 * new_pruned_percentage))
        acc = validate(test_loader, model, criterion, cuda=cuda)
        if new_pruned_percentage - pruned_percentage <= 0.001 and converged(validation_score, acc):
            break
        pruned_percentage = new_pruned_percentage
        validation_score = acc
    for e in range(epoch, max_epoch+1):
        acc1_train, acc5_train = train(train_loader,
                            epoch=e, model=model,
                            criterion=criterion, optimizer=optimizer, cuda=cuda)
    epoch = 1
final_top1, final_top5 = validate(test_loader, model, criterion, cuda=cuda)
logging.info(f"Iterative Pruning Final Result : Top1 %.3f / Top5 %.3f" % final_top1, final_top5)
