# Automated Gradual Pruning Schedule

Michael Zhu and Suyog Gupta, ["To prune, or not to prune: exploring the efficacy of pruning for model compression"](https://arxiv.org/pdf/1710.01878), 2017 NIPS Workshop on Machine Learning of Phones and other Consumer Devices<br>
<br>
After completing sensitivity analysis, decide on your pruning schedule.

## Table of Contents
1. [Implementation of the gradual sparsity function](#Implementation-of-the-gradual-sparsity-function)
2. [Visualize pruning schedule](#Visualize-pruning-schedule)
3. [References](#References)

In [1]:
import math
import argparse
import time
import os
import sys
import random
import traceback
from collections import OrderedDict
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchnet.meter as tnt
# script_dir = os.path.dirname(__file__)
script_dir = os.path.abspath('/host/model_compression/distiller/examples/style_transfer_compression')
module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
try:
    import distiller
except ImportError:
    sys.path.append(module_path)
    import distiller
import apputils
from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector
import distiller.quantization as quantization
from models import ALL_MODEL_NAMES, create_model

sys.path.append('/host/frameworks/examples/fast_neural_style/neural_style')
from utils import *
from neural_style import *

# Logger handle
msglogger = None


def float_range(val_str):
    val = float(val_str)
    if val < 0 or val >= 1:
        raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str))
    return val


parser = argparse.ArgumentParser(description='Distiller image classification model compression')
parser.add_argument('--dataset', default='/host/dataset/COCO', metavar='DIR', help='path to dataset')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--act-stats', dest='activation_stats', action='store_true', default=False,
                    help='collect activation statistics (WARNING: this slows down training)')
parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
                    help='log the paramter tensors histograms to file (WARNING: this can use significant disk space)')
SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx']
parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
                    help='print a summary of the model, and exit - options: ' +
                    ' | '.join(SUMMARY_CHOICES))
parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
                    help='configuration file for pruning the model (default is to use hard-coded schedule)')
parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'],
                    help='test the sensitivity of layers to pruning')
parser.add_argument('--extras', default=None, type=str,
                    help='file with extra configuration information')
parser.add_argument('--deterministic', '--det', action='store_true',
                    help='Ensure deterministic execution for re-producible results.')
parser.add_argument('--quantize', action='store_true',
                    help='Apply 8-bit quantization to model before evaluation')
parser.add_argument('--gpus', metavar='DEV_ID', default=None,
                    help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)')
parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
                    help='Portion of training dataset to set aside for validation')
parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK')
parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true',
                    help='Display the confusion matrix')
parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None,
                    help='List of loss weights for early exits (e.g. --lossweights 0.1 0.3)')
parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None,
                    help='List of EarlyExit thresholds (e.g. --earlyexit 1.2 0.9)')

distiller.knowledge_distillation.add_distillation_args(parser, ALL_MODEL_NAMES, True)

def check_pytorch_version():
    if torch.__version__ < '0.4.0':
        print("\nNOTICE:")
        print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to "
              "PyTorch API changes which are not backward-compatible.\n"
              "Please install PyTorch 0.4.0 or its derivative.\n"
              "If you are using a virtual environment, do not forget to update it:\n"
              "  1. Deactivate the old environment\n"
              "  2. Install the new environment\n"
              "  3. Activate the new environment")
        exit(1)
    else:
        print("torch version: " + torch.__version__)
        
        
        
        
        
        
        
# global msglogger
# check_pytorch_version()
# args = parser.parse_args()

In [2]:
class Object(object):
    pass

