# Notebook #2 - Model prunning
This tutorial provides a walk through of model prunning assuming we already have a trained (dense) model:

1. prepare and load data
2. load model
3. sparsity: why/how to prune a model

## load Data
This example uses the same dataset examplified in notebook #1: 
\
imagenette/imagewoof datasets from fast.ai provided under the Apache License 2.0

In [1]:
from neuralmagicML.datasets import *
from torch.utils.data import Dataset, DataLoader

dataset_type = 'imagenette'
dataset_path = '../data/imagenette-160/'
device = 'cuda:0'
 #device to run on: 'cpu' / 'cuda:0'
    
train_batch_size = 128
test_batch_size = 256
dataset_early_stop = -1

train_dataset = ImagenetteDataset(dataset_path, train=True, rand_trans=True)
train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8)

val_dataset = ImagenetteDataset(dataset_path, train=False, rand_trans=False)
val_data_loader = DataLoader(val_dataset, batch_size=test_batch_size, shuffle=True, num_workers=8)

train_test_dataset = ImagenetteDataset(dataset_path, train=True, rand_trans=False)
train_test_data_loader = DataLoader(train_test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=8)


already downloaded imagenette of size ImagenetteSize.s160
already downloaded imagenette of size ImagenetteSize.s160
already downloaded imagenette of size ImagenetteSize.s160


## Load model
Let us load the ResNet-50 Model we have trained in the previous notebook example.

In [2]:
from neuralmagicML.models import *
from neuralmagicML.sparsity import *
from torch.nn import Conv2d

model_path = '../checkpoints/resnet50-epoch=030-val=0.3995.pth'
print('initializing model...')
model = resnet50(num_classes = 10)
model = model.to(device)

print('loading model...')
load_model(model_path, model)
print('done')


initializing model...
loading model...
done


## Sparsity, Prunning and high level motivation
### sparsity:

Informally, sparsity is the degree in which a tensor is comprised of zeros.

Slightly more formally:

let $N^i$ be the total number of elements in a (e.g. weight) tensor $W_i$

let $N^i_z$ be the number of elements which are zero-valued within that tensor

The sparsity level associated with that tensor is defined as $s_i \triangleq \dfrac{N^i_z}{N^i}$ 


### prunning:

Prunning is the process selectively setting weights in a model to zero. The selection of how many and which weights to set to zero affects the accuracy and model required FLOPs and memory footprint. **Critically - attaining high levels of sparsity while preserving accuracy is possible, as we will demonstrate in this notebook**

### Sparsity --> Less FLOPs -?-> accelerated performance:
The fact that models can be heavily sparsified with little or no accuracy hit is well known in the research community. Intuitively, the higher the sparsity level the less theoretical FLOPs are required and hence a correspondingly large performance acceleration. However, while the first part of that intutive reasoning is true (higer sparsity --> less theoretical FLOPs), the second one is not nececcerily true (less theoretical FLOPs -/-> higher performance). The reason is that typical hardware such as GPUs is very ill-equiped to take advantage of that sparsity, and FLOPs savings in practice is very hard to come by. On CPUs, in contrast, algorithms for that very exploitation can be flexibly developed.

Armed with this insight we are ready (and motivated!) to start looking at model sparsity with the aim of increasing it (via prunning). 

In [3]:
from typing import List
from tensorboardX import SummaryWriter


print('Setting up model for kernel sparsity tracking...')
conv_layers_names = []
for name, mod in model.named_modules():
    if isinstance(mod, Conv2d): #to add the FC layers: isinstance(mod, Conv2d) or isinstance(mod, Linear) 
        conv_layers_names.append(name)
analyzed_layers = KSAnalyzerLayer.analyze_layers(model, conv_layers_names)

def _record_kernel_sparsity(analyzed_layers: List[KSAnalyzerLayer], writer: SummaryWriter, epoch: int):
#     layers_sparsities = []
    for ks_layer in analyzed_layers:
        tag = 'Kernel Sparsity/{}'.format(ks_layer.name)
        writer.add_scalar(tag, ks_layer.param_sparsity.item(), epoch)
    print('sparsity per layer [%]: '+ str([round(ks_layer.param_sparsity.item()*100.0,0) for ks_layer in analyzed_layers]))


