# Notebook #2 - Model prunning
This tutorial provides a walk through of model prunning assuming we already have a trained (dense) model:

1. prepare and load data
2. load model
3. sparsity: why/how to prune a model

## load Data
This example uses the same dataset examplified in notebook #1: 
\
imagenette/imagewoof datasets from fast.ai provided under the Apache License 2.0

In [1]:
from neuralmagicML.datasets import *
from torch.utils.data import Dataset, DataLoader

dataset_type = 'imagenette'
dataset_path = '../data/imagenette-160/'
device = 'cuda:2'
 #device to run on: 'cpu' / 'cuda:0'
    
train_batch_size = 128
test_batch_size = 256
dataset_early_stop = -1

train_dataset = ImagenetteDataset(dataset_path, train=True, rand_trans=True)
train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8)

val_dataset = ImagenetteDataset(dataset_path, train=False, rand_trans=False)
val_data_loader = DataLoader(val_dataset, batch_size=test_batch_size, shuffle=True, num_workers=8)

train_test_dataset = ImagenetteDataset(dataset_path, train=True, rand_trans=False)
train_test_data_loader = DataLoader(train_test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=8)


already downloaded imagenette of size ImagenetteSize.s160
already downloaded imagenette of size ImagenetteSize.s160
already downloaded imagenette of size ImagenetteSize.s160


## Load model
Let us load the ResNet-50 Model we have trained in the previous notebook example.

In [None]:
from neuralmagicML.models import *
from neuralmagicML.sparsity import *
from torch.nn import Conv2d

model_path = '../checkpoints/resnet50-epoch=030-val=0.3995.pth'
print('initializing model...')
model = resnet50(num_classes = 10)
model = model.to(device)

print('loading model...')
load_model(model_path, model)



initializing model...


## Sparsity, Prunning and high level motivation
### sparsity:

Informally, sparsity is the degree in which a tensor is comprised of zeros.

Slightly more formally:

let $N^i$ be the total number of elements in a (e.g. weight) tensor $W_i$

let $N^i_z$ be the number of elements which are zero-valued within that tensor

The sparsity level associated with that tensor is defined as $s_i \triangleq \dfrac{N^i_z}{N^i}$ 


### prunning:

Prunning is the process selectively setting weights in a model to zero. The selection of how many and which weights to set to zero affects the accuracy and model required FLOPs and memory footprint. **Critically - attaining high levels of sparsity while preserving accuracy is possible, as we will demonstrate in this notebook**

### Sparsity --> Less FLOPs -?-> accelerated performance:
The fact that models can be heavily sparsified with little or no accuracy hit is well known in the research community. Intuitively, the higher the sparsity level the less theoretical FLOPs are required and hence a correspondingly large performance acceleration. However, while the first part of that intutive reasoning is true (higer sparsity --> less theoretical FLOPs), the second one is not nececcerily true (less theoretical FLOPs -/-> higher performance). The reason is that typical hardware such as GPUs is very ill-equiped to take advantage of that sparsity, and FLOPs savings in practice is very hard to come by. On CPUs, in contrast, algorithms for that very exploitation can be flexibly developed.

Armed with this insight we are ready (and motivated!) to start looking at model sparsity with the aim of increasing it (via prunning). 

In [None]:
from typing import List
from tensorboardX import SummaryWriter


print('Setting up model for kernel sparsity tracking...')
conv_layers_names = []
for name, mod in model.named_modules():
    if isinstance(mod, Conv2d): #to add the FC layers: isinstance(mod, Conv2d) or isinstance(mod, Linear) 
        conv_layers_names.append(name)
analyzed_layers = KSAnalyzerLayer.analyze_layers(model, conv_layers_names)

def _record_kernel_sparsity(analyzed_layers: List[KSAnalyzerLayer], writer: SummaryWriter, epoch: int):
#     layers_sparsities = []
    for ks_layer in analyzed_layers:
        tag = 'Kernel Sparsity/{}'.format(ks_layer.name)
        writer.add_scalar(tag, ks_layer.param_sparsity.item(), epoch)
    print('sparsity per layer [%]: '+ str([int(ks_layer.param_sparsity.item()*100.0)} for ks_layer in analyzed_layers]))