args = Object()
args.dataset = '/host/dataset/COCO'
args.epochs = 4
args.batch_size = 2
args.output_dir = './logs'
args.name = None
args.deterministic = False
args.cuda = 1
args.earlyexit_thresholds = None
args.resume = '/host/converter_trial/coreml/candy.pth'
args.lr = 1e-4
args.style_size = None
args.log_interval = 2
args.log_params_histograms = False
args.activation_stats = False
args.image_size = 256
args.pretrained = './pretrained/mosaic.pth'
args.content_weight = 1e5
args.style_weight = 1e10
args.style_image = "/host/frameworks/examples/fast_neural_style/images/style-images/candy.jpg"
args.compress = None # '../sensitivity-pruning/alexnet.schedule_sensitivity.yaml'
print(args.resume)

/host/converter_trial/coreml/candy.pth


In [4]:
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir)

# Log various details about the execution environment.  It is sometimes useful
# to refer to past experiment executions and this information may be useful.
apputils.log_execution_env_state(sys.argv, gitroot=module_path)
msglogger.debug("Distiller: %s", distiller.__version__)

start_epoch = 0
best_top1 = 0
best_epoch = 0

if args.deterministic:
    # Experiment reproducibility is sometimes important.  Pete Warden expounded about this
    # in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/
    # In Pytorch, support for deterministic execution is still a bit clunky.
    if args.workers > 1:
        msglogger.error('ERROR: Setting --deterministic requires setting --workers/-j to 0 or 1')
        exit(1)
        # Use a well-known seed, for repeatability of experiments
        torch.manual_seed(0)
        random.seed(0)
        np.random.seed(0)
        cudnn.deterministic = True
    else:
        # This issue: https://github.com/pytorch/pytorch/issues/3659
        # Implies that cudnn.benchmark should respect cudnn.deterministic, but empirically we see that
        # results are not re-produced when benchmark is set. So enabling only if deterministic mode disabled.
        cudnn.benchmark = True

# if args.gpus is not None:
#     try:
#         args.gpus = [int(s) for s in args.gpus.split(',')]
#     except ValueError:
#         msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only')
#         exit(1)
#     available_gpus = torch.cuda.device_count()
#     for dev_id in args.gpus:
#         if dev_id >= available_gpus:
#             msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available'
#                             .format(dev_id, available_gpus))
#             exit(1)
#     # Set default device in case the first one on the list != 0
#     torch.cuda.set_device(args.gpus[0])
    
# # Infer the dataset from the model name
# args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'
# args.num_classes = 10 if args.dataset == 'cifar10' else 1000

if args.earlyexit_thresholds:
    args.num_exits = len(args.earlyexit_thresholds) + 1
    args.loss_exits = [0] * args.num_exits
    args.losses_exits = []
    args.exiterrors = []
    
# Create the model
model = TransformerNet()
if args.cuda:
    device = torch.device("cuda:{}".format(args.cuda-1))
else:
    device = torch.device("cpu")
model.to(device)
vgg = Vgg16(requires_grad=False).to(device)
style_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])

compression_scheduler = None
# Create a couple of logging backends.  TensorBoardLogger writes log files in a format
# that can be read by Google's Tensor Board.  PythonLogger writes to the Python logger.
tflogger = TensorBoardLogger(msglogger.logdir)
pylogger = PythonLogger(msglogger)

# capture thresholds for early-exit training
if args.earlyexit_thresholds:
    msglogger.info('=> using early-exit threshold values of %s', args.earlyexit_thresholds)

# We can optionally resume from a checkpoint
if args.resume:
    resumed_state_dict = torch.load(args.resume)
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for k in list(resumed_state_dict.keys()):
        if re.search(r'in\d+\.running_(mean|var)$', k):
            del resumed_state_dict[k]
    model.load_state_dict(resumed_state_dict)
#     model, compression_scheduler, start_epoch = apputils.load_checkpoint(
#         model, chkpt_file=args.resume)
    
# Define loss function (criterion) and optimizer
optimizer = Adam(model.parameters(), args.lr)
msglogger.info('Optimizer Type: %s', type(optimizer))
msglogger.info('Optimizer Args: %s', optimizer.defaults)

# This sample application can be invoked to produce various summary reports.
# if args.summary:
    # return summarize_model(model, args.dataset, which_summary=args.summary)
    