Setting up model for kernel sparsity tracking...


## Optimizer , Loss, Logging etc.

In [4]:
import torch
from torch import optim
from torch.nn import DataParallel
from neuralmagicML.utils import CrossEntropyLossCalc, TopKAccuracy
import os

init_lr = 0.01
momentum = 0.9
weight_decay = 1e-4

print('Creating optimizer with initial lr: {}, momentum: {}, weight_decay: {}'
          .format(init_lr, momentum, weight_decay))
optimizer = optim.SGD(
    model.parameters(), init_lr, momentum=momentum, weight_decay=weight_decay, nesterov=True
)
loss_extras = {
    'top1acc': TopKAccuracy(1),
    'top5acc': TopKAccuracy(5)
}
loss_calc = CrossEntropyLossCalc(loss_extras)
print('Created loss calc {} with extras {}'.format(loss_calc, ', '.join(loss_extras.keys())))


logs_dir = './logs'
model_dir = '../pruned'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

if not os.path.exists(logs_dir):
    os.makedirs(logs_dir)
    
save_rate = 5
print('Creating summary writer in {}'.format(logs_dir))
writer = SummaryWriter(logdir=logs_dir, comment='imagenet training')
if isinstance(model, DataParallel):
    model = model.module


Creating optimizer with initial lr: 0.01, momentum: 0.9, weight_decay: 0.0001
Created loss calc <neuralmagicML.utils.loss_calc.CrossEntropyLossCalc object at 0x7f6f127c4a58> with extras top1acc, top5acc
Creating summary writer in ./logs


## Scheduling the prunning process:
Prunning involves two intertwined processes:
1. sparsification - i.e. the selection of weights to zero out.
2. re-training the model post sparsification.

In practice, a gradual increase of the sparsity level allows for the recovery of accuracy by retraining (up to high levels of sparsity)

In order to simplify these control of these two processes, we introduce 'Modifier' classes which manage the schedules  of the associated hyperparameters (i.e. learning_rate, sparsity per layer) thoughout the epochs. For added convinience below is a simple GUI to set these hyperparameters.
The GUI allows for controlling the target sparsity on an individual layer basis / global / mixed fashion.
Try setting all layers to a sparsity level of 80%

In [5]:
import ipywidgets as widgets

############################################################
## configuration of sparsity levels / enables per layer ####
############################################################
c0 = widgets.VBox([widgets.Checkbox(description=ks_layer.name, value=True) for ks_layer in analyzed_layers])
c1 = widgets.VBox([widgets.FloatSlider(value=0.5,min=0.05,max=0.99) for _ in range(len(analyzed_layers))])
layer_ctrl = widgets.HBox([c0,c1])
global_ctrl = widgets.HBox([widgets.Checkbox(description='enable/disable all', value = True), 
                            widgets.FloatSlider(value=0.5,min=0.05,max=0.99,description='sparsity [%]')
                           ])
output2 = widgets.Output()

activated_layers = [child.value for child in layer_ctrl.children[0].children]

def global_enable_change(change):
    with output2:
        state = change['new']
        print(state)
        if state is not None:
            for ckbx_child in layer_ctrl.children[0].children:
                ckbx_child.value = state
                
global_ctrl.children[0].observe(global_enable_change, names='value')   

def global_sparsity_set(change):
    with output2:
        val = change['new']
        print(val)
        if val is not None:
            for ckbx_child, sldr_child in zip(layer_ctrl.children[0].children, layer_ctrl.children[1].children):
                if ckbx_child.value:
                    sldr_child.value = val
    
global_ctrl.children[1].observe(global_sparsity_set, names='value')   

###############################################
## configuration of learning rate schedule ####
###############################################

lr_class_dict = {   #TODO: read from CONSTRUCTORS in modifier_lr.py instead
                    #to include all supported methods in the GUI
    'ExponentialLR': {'gamma': [0.95, widgets.BoundedFloatText]}, #bound by 0.0
    'StepLR': {'step_size': [20, widgets.BoundedIntText], #bound by 1
              'gamma': [0.2, widgets.BoundedFloatText]}
}

