# Kernel Sparsity via Pruning Tutorial

Neural networks, in general, are very overparamatarized for given tasks (ie the number of parameters far exceeds the number of training points) yet still they [generalize very well](https://arxiv.org/abs/1611.03530). This flies against conventional wisdom where overparamatarizing a model will lead to overfitting and puting theory behind this empirical evidence is a very active area of research.

Additionally, [early on](http://yann.lecun.com/exdb/publis/pdf/lecun-90b.pdf) it was discovered that large numbers of weights in neural networks could be pruned away (set to 0) without affecting the loss and in most cases actually improving the generalization capability of the network. This work was reinvigorated with Song Han's [2015 paper](https://arxiv.org/abs/1510.00149) in pursuit of compressing model size for mobile applications. This has resulted in numerous papers coming out on the topic of weight pruning, filter pruning, channel pruning, and ultimately block pruning. [This paper](https://arxiv.org/abs/1902.09574) out of Google gives a good overview of the current state of sparsity.

Given that models are very overparamatarized and large numbers of weights can be effectively pruned away, what does this leave us with? Well intuitively, then, we can think of pruning as performing an [architecture search](https://openreview.net/pdf?id=rJlnB3C5Ym) within this large, traditionally fixed weight space. What was originally important in the dense model was representing a large number of pathways for optimization. We can then effectively remove the unused pathways in the optimization space with a fine toothed comb.

Well what does pruning get us? We now have a model with a lot of multiplications by zero that we don't need to run. If we're smart about how we do structure this compute (a suprisingly trickly problem), we can now run the model way faster than ever thought possible! That's where the [Neural Magic](http://neuralmagic.com/) engine can help us.

This tutorial provides a step by step walk through for pruning an already trained (dense) model. Specifically it is setup to work with the model trained in our [model training tutorial](model_training.ipynb), but it can be changed to support other models/datasets as needed:
1. Dataset selection
2. Model selection and loading
3. Pruning setup
4. Recalibration using pruning

In [None]:
import sys
import os

print('Python %s on %s' % (sys.version, sys.platform))

package_path = os.path.abspath(os.path.join(os.path.expanduser(os.getcwd()), os.pardir))
print(package_path)

"""
Adding the path to the neuralmagic-pytorch extension to the path so it isn't necessary to have it installed
"""
sys.path.extend([package_path])

print('Added current package path to sys.path')
print('Be sure to install from requirements.txt and pytorch separately')


## Dataset Selection

We are using fast.ai's [Imagenette dataset](https://github.com/fastai/imagenette) provided under the [Apache License 2.0](https://github.com/fastai/imagenette/blob/master/LICENSE) as the default dataset. The original authors, much like ourselves, were interested in a dataset that has similar properties to more complicated datasets such as the Imagenet dataset but one that would allow rapid iterations. It includes 10 of the easiest classes out of the Imagenet 1000 dataset: tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute. If you are interested in visualizing the properties in this dataset see our [model training tutorial](model_training.ipynb) which also gives a more in depth breakdown for what batch size to use and the dataset splits.

The dataset can easily be changed to the desired dataset in the code given.

Below we will need to fill in the dataset path, train batch size, and test batch size.

In [None]:
import ipywidgets as widgets
import torch

print('\nEnter the local path where the dataset can be found')

dataset_text = widgets.Text(value='', placeholder='Enter local path to dataset', description='Dataset Path')
display(dataset_text)

print('\nChoose the batch size to run through the model during train and test runs')
print('(be sure to press enter if/after inputting manually)')
train_batch_size_slider = widgets.IntSlider(
    value=64, min=1, max=256, step=1, description='Train Batch Size:'
)
display(train_batch_size_slider)
test_batch_size_slider = widgets.IntSlider(
    value=64 if torch.cuda.is_available() else 1, min=1, max=256, step=1, description='Test Batch Size:'
)
display(test_batch_size_slider)


In [None]:
from neuralmagicML.datasets import ImagenetteDataset, EarlyStopDataset
from torch.utils.data import Dataset, DataLoader

dataset_root = os.path.abspath(os.path.expanduser(dataset_text.value.strip()))
print('\nLoading dataset from {}'.format(dataset_root))

if not os.path.exists(dataset_root):
    raise Exception('Folder must exist for dataset at {}'.format(dataset_root))
    
train_batch_size = train_batch_size_slider.value
test_batch_size = test_batch_size_slider.value

print('\nUsing train batch size of {} and test batch size of {}\n'
      .format(train_batch_size, test_batch_size))
    
train_dataset = ImagenetteDataset(dataset_root, train=True, rand_trans=True)
train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4)
print('train dataset created: \n{}\n'.format(train_dataset))

val_dataset = ImagenetteDataset(dataset_root, train=False, rand_trans=False)
val_data_loader = DataLoader(val_dataset, batch_size=train_batch_size, shuffle=False, num_workers=4)
print('validation test dataset created: \n{}\n'.format(val_dataset))

train_test_dataset = EarlyStopDataset(ImagenetteDataset(dataset_root, train=True, rand_trans=False),
                                      early_stop=len(val_dataset))
train_test_data_loader = DataLoader(train_test_dataset, batch_size=train_batch_size, shuffle=False, num_workers=4)
print('train test dataset created: \n{}\n'.format(train_test_dataset))


## Model Selection and Loading

For this exercise we'll create the standard [ResNet50 model](https://arxiv.org/abs/1512.03385) and in addition we will load the pretrained weights from our [model training tutorial](model_training.ipynb)

If you changed the dataset in the above cell, then we'll need to update the number of classes to create the model appropriately as well as loading your own pretrained weights. Additionally the model can be changed out completely to work with your specific use case.

Additionally run the code block and select the device to run on before continuing. cpu runs in the pytorch cpu framework and cuda runs on an attached GPU.


In [None]:
import glob
import datetime
from neuralmagicML.models import resnet50, load_model

num_classes = 10
# TODO: change this to load pretrained weights from our cloud
model = resnet50(num_classes=num_classes)
model_id = '{}-{}'.format(model.__class__.__name__,
                          datetime.datetime.today().strftime('%Y-%m-%d-%H:%M:%S')
                              .replace('-', '.').replace(':', '.'))
pretrained_paths = [path for path in glob.glob('ResNet*.pth')]
pretrained_path = pretrained_paths[0]
load_model(pretrained_path, model)
print('Created model {}'.format(model.__class__.__name__))

print('\nChoose the device to run on')
device_choice = widgets.ToggleButtons(
    options=['cuda', 'cpu'] if torch.cuda.is_available() else ['cpu'],
    description='Device'
)
display(device_choice)


## Pruning Setup

Informally, sparsity is the degree in which a tensor is comprised of zeros. More formally:

let $N^i$ be the total number of elements in a (e.g. weight) tensor $W_i$ and let $N^i_z$ be the number of elements which are zero-valued within that tensor.

The sparsity level associated with that tensor is then defined as $s_i \triangleq \dfrac{N^i_z}{N^i}$ 

Our goal now is to maximize the sparsity of each weight tensor within the model while preserving the accuracy of the densely trained version. The most robust and effective approach to this so far has been magnitude pruning. To frame the problem, let's say our goal is to go from an intial sparsity $N^i_{init}$ to a final sparsity $N^i_{final}$. We could start with changing out our usual $L_2$ weight regularization with $L_1$ [($L_2$ vs $L_1$)](https://towardsdatascience.com/l1-and-l2-regularization-methods-ce25e7fc831c) in our cost function. Following through, we would find that we did, in fact, introduce a few zeros into our weights. We created a direct pathway within the optimization problem such that the model is incentivized to reduce the magnitude of unimportant weights to 0 (unimportant are defined as weights that do not significantly contribute to the loss function). This is hard to balance, though. For example, how do we determine the proper weighting between the loss function and the regularization such that we reach a desired sparsity without sacrificing accuracy? Also, emperically we find that the model will settle in local mins thus failing to reduce further weights to zero.

We can take this general idea to more of an extreme, though. Based on the previous thought experiment, it's reasonable to make an assumption that weights near zero are likely not important to the final loss function as well. Taking this even further, we could say that the smallest values within a given weight tensor are the ones least likely to affect our loss function. This is, of course, assuming that we've trained long enough to reach a stable point where our neural networks cost function has mimimized unimportant weights. Naturally, then, we could apply a schedule where we prune away a small number of smallest weights for every $M$ training steps. This is exactly what magnitude pruning does. 

Below, we go through a UI to setup a magntiude pruning schedule for our model. While pruning, some layers within a model will be more sensitive than others. A general rule is that the initial input layer and the final output layer are the most sensitive as they are an absolute bottleneck to the information flow. Because of this, we offer up a UI capable of creating very intricate schedules to the point that each layer could be pruned with a different schedule to a different final sparsity. In general, though, we can get by with pruning the initial and final layers minimally and the other layers equally. At Neural Magic we are actively working on automating this process to find the best possible pruning schedules to maximize performance while minimizing accuracy loss. 

The default for the below model will disable pruning for the initial and final layers and prune all other layers to 90% sparsity over the course of 10 epochs. The options available are described below:
- Prune Epochs: number of epochs to apply the pruning over
- Add New Group: add a new pruning group to the UI, used to prune selected layers to differnt sparsity levels at different rates
- Delete Current Group: remove the current group from the pruning schedule
- Tabs: the pruning groups setup so far, click to switch between
- Sparsity: a range slider where the left value defines the sparsity level to initially start pruning at and the right value represents the final sparsity level
- Start Epoch: the epoch to start pruning the selected layers at the initial sparsity
- End Epoch: the epoch to finish pruning the selected layers to the final sparsity
- Update Freq: the update frequency, in epochs, at which to prune the layers; ie 1.0 will prune every epoch
- Selectable Layers: a dropdown of the layers available in the model that can be pruned along with their FLOPS and param counts; select the desired layers to prune using the checkboxes

In [None]:
from neuralmagicML.notebooks import KSModifierWidgets


device = device_choice.value
print('running on device {}'.format(device))

print('\ncreating kernel sparsity analyzer widgets (need to execute the model, so may take some time)...')
ks_widget, ks_modifiers = KSModifierWidgets.interactive_module(
    model, device, inp_dim=(1, 3, 224, 224), init_start_sparsity=0.05, init_final_sparsity=0.9, disable_all=False
)
display(ks_widget)


### Learning Rate Selection

With our magnitude pruning schedule complete, we need to define the hyperparameters for training while pruning. The most important of these is the learning rate. Too high and we will diverge from our initial dense solution. Too low and we will have to train for too many epochs. Below we run a learning rate sensitivty analysis from the [cyclic LR paper](https://arxiv.org/abs/1506.01186).

To run the sensitivity analysis we will begin training the model at a very small learning rate ($10^{-7}$) where the weight updates will be lost in floating point precision errors (ie we won't learn). After each batch we exponentially increase the learning rate until we reach a very high learning rate ($10^0$) where we are guaranteed to diverge from our trained model.

A flat region will be apparent in the graph starting from the left to some point at the right for a fully trained model. After this we reach a critical learning rate where the loss begins to rapidly rise. This is the critical point where we begin diverging from the local minimum in our optimization space. Using this information, we can find the maximum and minimum learning rates to use in a cyclic learning rate with our [SGD + nesterov momentum optimizer](https://towardsdatascience.com/stochastic-gradient-descent-with-momentum-a84097641a5d) for training. Cyclic learning rates allows quicker exploration and convergence within an optimization space than a standard stepped learning rate. See the cyclic LR paper linked above for more information. Another optimizer or a non cyclic approach can be used as well. In that case the learning rate sensitivity analysis gives a good range for the initial learning rate to set.

Good defaults for minimum and maximum learning rates are given for the default model and dataset. If changing either, you will need to update the learning rates. A good rule of thumb is to set the minimum to $10^{-4}$ and the maximum to a value a little before the critical point found in the sensitivity analysis.

In [None]:
import torch
from neuralmagicML.utils import lr_analysis, lr_analysis_figure, CrossEntropyLossWrapper
%matplotlib inline
import matplotlib.pyplot as plt

### optimizer definitions
momentum = 0.9
weight_decay = 1e-4
###

# print('\nrunning learning rate analysis...')
# batches_per_sample = round(500 / train_batch_size)  # make sure we have enough sample points per learning rate
# analysis = lr_analysis(model, device, train_data_loader, CrossEntropyLossWrapper(), batches_per_sample,
#                        init_lr=1e-7, final_lr=1e0, sgd_momentum=momentum, sgd_weight_decay=weight_decay)
# lr_analysis_figure(analysis)
# plt.show()

print('\nselect the proper learning rates')
lr_min_slider = widgets.FloatLogSlider(
    value=1e-4, min=-7, max=1, step=0.01, description='LR Min:'
)
display(lr_min_slider)
lr_max_slider = widgets.FloatLogSlider(
    value=0.005, min=-7, max=1, step=0.01, description='LR Max:'
)
display(lr_max_slider)



## Recalibration using pruning


In [None]:
from torch import optim
from torch.nn import Conv2d, Linear
from tensorboardX import SummaryWriter

from neuralmagicML.sparsity import CyclicLRModifier, ScheduledModifierManager, ScheduledOptimizer, KSAnalyzerLayer
from neuralmagicML.utils import TopKAccuracy, CrossEntropyLossWrapper


model = model.to(device)

conv_layers_names = []
for name, mod in model.named_modules():
    if isinstance(mod, Conv2d): #to add the FC layers: isinstance(mod, Conv2d) or isinstance(mod, Linear) 
        conv_layers_names.append(name)
analyzed_layers = KSAnalyzerLayer.analyze_layers(
    model, layers=[name for name, mod in model.named_modules()
                   if isinstance(mod, Conv2d) or isinstance(mod, Linear)]
)

prune_epochs = max([modifier.end_epoch for modifier in ks_modifiers])
print(prune_epochs)
lr_modifier = CyclicLRModifier(
    base_lr=lr_min_slider.value, max_lr=lr_max_slider.value, step_size_up=len(train_dataset) // 2
)
modifier_manager = ScheduledModifierManager([lr_modifier, *ks_modifiers])
print('Created ScheduledModifierManager with lr_modifier and {} ks modifiers'.format(len(ks_modifiers)))

optimizer = optim.SGD(model.parameters(), lr_min_slider.value, momentum=momentum,
                      weight_decay=weight_decay, nesterov=True)
optimizer = ScheduledOptimizer(optimizer, model, modifier_manager, steps_per_epoch=len(train_dataset))
print('\nCreated scheudled optimizer with initial lr: {}, momentum: {}, weight decay: {}'
      .format(lr_min_slider.value, momentum, weight_decay))

loss = CrossEntropyLossWrapper(extras={'top1acc': TopKAccuracy(1)})
print('\nCreated loss wrapper\n{}'.format(loss))

logs_dir = os.path.abspath(os.path.expanduser(os.path.join('.', 'model_training_logs', model_id)))

if not os.path.exists(logs_dir):
    os.makedirs(logs_dir)

writer = SummaryWriter(logdir=logs_dir, comment='imagenette training')
print('\nCreated summary writer logging to \n{}'.format(logs_dir))


### Pruning with Schedule

In [None]:
import math
from tqdm import tqdm


def test_epoch(model, data_loader, loss, device, epoch):
    model.eval()
    results = {}
    
    with torch.no_grad():
        for batch, (*x_feature, y_lab) in enumerate(tqdm(data_loader)):
            y_lab = y_lab.to(device)
            x_feature = tuple([dat.to(device) for dat in x_feature])
            batch_size = y_lab.shape[0]
            y_pred = model(*x_feature)
            losses = loss(x_feature, y_lab, y_pred)
            
            for key, val in losses.items():
                if key not in results:
                    results[key] = []
                result = val.detach_().cpu()
                result = result.repeat(batch_size)
                results[key].append(result)
                
    return results

def train_epoch(model, data_loader, optimizer, loss, device, epoch, writer):
    model.train()
    init_batch_size = None
    batches_per_epoch = len(data_loader)
    
    for batch, (*x_feature, y_lab) in enumerate(tqdm(data_loader)):
        y_lab = y_lab.to(device)
        x_feature = tuple([dat.to(device) for dat in x_feature])
        batch_size = y_lab.shape[0]
        if init_batch_size is None:
            init_batch_size = batch_size
        optimizer.zero_grad()
        y_pred = model(*x_feature)
        losses = loss(x_feature, y_lab, y_pred)
        losses['loss'].backward()
        optimizer.step(closure=None)
        
        step_count = init_batch_size * (epoch * batches_per_epoch + batch)
        for _loss, _value in losses.items():
            writer.add_scalar('Train/{}'.format(_loss), _value.item(), step_count)
            writer.add_scalar('Train/Learning Rate', optimizer.learning_rate, step_count)
            
print('Pruning model')
            
for epoch in range(math.ceil(modifier_manager.max_epochs)):
    print('Starting epoch {}'.format(epoch))
    optimizer.epoch_start()
    train_epoch(model, train_data_loader, optimizer, loss, device, epoch, writer)
    
    print('Completed training for epoch {}, testing validation dataset'.format(epoch))
    val_losses = test_epoch(model, val_data_loader, loss, device, epoch)
    for _loss, _values in val_losses.items():
        _value = torch.mean(torch.cat(_values))
        last_value = _value
        writer.add_scalar('Test/Validation/{}'.format(_loss), _value, epoch)
        print('{}: {}'.format(_loss, _value))
        
    print('Completed testing validation dataset for epoch {}, testing training dataset'.format(epoch))
    train_losses = test_epoch(model, train_test_data_loader, loss, device, epoch)
    for _loss, _values in train_losses.items():
        _value = torch.mean(torch.cat(_values))
        writer.add_scalar('Test/Training/{}'.format(_loss), _value, epoch)
        print('{}: {}'.format(_loss, _value))
        
    optimizer.epoch_end()
    
    for ks_layer in analyzed_layers:
        writer.add_scalar('Kernel Sparsity/{}'.format(ks_layer.name), ks_layer.param_sparsity.item(), epoch + 1)
        
    print('Completed testing training dataset for epoch {}'.format(epoch))

In [None]:
from typing import List
from tensorboardX import SummaryWriter


print('Setting up model for kernel sparsity tracking...')
conv_layers_names = []
for name, mod in model.named_modules():
    if isinstance(mod, Conv2d): #to add the FC layers: isinstance(mod, Conv2d) or isinstance(mod, Linear) 
        conv_layers_names.append(name)
analyzed_layers = KSAnalyzerLayer.analyze_layers(model, conv_layers_names)

def _record_kernel_sparsity(analyzed_layers: List[KSAnalyzerLayer], writer: SummaryWriter, epoch: int):
#     layers_sparsities = []
    for ks_layer in analyzed_layers:
        tag = 'Kernel Sparsity/{}'.format(ks_layer.name)
        writer.add_scalar(tag, ks_layer.param_sparsity.item(), epoch)
    print('sparsity per layer [%]: '+ str([round(ks_layer.param_sparsity.item()*100.0,0) for ks_layer in analyzed_layers]))


## Optimizer , Loss, Logging etc.

In [None]:
import torch
from torch import optim
from torch.nn import DataParallel
from neuralmagicML.utils import CrossEntropyLossCalc, TopKAccuracy
import os

init_lr = 0.01
momentum = 0.9
weight_decay = 1e-4

print('Creating optimizer with initial lr: {}, momentum: {}, weight_decay: {}'
          .format(init_lr, momentum, weight_decay))
optimizer = optim.SGD(
    model.parameters(), init_lr, momentum=momentum, weight_decay=weight_decay, nesterov=True
)
loss_extras = {
    'top1acc': TopKAccuracy(1),
    'top5acc': TopKAccuracy(5)
}
loss_calc = CrossEntropyLossCalc(loss_extras)
print('Created loss calc {} with extras {}'.format(loss_calc, ', '.join(loss_extras.keys())))


logs_dir = './logs'
model_dir = '../pruned'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

if not os.path.exists(logs_dir):
    os.makedirs(logs_dir)
    
save_rate = 5
print('Creating summary writer in {}'.format(logs_dir))
writer = SummaryWriter(logdir=logs_dir, comment='imagenet training')
if isinstance(model, DataParallel):
    model = model.module


## Scheduling the prunning process:
Prunning involves two intertwined processes:
1. sparsification - i.e. the selection of weights to zero out.
2. re-training the model post sparsification.

In practice, a gradual increase of the sparsity level allows for the recovery of accuracy by retraining (up to high levels of sparsity)

In order to simplify these control of these two processes, we introduce 'Modifier' classes which manage the schedules  of the associated hyperparameters (i.e. learning_rate, sparsity per layer) thoughout the epochs. For added convinience below is a simple GUI to set these hyperparameters.
The GUI allows for controlling the target sparsity on an individual layer basis / global / mixed fashion.
Try setting all layers to a sparsity level of 80%

In [None]:
import ipywidgets as widgets

############################################################
## configuration of sparsity levels / enables per layer ####
############################################################
c0 = widgets.VBox([widgets.Checkbox(description=ks_layer.name, value=True) for ks_layer in analyzed_layers])
c1 = widgets.VBox([widgets.FloatSlider(value=0.5,min=0.05,max=0.99) for _ in range(len(analyzed_layers))])
layer_ctrl = widgets.HBox([c0,c1])
global_ctrl = widgets.HBox([widgets.Checkbox(description='enable/disable all', value = True), 
                            widgets.FloatSlider(value=0.5,min=0.05,max=0.99,description='sparsity [%]')
                           ])
output2 = widgets.Output()

activated_layers = [child.value for child in layer_ctrl.children[0].children]

def global_enable_change(change):
    with output2:
        state = change['new']
        print(state)
        if state is not None:
            for ckbx_child in layer_ctrl.children[0].children:
                ckbx_child.value = state
                
global_ctrl.children[0].observe(global_enable_change, names='value')   

def global_sparsity_set(change):
    with output2:
        val = change['new']
        print(val)
        if val is not None:
            for ckbx_child, sldr_child in zip(layer_ctrl.children[0].children, layer_ctrl.children[1].children):
                if ckbx_child.value:
                    sldr_child.value = val
    
global_ctrl.children[1].observe(global_sparsity_set, names='value')   

###############################################
## configuration of learning rate schedule ####
###############################################

lr_class_dict = {   #TODO: read from CONSTRUCTORS in modifier_lr.py instead
                    #to include all supported methods in the GUI
    'ExponentialLR': {'gamma': [0.95, widgets.BoundedFloatText]}, #bound by 0.0
    'StepLR': {'step_size': [20, widgets.BoundedIntText], #bound by 1
              'gamma': [0.2, widgets.BoundedFloatText]}
}

lr_mod_args_field_initval = {
    'start_epoch': 25.0,# 'start epoch:'],
    'end_epoch': 35.0,# 'end epoch  :'],
    'update_frequency': 1.0,# 'update freq:'],
    'init_lr': 0.001# 'initial learning rate :']
}

style = {'description_width': 'initial'}
# lr_cfg_list = [widgets.Text(value='learning rate schedule', disabled=True)]
lr_section_title = widgets.Text(value='learning rate schedule', disabled=True)
lr_cfg_list =[]
for fld, val in lr_mod_args_field_initval.items():
    lr_cfg_list.append(widgets.BoundedFloatText(value=val, description=fld, disabled=False, min=0, style=style,))

lr_slct = widgets.Dropdown(
    options=[key for key in lr_class_dict.keys()],  
    value=[key for key in lr_class_dict.keys()][0],
    description='lr_class',
)
# lr_cfg_list.append(lr_slct)

def create_lr_slct_list():
    lr_slct_params = [] #create new widgets
    for param, val in lr_class_dict[lr_slct.value].items():
        lr_slct_params.append(val[1](value=val[0],description=param))
    return lr_slct_params
slct_param = widgets.VBox(children=create_lr_slct_list())
# lr_cfg_list.append(slct_param)

def refresh_lr_param(change):
    if change['new']:
        val = lr_slct.value
        slct_param.children = create_lr_slct_list()

lr_slct.observe(refresh_lr_param, names='value')   
lr_cfg = widgets.VBox([lr_section_title, *lr_cfg_list, lr_slct, slct_param])

##########################################
## configuration of prunning schedule ####
##########################################

prunning_mod_args_field_initval = {
    'start_epoch': 0.0,# 'start epoch:'],
    'end_epoch': 25.0,#'end epoch  :'],
    'update_frequency': 1.0#,'update freq:'],


}

style = {'description_width': 'initial'}
prn_section_title = widgets.Text(value='prunning schedule', disabled=True)
prn_cfg_list =[]
for fld, val in prunning_mod_args_field_initval.items():
    prn_cfg_list.append(widgets.BoundedFloatText(value=val, description=fld, disabled=False, min=0, style=style,))


prn_cfg = widgets.VBox([prn_section_title,*prn_cfg_list])
schd_cfg = widgets.VBox([lr_cfg, prn_cfg])#,prn_cfg_list])
display(widgets.HBox([widgets.VBox([global_ctrl,layer_ctrl]),schd_cfg]))


In [None]:
print('Creating learning rate schedule...')
lr_mod_args = {}

for child in lr_cfg_list: 
    lr_mod_args[child.description] = child.value
assert(lr_slct.description == 'lr_class')
lr_mod_args['lr_class'] = lr_slct.value
lr_mod_args['lr_kwargs'] = {}
for child in slct_param.children:
    lr_mod_args['lr_kwargs'][child.description] = child.value

lr_mod = LearningRateModifier(**lr_mod_args)

print('Creating sparsification schedule...')

def create_ks_mod_args(layer_name, final_sparsity):
    ks_mod_args ={
        'param': 'weight',
        'init_sparsity': 0.05,
        'inter_func': 'linear',
        'layers': [layer_name],
        'final_sparsity': final_sparsity
    }
    # add common fields
    for child in prn_cfg_list:
        ks_mod_args[child.description] = child.value
    return ks_mod_args

ks_mod_args_list = []
for ckbx_child, sldr_child in zip(layer_ctrl.children[0].children, layer_ctrl.children[1].children):
        if ckbx_child.value: #layer is sparsified
            layer_name = ckbx_child.description
            final_sparsity = sldr_child.value#
            ks_mod_args_list.append(create_ks_mod_args(layer_name, final_sparsity))
            
ks_mod_list = [GradualKSModifier(**ks_mod_args) for ks_mod_args in ks_mod_args_list]
modifiers = [lr_mod, *ks_mod_list]

modifier_manager = ScheduledModifierManager(modifiers)
optimizer = ScheduledOptimizer(optimizer, model, modifier_manager, steps_per_epoch=len(train_dataset))

## Setting up training
The following should look very familiar - it is in fact the exact same code from our previous tutorial (NB1). 

In [None]:
from tqdm import tqdm
from torch.utils.data import DataLoader

from typing import Tuple, Dict
from torch import Tensor
import torch
from torch.nn import Module


def _test_datasets(model, train_data_loader: DataLoader, val_data_loader: DataLoader,
                   writer: SummaryWriter, epoch: int) -> Tuple[Dict[str, float], Dict[str, float]]:
    val_losses , train_losses = None, None
    if val_data_loader  is not None:
        print('Running test for validation dataset for epoch {}'.format(epoch))
        val = test_epoch(model, val_data_loader, loss_calc, device, epoch)
        print('Completed test for validation dataset for epoch {}'.format(epoch))
        val_losses = {}
        for loss, _ in val.items():
            val_losses[loss] = torch.mean(torch.cat(val[loss])).item()
            val_tag = 'Test/validation/{}'.format(loss)
            writer.add_scalar(val_tag, val_losses[loss], epoch)
        
        val_loss_str = 'validation set - epoch: {} '.format(epoch)
        for loss, value in val_losses.items():
            val_loss_str += (loss + ': {0:.2f} '.format(value))
        print(val_loss_str)
        
        
    if train_data_loader is not None:
        print('Running test for train dataset for epoch {}'.format(epoch))
        train = test_epoch(model, train_data_loader, loss_calc, device, epoch)
        print('Completed test for train dataset for epoch {}'.format(epoch))
        train_losses = {}

        for loss, _ in train.items():
            train_losses[loss] = torch.mean(torch.cat(train[loss])).item()
            train_tag = 'Test/training/{}'.format(loss)
            writer.add_scalar(train_tag, train_losses[loss], epoch)

        
        train_loss_str = 'training set - epoch: {} '.format(epoch)
        for loss, value in train_losses.items():
            train_loss_str += (loss + ': {0:.2f} '.format(value))
        print(train_loss_str)


    return val_losses , train_losses


def test_epoch(model: Module, data_loader: DataLoader, loss, device, epoch: int) -> Dict:
    model.eval()
    results = {}#ModuleTestResults()
    with torch.no_grad():
        for batch, (*x_feature, y_lab) in enumerate(tqdm(data_loader)):
            y_lab = y_lab.to(device)
            x_feature = tuple([dat.to(device) for dat in x_feature])
            batch_size = y_lab.shape[0]
            
            y_pred = model(*x_feature)

            losses = loss(x_feature, y_lab, y_pred)  # type: Dict[str, Tensor]
            for key, val in losses.items():
                if key not in results:
                    results[key] = []

                result = val.detach_().cpu()
                result = result.repeat(batch_size) #repeat tensor so that there is no dependency on batch size
                results[key].append(result)
#             results.append(losses, batch_size)
    return results

def train_epoch(model: Module, data_loader: DataLoader, optimizer, loss, device, data_counter: int):
    model.train()
    
    for batch, (*x_feature, y_lab) in enumerate(tqdm(data_loader)):
        # copy next batch to the device we are using
        y_lab = y_lab.to(device)
        x_feature = tuple([dat.to(device) for dat in x_feature])
        batch_size = y_lab.shape[0]

        # Zero the parameter gradients
        optimizer.zero_grad()

        # forward 
        y_pred = model(*x_feature)
        
        # update losses
        losses = loss(x_feature, y_lab, y_pred)  # type: Dict[str, Tensor]
        
        # backward
        losses['loss'].backward()
        
        # take SGD step
        optimizer.step(closure=None)
        
        # log loss and accuracy
        data_counter += batch_size
        for _loss, _value in losses.items():
            writer.add_scalar('Train/{}'.format(_loss), _value.item(), data_counter)



## Prunning main loop:
This too is very simialr to the training main loop previously introduced, the main differences are:
1. we are tracking sparsity
2. we are following the schedules as orchastrated by the modifiers above

In [None]:
import math
print('Running baseline test...')
epoch = -1
_record_kernel_sparsity(analyzed_layers, writer, epoch)
_test_datasets(model, None, val_data_loader, writer, epoch=-1)

print('Training model')
num_epochs = int(math.ceil(modifier_manager.max_epochs))
data_counter = 0

for epoch in range(num_epochs):
    print('Starting epoch {}'.format(epoch))
    optimizer.epoch_start()
    _record_kernel_sparsity(analyzed_layers, writer, epoch)



    train_epoch(model, train_data_loader, optimizer, loss_calc, device, data_counter)
    optimizer.epoch_end()
    val_losses, train_losses = _test_datasets(model, None, val_data_loader, writer, epoch)

    if save_rate > 0 and epoch % save_rate == 0:
        save_path = os.path.join(model_dir, 'resnet50-epoch={:03d}-val={:.4f}.pth'
                                 .format(epoch, val_losses['loss']))
        save_model(save_path, model, optimizer, epoch)
        print('saved model checkpoint at {}'.format(save_path))

_record_kernel_sparsity(analyzed_layers, writer, num_epochs)


scalars_json_path = os.path.join(logs_dir, 'all_scalars.json')
writer.export_scalars_to_json(scalars_json_path)
writer.close()

save_path = os.path.join(model_dir, 'resnet50-pruned.pth')
print('Finished training, saving model to {}'.format(save_path))
save_model(save_path, model)
print('Saved model')

In [None]:
modifier_manager.max_epochs