# Load the datasets: the dataset to load is inferred from the model name passed
# in args.arch.  The default dataset is ImageNet, but if args.arch contains the
# substring "_cifar", then cifar10 is used.
transform = transforms.Compose([
    transforms.Resize(args.image_size),
    transforms.CenterCrop(args.image_size),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])
train_dataset = datasets.ImageFolder(args.dataset, transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size)
msglogger.info('Dataset sizes:\n\ttraining=%d\n',
               len(train_loader.sampler))

if args.compress:
    # The main use-case for this sample application is CNN compression. Compression
    # requires a compression schedule configuration file in YAML.
    compression_scheduler = distiller.file_config(model, optimizer, args.compress)
#     # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
#     model.cuda()
    model.to(device)
else:
    compression_scheduler = distiller.CompressionScheduler(model)
        
for epoch in range(start_epoch, start_epoch + args.epochs):
    # This is the main training loop.
    msglogger.info('\n')
    if compression_scheduler:
        compression_scheduler.on_epoch_begin(epoch)
        
        # Train for one epoch
        train(train_loader, model, optimizer, vgg, epoch, compression_scheduler, [tflogger, pylogger],
              args.log_interval, args.style_image, args.style_size, args.content_weight, args.style_weight)
        distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
        if args.activation_stats:
            distiller.log_activation_sparsity(epoch, loggers=[tflogger, pylogger],
                                              collector=activations_sparsity)

        if compression_scheduler:
            compression_scheduler.on_epoch_end(epoch, optimizer)

        # remember best top1 and save checkpoint
        is_best = top1 > best_top1
        if is_best:
            best_epoch = epoch
            best_top1 = top1
        msglogger.info('==> Best Top1: %.3f   On Epoch: %d\n', best_top1, best_epoch)
        apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, best_top1, is_best,
                                 args.name, msglogger.logdir)

Log file for this run: /host/model_compression/distiller/examples/style_transfer_compression/logs/2018.11.08-090749/2018.11.08-090749.log
Optimizer Type: <class 'torch.optim.adam.Adam'>
Optimizer Args: {'lr': 0.0001, 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'betas': (0.9, 0.999)}



--------------------------------------------------------
Logging to TensorBoard - remember to execute the server:
> tensorboard --logdir='./logs'



Dataset sizes:
	training=82783



Training epoch: 82783 samples (2 per mini-batch)
Epoch: [0][    2/41392]    Overall Loss 3868320.500000    Objective Loss 3868320.500000    LR 0.000100    
Epoch: [0][    4/41392]    Overall Loss 3743740.625000    Objective Loss 3743740.625000    LR 0.000100    
Epoch: [0][    6/41392]    Overall Loss 3634853.958333    Objective Loss 3634853.958333    LR 0.000100    
Epoch: [0][    8/41392]    Overall Loss 3531091.468750    Objective Loss 3531091.468750    LR 0.000100    
Epoch: [0][   10/41392]    Overall Loss 3438809.625000    Objective Loss 3438809.625000    LR 0.000100    
Epoch: [0][   12/41392]    Overall Loss 3343157.312500    Objective Loss 3343157.312500    LR 0.000100    
Epoch: [0][   14/41392]    Overall Loss 3250093.464286    Objective Loss 3250093.464286    LR 0.000100    
Epoch: [0][   16/41392]    Overall Loss 3179808.875000    Objective Loss 3179808.875000    LR 0.000100    
Epoch: [0][   18/41392]    Overall Loss 3105655.291667    Obj