lr_mod_args_field_initval = {
    'start_epoch': 25.0,# 'start epoch:'],
    'end_epoch': 35.0,# 'end epoch  :'],
    'update_frequency': 1.0,# 'update freq:'],
    'init_lr': 0.001# 'initial learning rate :']
}

style = {'description_width': 'initial'}
# lr_cfg_list = [widgets.Text(value='learning rate schedule', disabled=True)]
lr_section_title = widgets.Text(value='learning rate schedule', disabled=True)
lr_cfg_list =[]
for fld, val in lr_mod_args_field_initval.items():
    lr_cfg_list.append(widgets.BoundedFloatText(value=val, description=fld, disabled=False, min=0, style=style,))

lr_slct = widgets.Dropdown(
    options=[key for key in lr_class_dict.keys()],  
    value=[key for key in lr_class_dict.keys()][0],
    description='lr_class',
)
# lr_cfg_list.append(lr_slct)

def create_lr_slct_list():
    lr_slct_params = [] #create new widgets
    for param, val in lr_class_dict[lr_slct.value].items():
        lr_slct_params.append(val[1](value=val[0],description=param))
    return lr_slct_params
slct_param = widgets.VBox(children=create_lr_slct_list())
# lr_cfg_list.append(slct_param)

def refresh_lr_param(change):
    if change['new']:
        val = lr_slct.value
        slct_param.children = create_lr_slct_list()

lr_slct.observe(refresh_lr_param, names='value')   
lr_cfg = widgets.VBox([lr_section_title, *lr_cfg_list, lr_slct, slct_param])

##########################################
## configuration of prunning schedule ####
##########################################

prunning_mod_args_field_initval = {
    'start_epoch': 0.0,# 'start epoch:'],
    'end_epoch': 25.0,#'end epoch  :'],
    'update_frequency': 1.0#,'update freq:'],


}

style = {'description_width': 'initial'}
prn_section_title = widgets.Text(value='prunning schedule', disabled=True)
prn_cfg_list =[]
for fld, val in prunning_mod_args_field_initval.items():
    prn_cfg_list.append(widgets.BoundedFloatText(value=val, description=fld, disabled=False, min=0, style=style,))


prn_cfg = widgets.VBox([prn_section_title,*prn_cfg_list])
schd_cfg = widgets.VBox([lr_cfg, prn_cfg])#,prn_cfg_list])
display(widgets.HBox([widgets.VBox([global_ctrl,layer_ctrl]),schd_cfg]))


