# PyTorch Model Pruning with Adam Optimizer

Neural networks are generally overparameterized for given tasks (i.e., the number of parameters far exceeds the number of training points), yet they still [generalize very well](https://arxiv.org/abs/1611.03530). Overparametrization is contrary to conventional ML wisdom, where overparameterizing a model would traditionally lead to overfitting. The overall term for this is [double descent](https://openai.com/blog/deep-double-descent/), and it is a very active area of research.

A side effect of this overparameterization is that a large number of weights in deep learning networks can be pruned away (set to 0). This was discovered [early on](http://yann.lecun.com/exdb/publis/pdf/lecun-90b.pdf) by Yann Lecun, but interest died off due to lack of applications at that time. Song Han's [2015 paper](https://arxiv.org/abs/1510.00149) reinvigorated the area in pursuit of compressing model size for mobile applications. This renewed interest has resulted in numerous papers coming out on the topic of weight pruning, filter pruning, channel pruning, and ultimately block pruning. [This paper](https://arxiv.org/abs/1902.09574) out of Google gives a good overview of the current state of kernel sparsity (model pruning).

While pruning to increase kernel sparsity, we iteratively go through and remove weights based on their absolute magnitude. The smallest weights are the ones pruned first. Generally two properties enable us to do this: the [self regularizing effect of gradient descent](https://www.nature.com/articles/s41467-020-14663-9) as well as the L1 or L2 regularization functions applied to the weights. Weights that do not help in the optimization process are quickly reduced in absolute value. In this way, pruning can be thought of as an [architecture search](https://arxiv.org/abs/1905.09717).

Well, what does pruning get us? We now have a model with a lot of multiplications by zero that we don't need to run. If we're smart about how we do structure this compute (a surprisingly tricky problem), we can run the model much faster than before! That's where the [Neural Magic](http://neuralmagic.com/) engine can help us.

This notebook provides a step by step walkthrough for pruning an already trained (dense) model to enable better performance at inference time using the Neural Magic Inference Engine. By default, we prune a simple [CNN](https://en.wikipedia.org/wiki/Convolutional_neural_network) on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) using an [Adam optimizer](https://arxiv.org/abs/1412.6980); however, it is designed to be easily extendable for your model and dataset. Note that the Adam optimizer is easier to use when compared with a [Stocahstic Gradient Descent (SGD) optimizer](https://en.wikipedia.org/wiki/Stochastic_gradient_descent); however, SGD is the preferred method for pruning to ensure the resulting model will generalize well. See our other notebooks for pruning with SGD. In this notebook, we walk through the following items:
1. Environment Setup
2. Model and Dataset Setup
3. Loss Sensitivity Analysis
4. Hyperparameter Selection
5. Recalibration Using Pruning
6. Export to [ONNX](https://onnx.ai/)

Reading through this notebook will be reasonably quick to gain an intuition for what is happening. Rough time estimates for fully pruning the default model are given, note since we are training with the PyTorch CPU implementation it will be much slower than a GPU:
- 5 minutes on a GPU
- 30 minutes on a laptop CPU

## Environment Setup

Below we try to add the project folder to the PYTHONPATH environment variable for our execution. If this does not work, we will need to install neuralmagicML into the system using `pip install ./` when you are located at the root of the folder.

Additionally, please be sure to install from the requirements.txt file located at the root before running: `pip install -r ./requirements.txt`

In [None]:
import sys
import os

notebook_name = "pruning_adam_pytorch"

# environment setup for ease of use (puts neuralmagicML into the python package path)
if "WORKBOOK_DIR" not in globals():
    WORKBOOK_DIR = os.getcwd()

package_path = os.path.abspath(
    os.path.join(os.path.expanduser(WORKBOOK_DIR), os.pardir)
)
sys.path.extend([package_path])

print("added {} to PYTHONPATH".format(package_path))
print("working out of {}".format(WORKBOOK_DIR))

## Model and Dataset Setup

By default, we create a simple CNN to prune on the MNIST dataset. The CNN is already pretrained, and the weights download from the Neural Magic Model Repo. The MNIST dataset will auto-download as well through PyTorch.

If you would like to try out your model for pruning, then replace the appropriate lines for your model and dataset. Specifically,
- `model = mnist_net(pretrained=True)`
- `train_dataset = MNISTDataset(dataset_root, train=True)`
- `val_dataset = MNISTDataset(dataset_root, train=False)`

Take care to keep the variable names the same, as the rest of the notebook is set up according to those.

In [None]:
from neuralmagicML.pytorch.datasets import MNISTDataset
from neuralmagicML.pytorch.models import mnist_net
from neuralmagicML.utils import clean_path

#######################################################
# Define your model below
#######################################################
print("loading model...")
model = mnist_net(pretrained=True)
model_name = model.__class__.__name__
print(model)

#######################################################
# Define your train and validation datasets below
#######################################################
dataset_root = clean_path(os.path.join(".", notebook_name, "datasets"))

print("\nloading train dataset...")
train_dataset = MNISTDataset(dataset_root, train=True)
print(train_dataset)

print("\nloading val dataset...")
val_dataset = MNISTDataset(dataset_root, train=False)
print(val_dataset)

## Loss Sensitivity Analysis

One of the hyperparameters we need to control is how sparse (percentage of zeros) to make each fully connected or convolutional layer in a network. Not all layers are created equal, so we want to be careful about how we assign sparsity. Generally, the more parameters there are per input data, the less sensitive (and therefore more prunable) the layer will be. For example, a 3x3 convolution is much less sensitive than an equivalent channel sized 1x1 convolution. Likewise, increasing stride for convolutions will increase sensitivity.

To enable more natural visibility into this, we provide a quick, one-shot approach to approximating sensitivity. Using the [`one_shot_ks_loss_sensitivity()`]() function, an algorithm goes layer by layer and prunes each to different levels of sparsity without retraining. In this way, it is both cheap to run and a reasonable approximation. For display, the piecewise integral of the sparsity vs. loss curve is calculated for each layer. Therefore, higher sensitivities mean more loss for a given amount of sparsity.

Note, if you changed the model and/or dataset above, you likely should change the `loss`, `batch_size`, and `samples_per_measurement` variables below. The number of samples per measurement can be relatively small (only one or a few items per class) to get a proper analysis.

Finally, after running, the results will be saved to a JSON file and plotted in this notebook for easy viewing.


In [None]:
import torch
from neuralmagicML.pytorch.utils import CrossEntropyLossWrapper
from neuralmagicML.pytorch.recal import one_shot_ks_loss_sensitivity
from neuralmagicML.utils import clean_path

device = "cuda" if torch.cuda.is_available() else "cpu"
print("running ks loss sensitivity analysis for model on {}".format(device))

#######################################################
# Edit paramaters below
#######################################################
loss = CrossEntropyLossWrapper()
batch_size = 1024
samples_per_measurement = 1024

loss_analysis = one_shot_ks_loss_sensitivity(
    model, val_dataset, loss, device, batch_size, samples_per_measurement
)

save_path = clean_path(
    os.path.join(".", notebook_name, model_name, "ks-loss-sensitivity.json")
)
loss_analysis.save_json(save_path)
print("saved analysis to {}".format(save_path))
print("plotting...")
fig, axes = loss_analysis.plot(path=None, normalize=False)

# Hyperparameter Selection

In addition to the sparsity per layer hyperparams, there are a few more for pruning. The biggest ones are:
- When to start pruning `(stabilization period)`. Letting the model stabilize a bit before beginning pruning is generally a good idea. Edits to the training setup can make the initial epoch or two unstable. So, before cutting out weights, we want to make sure the model is stable.
- How long to prune `(pruning period)`; pruning for more epochs is preferred up to a point. The shorter we make our pruning period, the less likely it is that the model has converged to a stable position before pruning again. A good rule of thumb is to prune over roughly 1/6 - 1/3 of the number of epochs it took to train.
- How long to train after pruning `(fine-tuning period)`. Generally, the model will not have fully recovered after pruning has stopped. In this case, training should continue for a bit longer until the validation loss has stabilized. A good rule of thumb is to fine-tune for roughly 1/6 - 1/3 of the number of epochs it took to train.
- How often to update pruning steps while in the pruning period. The general convention is to apply pruning steps once per epoch. For different setups, it may be beneficial to prune more often (ex: once every tenth of an epoch -- 0.1). It depends on how many weight updates have happened since the last pruning step and how stable the loss function is currently.

In support of all these different hyperparameters, a config file is used and then loaded at training time. A simple UI is given below to enable easy editing of the config. The parameters mentioned above can all be adjusted. Soon we will be replacing this with a more advanced UI with more features to make this selection even easier! For now, we recommend using this notebook and the UI inside for generating the config files. If you are curious, you can look at the output after the next step as it saves the config to a file locally.

Defaults are given for the MNIST network and dataset. These will likely need to be changed to fit your application better.

In [None]:
from neuralmagicML.nbutils import (
    KSWidgetContainer,
    PruningEpochWidget,
    PruningLayersWidget,
)
from neuralmagicML.pytorch.utils import get_prunable_layers
from neuralmagicML.pytorch.models import MnistNet

if "loss_analysis" not in globals():
    loss_analysis = None

prune_layers = get_prunable_layers(model)
not_mnist = not isinstance(model, MnistNet)
widget_container = KSWidgetContainer(
    PruningEpochWidget(start_epoch=2, end_epoch=20, total_epochs=25, max_epochs=100),
    PruningLayersWidget(
        layer_names=[layer[0] for layer in prune_layers],
        layer_descs=[str(layer[1]) for layer in prune_layers],
        layer_enables=None if not_mnist else [False, True, True, True, True],
        layer_sparsities=None if not_mnist else [0.0, 0.8, 0.9, 0.9, 0.9],
        loss_sens_analysis=loss_analysis,
    ),
)
print("creating ui...")
display(widget_container.create())

## Recalibration Using Pruning

Now that the hyperparameters are chosen, we will use them to recalibrate the given model and dataset. The library is designed to be easily plugged into nearly any training setup for PyTorch. Below we provide an example of how an integration looks. Note, only a handful of these lines are needed to be able to integrate fully.

1. Create a [`ScheduledModifierManager`](); handles loading the config into PyTorch objects that modify the training process.
2. Create a [`ScheduledOptimizer`](); handles updating the PyTorch objects that modify the training process. It wraps the original optimizer that was used to modify the training process/graph, and should be used in place of that. IE, optimizer.step() must be called on `ScheduledOptimizer` and not the original.
3. Use `max_epochs` on the `ScheduledModifierManager` to know how many epochs are needed for training.
4. Call into the `ScheduledOptimizer` for `epoch_start()` and `epoch_end()` before training. These calls mark when an epoch has started and after training for an epoch has ended, respectively.

Once the training objects are created (optimizer, loss function, etc.), a `ScheduledModifierManager` and `ScheduledOptimizer` are instantiated from the config. Almost all logging and updates are done through `Tensorboard` for this notebook. The use of `Tensorboard` is optional, and other loggers (as well as not using a logger) are available. Finally, regular training and testing code is used to go through the process.

In [None]:
import math
from tqdm import auto
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from neuralmagicML.pytorch.utils import (
    CrossEntropyLossWrapper,
    TopKAccuracy,
    ModuleTrainer,
    ModuleTester,
    TensorboardLogger,
)
from neuralmagicML.utils import create_unique_dir, clean_path

# save the config locally for use in this flow
config_path = clean_path(os.path.join(".", notebook_name, model_name, "config.yaml"))
print("saving config to {}".format(config_path))
widget_container.get_manager("pytorch").save(config_path)


#######################################################
# Necessary imports for running recalibration
#######################################################
from neuralmagicML.pytorch.recal import ScheduledModifierManager, ScheduledOptimizer


# setup device, data loaders, loss, optimizer
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 1024
train_data_loader = DataLoader(train_dataset, batch_size, shuffle=True, pin_memory=True)
val_data_loader = DataLoader(val_dataset, batch_size, shuffle=False, pin_memory=True)
loss = CrossEntropyLossWrapper(extras={"top1acc": TopKAccuracy(1)})
optim = Adam(model.parameters())
print("device:{} batch_size:{} loss:{}".format(device, batch_size, loss))

tensorboard_model_path = create_unique_dir(
    os.path.join(".", "tensorboard-logs", notebook_name, model_name)
)
loggers = [TensorboardLogger(tensorboard_model_path)]
print("logging at {}".format(tensorboard_model_path))


#######################################################
# First lines that must be substituted in training code
# We create a modifier manager as well as a scheduled optimizer
# These will apply the config we created while running the training process
# The loggers can be left out if desired
#######################################################
manager = ScheduledModifierManager.from_yaml(config_path)
optim = ScheduledOptimizer(
    optim,
    model,
    manager,
    steps_per_epoch=math.ceil(len(train_dataset) / batch_size),
    loggers=loggers,
)
print("created manager and optimizer from config at {}".format(config_path))


# we use prewritten trainers and testers to make the code more concise
trainer = ModuleTrainer(model, device, loss, optim, loggers=loggers)
tester = ModuleTester(model, device, loss, loggers=loggers)
model = model.to(device)

# startup tensorboard
%load_ext tensorboard
%tensorboard --logdir ./tensorboard-logs

# run initial validation for comparison
tester.run_epoch(val_data_loader, epoch=-1, show_progress=False)


#######################################################
# Final lines that must be substituted in training code
# We continue training for as long as requested in the config
# Additionally, we tell the ScheduledOptimizer when each epoch
# has started and ended
#######################################################
for epoch in auto.tqdm(range(manager.max_epochs), desc="training"):
    optim.epoch_start()

    trainer.run_epoch(train_data_loader, epoch, show_progress=False)
    tester.run_epoch(val_data_loader, epoch, show_progress=False)

    optim.epoch_end()

# delete so all modifiers are cleaned up before exporting
del optim
print("training completed")

## Export to ONNX

Now that the model is fully recalibrated, we need to export it to an ONNX format. The ONNX format is what is used by the Neural Magic Inference Engine. For PyTorch, exporting to ONNX is natively supported. Below we use a convenience class to handle exporting, the [`ModuleExporter`](). Once the model has saved as an ONNX file, it is ready to be used for inference with Neural Magic.

In [None]:
from neuralmagicML.utils import clean_path
from neuralmagicML.pytorch.utils import ModuleExporter

export_path = clean_path(os.path.join(".", notebook_name, model_name))
exporter = ModuleExporter(model, export_path)
for batch in val_data_loader:
    sample_input = batch[0]
    break
exporter.export_onnx(sample_input)
print("exported onnx to {}".format(export_path))