Epoch: [0][  154/41392]    Overall Loss 1271787.946023    Objective Loss 1271787.946023    LR 0.000100    
Epoch: [0][  156/41392]    Overall Loss 1263964.281250    Objective Loss 1263964.281250    LR 0.000100    
Epoch: [0][  158/41392]    Overall Loss 1257231.805380    Objective Loss 1257231.805380    LR 0.000100    
Epoch: [0][  160/41392]    Overall Loss 1249811.459375    Objective Loss 1249811.459375    LR 0.000100    
Epoch: [0][  162/41392]    Overall Loss 1242198.939815    Objective Loss 1242198.939815    LR 0.000100    
Epoch: [0][  164/41392]    Overall Loss 1234983.033155    Objective Loss 1234983.033155    LR 0.000100    
Epoch: [0][  166/41392]    Overall Loss 1227419.565136    Objective Loss 1227419.565136    LR 0.000100    
Epoch: [0][  168/41392]    Overall Loss 1220812.448289    Objective Loss 1220812.448289    LR 0.000100    
Epoch: [0][  170/41392]    Overall Loss 1214924.297794    Objective Loss 1214924.297794    LR 0.000100    
Epoch: [0][  172/41392]    Overall Lo

Epoch: [0][  308/41392]    Overall Loss 950886.046469    Objective Loss 950886.046469    LR 0.000100    
Epoch: [0][  310/41392]    Overall Loss 948448.489516    Objective Loss 948448.489516    LR 0.000100    
Epoch: [0][  312/41392]    Overall Loss 946436.228165    Objective Loss 946436.228165    LR 0.000100    
Epoch: [0][  314/41392]    Overall Loss 944041.170183    Objective Loss 944041.170183    LR 0.000100    
Epoch: [0][  316/41392]    Overall Loss 942709.234177    Objective Loss 942709.234177    LR 0.000100    
Epoch: [0][  318/41392]    Overall Loss 940364.474057    Objective Loss 940364.474057    LR 0.000100    
Epoch: [0][  320/41392]    Overall Loss 938224.546094    Objective Loss 938224.546094    LR 0.000100    
Epoch: [0][  322/41392]    Overall Loss 936342.519410    Objective Loss 936342.519410    LR 0.000100    
Epoch: [0][  324/41392]    Overall Loss 933841.230517    Objective Loss 933841.230517    LR 0.000100    
Epoch: [0][  326/41392]    Overall Loss 931981.277224  

Epoch: [0][  466/41392]    Overall Loss 830619.695815    Objective Loss 830619.695815    LR 0.000100    
Epoch: [0][  468/41392]    Overall Loss 829419.234375    Objective Loss 829419.234375    LR 0.000100    
Epoch: [0][  470/41392]    Overall Loss 828592.099202    Objective Loss 828592.099202    LR 0.000100    
Epoch: [0][  472/41392]    Overall Loss 827308.096398    Objective Loss 827308.096398    LR 0.000100    
Epoch: [0][  474/41392]    Overall Loss 826174.430775    Objective Loss 826174.430775    LR 0.000100    
Epoch: [0][  476/41392]    Overall Loss 825169.913603    Objective Loss 825169.913603    LR 0.000100    
Epoch: [0][  478/41392]    Overall Loss 824192.156773    Objective Loss 824192.156773    LR 0.000100    
Epoch: [0][  480/41392]    Overall Loss 823381.705599    Objective Loss 823381.705599    LR 0.000100    
Epoch: [0][  482/41392]    Overall Loss 822438.968361    Objective Loss 822438.968361    LR 0.000100    
Epoch: [0][  484/41392]    Overall Loss 821278.273760  

KeyboardInterrupt: 

In [3]:
OVERALL_LOSS_KEY = 'Overall Loss'
OBJECTIVE_LOSS_KEY = 'Objective Loss'

def train(train_loader, model, optimizer, vgg, epoch, compression_scheduler, loggers,
          log_interval, style_image, style_size, content_weight, style_weight):