HBox(children=(VBox(children=(HBox(children=(Checkbox(value=True, description='enable/disable all'), FloatSlid…

In [6]:
print('Creating learning rate schedule...')
lr_mod_args = {}

for child in lr_cfg_list: 
    lr_mod_args[child.description] = child.value
assert(lr_slct.description == 'lr_class')
lr_mod_args['lr_class'] = lr_slct.value
lr_mod_args['lr_kwargs'] = {}
for child in slct_param.children:
    lr_mod_args['lr_kwargs'][child.description] = child.value

lr_mod = LearningRateModifier(**lr_mod_args)

print('Creating sparsification schedule...')

def create_ks_mod_args(layer_name, final_sparsity):
    ks_mod_args ={
        'param': 'weight',
        'init_sparsity': 0.05,
        'inter_func': 'linear',
        'layers': [layer_name],
        'final_sparsity': final_sparsity
    }
    # add common fields
    for child in prn_cfg_list:
        ks_mod_args[child.description] = child.value
    return ks_mod_args

ks_mod_args_list = []
for ckbx_child, sldr_child in zip(layer_ctrl.children[0].children, layer_ctrl.children[1].children):
        if ckbx_child.value: #layer is sparsified
            layer_name = ckbx_child.description
            final_sparsity = sldr_child.value#
            ks_mod_args_list.append(create_ks_mod_args(layer_name, final_sparsity))
            
ks_mod_list = [GradualKSModifier(**ks_mod_args) for ks_mod_args in ks_mod_args_list]
modifiers = [lr_mod, *ks_mod_list]

modifier_manager = ScheduledModifierManager(modifiers)
optimizer = ScheduledOptimizer(optimizer, model, modifier_manager, steps_per_epoch=len(train_dataset))

Creating learning rate schedule...
Creating sparsification schedule...


In [7]:
ks_mod_args_list

[{'param': 'weight',
  'init_sparsity': 0.05,
  'inter_func': 'linear',
  'layers': ['input.conv'],
  'final_sparsity': 0.5,
  'start_epoch': 0.0,
  'end_epoch': 25.0,
  'update_frequency': 1.0},
 {'param': 'weight',
  'init_sparsity': 0.05,
  'inter_func': 'linear',
  'layers': ['sections.0.0.conv1'],
  'final_sparsity': 0.5,
  'start_epoch': 0.0,
  'end_epoch': 25.0,
  'update_frequency': 1.0},
 {'param': 'weight',
  'init_sparsity': 0.05,
  'inter_func': 'linear',
  'layers': ['sections.0.0.conv2'],
  'final_sparsity': 0.5,
  'start_epoch': 0.0,
  'end_epoch': 25.0,
  'update_frequency': 1.0},
 {'param': 'weight',
  'init_sparsity': 0.05,
  'inter_func': 'linear',
  'layers': ['sections.0.0.conv3'],
  'final_sparsity': 0.5,
  'start_epoch': 0.0,
  'end_epoch': 25.0,
  'update_frequency': 1.0},
 {'param': 'weight',
  'init_sparsity': 0.05,
  'inter_func': 'linear',
  'layers': ['sections.0.0.identity.conv'],
  'final_sparsity': 0.5,
  'start_epoch': 0.0,
  'end_epoch': 25.0,
  'updat

## Setting up training
The following should look very familiar - it is in fact the exact same code from our previous tutorial (NB1). 

In [8]:
from tqdm import tqdm
from torch.utils.data import DataLoader

from typing import Tuple, Dict
from torch import Tensor
import torch
from torch.nn import Module


def _test_datasets(model, train_data_loader: DataLoader, val_data_loader: DataLoader,
                   writer: SummaryWriter, epoch: int) -> Tuple[Dict[str, float], Dict[str, float]]:
    val_losses , train_losses = None, None
    if val_data_loader  is not None:
        print('Running test for validation dataset for epoch {}'.format(epoch))
        val = test_epoch(model, val_data_loader, loss_calc, device, epoch)
        print('Completed test for validation dataset for epoch {}'.format(epoch))
        val_losses = {}
        for loss, _ in val.items():
            val_losses[loss] = torch.mean(torch.cat(val[loss])).item()
            val_tag = 'Test/validation/{}'.format(loss)
            writer.add_scalar(val_tag, val_losses[loss], epoch)
        
        val_loss_str = 'validation set - epoch: {} '.format(epoch)
        for loss, value in val_losses.items():
            val_loss_str += (loss + ': {0:.2f} '.format(value))
        print(val_loss_str)
        
        
    if train_data_loader is not None:
        print('Running test for train dataset for epoch {}'.format(epoch))
        train = test_epoch(model, train_data_loader, loss_calc, device, epoch)
        print('Completed test for train dataset for epoch {}'.format(epoch))
        train_losses = {}

        for loss, _ in train.items():
            train_losses[loss] = torch.mean(torch.cat(train[loss])).item()
            train_tag = 'Test/training/{}'.format(loss)
            writer.add_scalar(train_tag, train_losses[loss], epoch)

        
        train_loss_str = 'training set - epoch: {} '.format(epoch)
        for loss, value in train_losses.items():
            train_loss_str += (loss + ': {0:.2f} '.format(value))
        print(train_loss_str)


    return val_losses , train_losses


def test_epoch(model: Module, data_loader: DataLoader, loss, device, epoch: int) -> Dict:
    model.eval()
    results = {}#ModuleTestResults()
    with torch.no_grad():
        for batch, (*x_feature, y_lab) in enumerate(tqdm(data_loader)):
            y_lab = y_lab.to(device)
            x_feature = tuple([dat.to(device) for dat in x_feature])
            batch_size = y_lab.shape[0]
            
            y_pred = model(*x_feature)

            losses = loss(x_feature, y_lab, y_pred)  # type: Dict[str, Tensor]
            for key, val in losses.items():
                if key not in results:
                    results[key] = []

                result = val.detach_().cpu()
                result = result.repeat(batch_size) #repeat tensor so that there is no dependency on batch size
                results[key].append(result)
#             results.append(losses, batch_size)
    return results

def train_epoch(model: Module, data_loader: DataLoader, optimizer, loss, device, data_counter: int):
    model.train()
    
    for batch, (*x_feature, y_lab) in enumerate(tqdm(data_loader)):
        # copy next batch to the device we are using
        y_lab = y_lab.to(device)
        x_feature = tuple([dat.to(device) for dat in x_feature])
        batch_size = y_lab.shape[0]

        # Zero the parameter gradients
        optimizer.zero_grad()

        # forward 
        y_pred = model(*x_feature)
        
        # update losses
        losses = loss(x_feature, y_lab, y_pred)  # type: Dict[str, Tensor]
        
        # backward
        losses['loss'].backward()
        
        # take SGD step
        optimizer.step(closure=None)
        
        # log loss and accuracy
        data_counter += batch_size
        for _loss, _value in losses.items():
            writer.add_scalar('Train/{}'.format(_loss), _value.item(), data_counter)



## Prunning main loop:
This too is very simialr to the training main loop previously introduced, the main differences are:
1. we are tracking sparsity
2. we are following the schedules as orchastrated by the modifiers above

In [None]:
import math
print('Running baseline test...')
epoch = -1
_record_kernel_sparsity(analyzed_layers, writer, epoch)
_test_datasets(model, None, val_data_loader, writer, epoch=-1)

print('Training model')
num_epochs = int(math.ceil(modifier_manager.max_epochs))
data_counter = 0

for epoch in range(num_epochs):
    print('Starting epoch {}'.format(epoch))
    optimizer.epoch_start()
    _record_kernel_sparsity(analyzed_layers, writer, epoch)



    train_epoch(model, train_data_loader, optimizer, loss_calc, device, data_counter)
    optimizer.epoch_end()
    val_losses, train_losses = _test_datasets(model, None, val_data_loader, writer, epoch)

    if save_rate > 0 and epoch % save_rate == 0:
        save_path = os.path.join(model_dir, 'resnet50-epoch={:03d}-val={:.4f}.pth'
                                 .format(epoch, val_losses['loss']))
        save_model(save_path, model, optimizer, epoch)
        print('saved model checkpoint at {}'.format(save_path))

_record_kernel_sparsity(analyzed_layers, writer, num_epochs)


scalars_json_path = os.path.join(logs_dir, 'all_scalars.json')
writer.export_scalars_to_json(scalars_json_path)
writer.close()

save_path = os.path.join(model_dir, 'resnet50-pruned.pth')
print('Finished training, saving model to {}'.format(save_path))
save_model(save_path, model)
print('Saved model')

Running baseline test...


  0%|          | 0/2 [00:00<?, ?it/s]

sparsity per layer [%]: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Running test for validation dataset for epoch -1


100%|██████████| 2/2 [00:01<00:00,  1.06s/it]


Completed test for validation dataset for epoch -1
validation set - epoch: -1 loss: 0.40 top1acc: 87.20 top5acc: 98.60 
Training model
Starting epoch 0


  0%|          | 0/101 [00:00<?, ?it/s]

sparsity per layer [%]: [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0]


100%|██████████| 101/101 [02:09<00:00,  1.63it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

Running test for validation dataset for epoch 0


100%|██████████| 2/2 [00:01<00:00,  1.20s/it]


Completed test for validation dataset for epoch 0
validation set - epoch: 0 loss: 0.37 top1acc: 88.80 top5acc: 98.80 
saved model checkpoint at ../pruned/resnet50-epoch=000-val=0.3695.pth
Starting epoch 1


  0%|          | 0/101 [00:00<?, ?it/s]

sparsity per layer [%]: [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]


 12%|█▏        | 12/101 [00:08<01:01,  1.44it/s]

In [None]:
modifier_manager.max_epochs