## Optimizer , Loss, Logging etc.

In [None]:
import torch
from torch import optim
from torch.nn import DataParallel
from neuralmagicML.utils import CrossEntropyLossCalc, TopKAccuracy
import os

init_lr = 0.01
momentum = 0.9
weight_decay = 1e-4

print('Creating optimizer with initial lr: {}, momentum: {}, weight_decay: {}'
          .format(init_lr, momentum, weight_decay))
optimizer = optim.SGD(
    model.parameters(), init_lr, momentum=momentum, weight_decay=weight_decay, nesterov=True
)
loss_extras = {
    'top1acc': TopKAccuracy(1),
    'top5acc': TopKAccuracy(5)
}
loss_calc = CrossEntropyLossCalc(loss_extras)
print('Created loss calc {} with extras {}'.format(loss_calc, ', '.join(loss_extras.keys())))


logs_dir = './logs'
model_dir = '../pruned'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

if not os.path.exists(logs_dir):
    os.makedirs(logs_dir)
    
save_rate = 5
print('Creating summary writer in {}'.format(logs_dir))
writer = SummaryWriter(logdir=logs_dir, comment='imagenet training')
if isinstance(model, DataParallel):
    model = model.module


## Scheduling the prunning process:
Prunning involves two intertwined processes:
1. sparsification - i.e. the selection of weights to zero out.
2. re-training the model post sparsification.

In practice, a gradual increase of the sparsity level allows for the recovery of accuracy by retraining (up to high levels of sparsity)

In order to simplify these control of these two processes, we introduce 'Modifier' classes which manage the schedules  of the associated hyperparameters (i.e. learning_rate, sparsity per layer) thoughout the epochs.

In [None]:
print('Creating learning rate schedule...')
lr_mod_args = {
    'start_epoch': 40.0,
    'end_epoch': 100.0,
    'update_frequency': 1.0,
    'lr_class': 'ExponentialLR',
    'lr_kwargs':
      {'gamma': 0.95},
    'init_lr': 0.001
}
lr_mod = LearningRateModifier(**lr_mod_args)

print('Creating sparsification schedule...')
ks_mod_args ={
    'start_epoch': 0.0,
    'end_epoch': 35.0,
    'update_frequency': 1.0,
    'param': 'weight',
    'init_sparsity': 0.05,
    'final_sparsity': 0.8,
    'inter_func': 'cubic',
    'layers': conv_layers_names
}
ks_mod = GradualKSModifier(**ks_mod_args)
modifiers = [lr_mod, ks_mod]

modifier_manager = ScheduledModifierManager(modifiers)
optimizer = ScheduledOptimizer(optimizer, model, modifier_manager, steps_per_epoch=len(train_dataset))


## Setting up training
The following should look very familiar - it is in fact the exact same code from our previous tutorial (NB1). 

In [None]:
from tqdm import tqdm
from torch.utils.data import DataLoader

from typing import Tuple, Dict
from torch import Tensor
import torch
from torch.nn import Module


def _test_datasets(model, train_data_loader: DataLoader, val_data_loader: DataLoader,
                   writer: SummaryWriter, epoch: int) -> Tuple[Dict[str, float], Dict[str, float]]:
    print('Running test for validation dataset for epoch {}'.format(epoch))
    val = test_epoch(model, val_data_loader, loss_calc, device, epoch)
    print('Completed test for validation dataset for epoch {}'.format(epoch))

    print('Running test for train dataset for epoch {}'.format(epoch))
    train = test_epoch(model, train_data_loader, loss_calc, device, epoch)
    print('Completed test for train dataset for epoch {}'.format(epoch))

    val_losses = {}
    train_losses = {}

    for loss, _ in val.items():
        val_losses[loss] = torch.mean(torch.cat(val[loss])).item()
        val_tag = 'Test/validation/{}'.format(loss)
        writer.add_scalar(val_tag, val_losses[loss], epoch)
        train_losses[loss] = torch.mean(torch.cat(train[loss])).item()
        train_tag = 'Test/training/{}'.format(loss)
        writer.add_scalar(train_tag, train_losses[loss], epoch)
    val_loss_str = 'validation set - epoch: {} '.format(epoch)
    for loss, value in val_losses.items():
        val_loss_str += (loss + ': {0:.2f} '.format(value))
    print(val_loss_str)
    train_loss_str = 'training set - epoch: {} '.format(epoch)
    for loss, value in train_losses.items():
        train_loss_str += (loss + ': {0:.2f} '.format(value))
    print(train_loss_str)


    return val_losses, train_losses