#     np.random.seed(args.seed)
#     torch.manual_seed(args.seed)
    """Training loop for one epoch."""
    losses = OrderedDict([(OVERALL_LOSS_KEY, tnt.AverageValueMeter()),
                          (OBJECTIVE_LOSS_KEY, tnt.AverageValueMeter())])
   
    mse_loss = torch.nn.MSELoss()
    
    total_samples = len(train_loader.sampler)
    batch_size = train_loader.batch_size
    steps_per_epoch = math.ceil(total_samples / batch_size)
    msglogger.info('Training epoch: %d samples (%d per mini-batch)', total_samples, batch_size)
    
    style = utils.load_image(style_image, size=style_size)
    style = style_transform(style)
    style = style.repeat(batch_size, 1, 1, 1).to(device)
    
    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]
    
    model.train()
    agg_content_loss = 0.
    agg_style_loss = 0.
    count = 0
    end = time.time()
    for batch_id, (x, _) in enumerate(train_loader):
        n_batch = len(x)
        count += n_batch
        
        # Execute the forward phase, compute the output and measure loss
        if compression_scheduler:
            compression_scheduler.on_minibatch_begin(epoch, batch_id, steps_per_epoch, optimizer)

        x = x.to(device)
        y = model(x)

        y = utils.normalize_batch(y)
        x = utils.normalize_batch(x)

        features_y = vgg(y)
        features_x = vgg(x)

        content_loss = content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

        style_loss = 0.
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
        style_loss *= style_weight

        loss = content_loss + style_loss

        losses[OBJECTIVE_LOSS_KEY].add(loss.item())
        
        if compression_scheduler:
            # Before running the backward phase, we allow the scheduler to modify the loss
            # (e.g. add regularization loss)
            agg_loss = compression_scheduler.before_backward_pass(epoch, batch_id, steps_per_epoch, loss,
                                                                  optimizer=optimizer, return_loss_components=True)
            loss = agg_loss.overall_loss
            losses[OVERALL_LOSS_KEY].add(loss.item())
            for lc in agg_loss.loss_components:
                if lc.name not in losses:
                    losses[lc.name] = tnt.AverageValueMeter()
                losses[lc.name].add(lc.value.item())
        
        # Compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if compression_scheduler:
            compression_scheduler.on_minibatch_end(epoch, batch_id, steps_per_epoch, optimizer)
        
        
        if (batch_id + 1) % log_interval == 0:
            stats_dict = OrderedDict()
            for loss_name, meter in losses.items():
                stats_dict[loss_name] = meter.mean
            # stats_dict.update(errs)
            stats_dict['LR'] = optimizer.param_groups[0]['lr']
            # stats_dict['Time'] = batch_time.mean
            stats = ('Peformance/Training/', stats_dict)

            params = model.named_parameters() if args.log_params_histograms else None
            distiller.log_training_progress(stats,
                                            params,
                                            epoch, batch_id+1,
                                            steps_per_epoch, log_interval,
                                            loggers)

In [1]:
def validate(val_loader, model, criterion, loggers, args, epoch=-1):
    """Model validation"""
    if epoch > -1:
        msglogger.info('--- validate (epoch=%d)-----------', epoch)
    else:
        msglogger.info('--- validate ---------------------')
    return _validate(val_loader, model, criterion, loggers, args, epoch)

def test(test_loader, model, criterion, loggers, args):
    """Model Test"""
    msglogger.info('--- test ---------------------')
    return _validate(test_loader, model, criterion, loggers, args)

def _validate(data_loader, model, criterion, loggers, args, epoch=-1):
    """Execute the validation/test loop."""
    losses = {'objective_loss': tnt.AverageValueMeter()}
    classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))

    if args.earlyexit_thresholds:
        # for Early Exit, we have a list of errors and losses for each of the exits.
        args.exiterrors = []
        args.losses_exits = []
        for exitnum in range(args.num_exits):
            args.exiterrors.append(tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)))
            args.losses_exits.append(tnt.AverageValueMeter())
        args.exit_taken = [0] * args.num_exits

    batch_time = tnt.AverageValueMeter()
    total_samples = len(data_loader.sampler)
    batch_size = data_loader.batch_size
    if args.display_confusion:
        confusion = tnt.ConfusionMeter(args.num_classes)
    total_steps = total_samples / batch_size
    msglogger.info('%d samples (%d per mini-batch)', total_samples, batch_size)

    # Switch to evaluation mode
    model.eval()

    end = time.time()
    for validation_step, (inputs, target) in enumerate(data_loader):
        with torch.no_grad():
            inputs, target = inputs.to('cuda'), target.to('cuda')
            # compute output from model
            output = model(inputs)

            if not args.earlyexit_thresholds:
                # compute loss
                loss = criterion(output, target)
                # measure accuracy and record loss
                losses['objective_loss'].add(loss.item())
                classerr.add(output.data, target)
                if args.display_confusion:
                    confusion.add(output.data, target)
            else:
                # If using Early Exit, then compute outputs at all exits - output is now a list of all exits
                # from exit0 through exitN (i.e. [exit0, exit1, ... exitN])
                earlyexit_validate_loss(output, target, criterion, args)

            # measure elapsed time
            batch_time.add(time.time() - end)
            end = time.time()

            steps_completed = (validation_step+1)
            if steps_completed % args.print_freq == 0:
                if not args.earlyexit_thresholds:
                    stats = ('',
                            OrderedDict([('Loss', losses['objective_loss'].mean),
                                         ('Top1', classerr.value(1)),
                                         ('Top5', classerr.value(5))]))
                else:
                    stats_dict = OrderedDict()
                    stats_dict['Test'] = validation_step
                    for exitnum in range(args.num_exits):
                        la_string = 'LossAvg' + str(exitnum)
                        stats_dict[la_string] = args.losses_exits[exitnum].mean
                        # Because of the nature of ClassErrorMeter, if an exit is never taken during the batch,
                        # then accessing the value(k) will cause a divide by zero. So we'll build the OrderedDict
                        # accordingly and we will not print for an exit error when that exit is never taken.
                        if args.exit_taken[exitnum]:
                            t1 = 'Top1_exit' + str(exitnum)
                            t5 = 'Top5_exit' + str(exitnum)
                            stats_dict[t1] = args.exiterrors[exitnum].value(1)
                            stats_dict[t5] = args.exiterrors[exitnum].value(5)
                    stats = ('Performance/Validation/', stats_dict)

                distiller.log_training_progress(stats, None, epoch, steps_completed,
                                                total_steps, args.print_freq, loggers)
    if not args.earlyexit_thresholds:
        msglogger.info('==> Top1: %.3f    Top5: %.3f    Loss: %.3f\n',
                       classerr.value()[0], classerr.value()[1], losses['objective_loss'].mean)

        if args.display_confusion:
            msglogger.info('==> Confusion:\n%s\n', str(confusion.value()))
        return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
    else:
        # Print some interesting summary stats for number of data points that could exit early
        top1k_stats = [0] * args.num_exits
        top5k_stats = [0] * args.num_exits
        losses_exits_stats = [0] * args.num_exits
        sum_exit_stats = 0
        for exitnum in range(args.num_exits):
            if args.exit_taken[exitnum]:
                sum_exit_stats += args.exit_taken[exitnum]
                msglogger.info("Exit %d: %d", exitnum, args.exit_taken[exitnum])
                top1k_stats[exitnum] += args.exiterrors[exitnum].value(1)
                top5k_stats[exitnum] += args.exiterrors[exitnum].value(5)
                losses_exits_stats[exitnum] += args.losses_exits[exitnum].mean
        for exitnum in range(args.num_exits):
            if args.exit_taken[exitnum]:
                msglogger.info("Percent Early Exit %d: %.3f", exitnum,
                               (args.exit_taken[exitnum]*100.0) / sum_exit_stats)

        return top1k_stats[args.num_exits-1], top5k_stats[args.num_exits-1], losses_exits_stats[args.num_exits-1]


## Implementation of the gradual sparsity function