def test_epoch(model: Module, data_loader: DataLoader, loss, device, epoch: int) -> Dict:
    model.eval()
    results = {}#ModuleTestResults()
    with torch.no_grad():
        for batch, (*x_feature, y_lab) in enumerate(tqdm(data_loader)):
            y_lab = y_lab.to(device)
            x_feature = tuple([dat.to(device) for dat in x_feature])
            batch_size = y_lab.shape[0]
            
            y_pred = model(*x_feature)

            losses = loss(x_feature, y_lab, y_pred)  # type: Dict[str, Tensor]
            for key, val in losses.items():
                if key not in results:
                    results[key] = []

                result = val.detach_().cpu()
                result = result.repeat(batch_size) #repeat tensor so that there is no dependency on batch size
                results[key].append(result)
#             results.append(losses, batch_size)
    return results

def train_epoch(model: Module, data_loader: DataLoader, optimizer, loss, device, data_counter: int):
    model.train()
    
    for batch, (*x_feature, y_lab) in enumerate(tqdm(data_loader)):
        # copy next batch to the device we are using
        y_lab = y_lab.to(device)
        x_feature = tuple([dat.to(device) for dat in x_feature])
        batch_size = y_lab.shape[0]

        # Zero the parameter gradients
        optimizer.zero_grad()

        # forward 
        y_pred = model(*x_feature)
        
        # update losses
        losses = loss(x_feature, y_lab, y_pred)  # type: Dict[str, Tensor]
        
        # backward
        losses['loss'].backward()
        
        # take SGD step
        optimizer.step(closure=None)
        
        # log loss and accuracy
        data_counter += batch_size
        for _loss, _value in losses.items():
            writer.add_scalar('Train/{}'.format(_loss), _value.item(), data_counter)



## Prunning main loop:
This too is very simialr to the training main loop previously introduced, the main differences are:
1. we are tracking sparsity
2. we are following the schedules as orchastrated by the modifiers above

In [None]:
import math
print('Running baseline test...')
epoch = -1
_record_kernel_sparsity(analyzed_layers, writer, epoch)
_test_datasets(model, train_test_data_loader, val_data_loader, writer, epoch=-1)

print('Training model')
num_epochs = int(math.ceil(modifier_manager.max_epochs))
data_counter = 0

for epoch in range(num_epochs):
    print('Starting epoch {}'.format(epoch))
    optimizer.epoch_start()
    _record_kernel_sparsity(analyzed_layers, writer, epoch)



    train_epoch(model, train_data_loader, optimizer, loss_calc, device, data_counter)
    optimizer.epoch_end()
    val_losses, train_losses = _test_datasets(model, train_test_data_loader, val_data_loader, writer, epoch)

    if save_rate > 0 and epoch % save_rate == 0:
        save_path = os.path.join(model_dir, 'resnet50-epoch={:03d}-val={:.4f}.pth'
                                 .format(epoch, val_losses['loss']))
        save_model(save_path, model, optimizer, epoch)
        print('saved model checkpoint at {}'.format(save_path))

_record_kernel_sparsity(analyzed_layers, writer, num_epochs)


scalars_json_path = os.path.join(logs_dir, 'all_scalars.json')
writer.export_scalars_to_json(scalars_json_path)
writer.close()

save_path = os.path.join(model_dir, 'resnet50-pruned.pth')
print('Finished training, saving model to {}'.format(save_path))
save_model(save_path, model)
print('Saved model')

In [None]:
modifier_manager.max_epochs