The function ```sparsity_target``` implements the gradual sparsity schedule from [[1]](#zhu-gupta):<br><br>
<b><i>"We introduce a new automated gradual pruning algorithm in which the sparsity is increased from an initial sparsity value $s_i$ (usually 0) to a final sparsity value $s_f$ over a span of $n$ pruning steps, starting at training step $t_0$ and with pruning frequency $\Delta t$."</i></b><br>
<br>

<div id="eq:zhu_gupta_schedule"></div>
<center>
$\large
\begin{align}
s_t = s_f + (s_i - s_f) \left(1- \frac{t-t_0}{n\Delta t}\right)^3
\end{align}
\ \ for
\large \ \ t \in \{t_0, t_0+\Delta t, ..., t_0+n\Delta t\}
$
</center>
<br>
Pruning happens once at the beginning of each epoch, until the duration of the pruning (the number of epochs to prune) is exceeded.  After pruning ends, the training continues without pruning, but the pruned weights are kept at zero.

In [None]:
def sparsity_target(starting_epoch, ending_epoch, initial_sparsity, final_sparsity, current_epoch):
    if final_sparsity < initial_sparsity:
        return current_epoch 
    if current_epoch < starting_epoch:
        return current_epoch
    
    span = ending_epoch - starting_epoch
    target_sparsity = ( final_sparsity +
                        (initial_sparsity - final_sparsity) *
                        (1.0 - ((current_epoch-starting_epoch)/span))**3)
    return target_sparsity

## Visualize pruning schedule
When using the Automated Gradual Pruning (AGP) schedule, you may want to visualize how the pruning schedule will look as a function of the epoch number.  This is called the *sparsity function*.  The widget below will help you do this.<br>
There are three knobs you can use to change the schedule:
- ```duration```: this is the number of epochs over which to use the AGP schedule ($n\Delta t$).
- ```initial_sparsity```: $s_i$
- ```final_sparsity```: $s_f$
- ```frequency```: this is the pruning frequency ($\Delta t$).

In [None]:
def draw_pruning(duration, initial_sparsity, final_sparsity, frequency):
    epochs = []
    sparsity_levels = []
    # The derivative of the sparsity (i.e. sparsity rate of change)
    d_sparsity = []

    if frequency=='':
        frequency = 1 
    else:
        frequency = int(frequency)
    for epoch in range(0,40):
        epochs.append(epoch)
        current_epoch=Variable(torch.FloatTensor([epoch]), requires_grad=True)
        if epoch<duration and epoch%frequency == 0:
            sparsity = sparsity_target(
                     starting_epoch=0, 
                     ending_epoch=duration, 
                     initial_sparsity=initial_sparsity, 
                     final_sparsity=final_sparsity,
                current_epoch=current_epoch
            )
            
            sparsity_levels.append(sparsity)
            sparsity.backward()
            d_sparsity.append(current_epoch.grad.item())
            current_epoch.grad.data.zero_()
        else:
            sparsity_levels.append(sparsity)
            d_sparsity.append(0)
            

    plt.plot(epochs, sparsity_levels, epochs, d_sparsity)
    plt.ylabel('sparsity (%)')
    plt.xlabel('epoch')
    plt.title('Pruning Rate')
    plt.ylim(0, 100)
    plt.draw()


duration_widget = widgets.IntSlider(min=0, max=100, step=1, value=28)
si_widget = widgets.IntSlider(min=0, max=100, step=1, value=0)
interact(draw_pruning, 
         duration=duration_widget, 
         initial_sparsity=si_widget, 
         final_sparsity=(0,100,1),
         frequency='2'); 

<div id="toc"></div>
## References
1. <div id="zhu-gupta"></div> **Michael Zhu and Suyog Gupta**. 
    [*To prune, or not to prune: exploring the efficacy of pruning for model compression*](https://arxiv.org/pdf/1710.01878),
    NIPS Workshop on Machine Learning of Phones and other Consumer Devices,
    2017.