<sub>&copy; 2020 Neuralmagic, Inc., Confidential // [Neural Magic Evaluation License Agreement](https://neuralmagic.com/evaluation-license-agreement/)</sub> 

# TensorFlow Model Pruning with Adam Optimizer

This notebook provides a step-by-step walkthrough for training and pruning a model to enable better performance at inference time using the Neural Magic Inference Engine. You will:
1. Set up the environment
2. Set up the model and dataset
3. Analyze loss sensitivity
4. Select hyperparameters
5. Recalibrate using pruning
6. Export to [ONNX](https://onnx.ai/)

Reading through this notebook will be reasonably quick to gain an intuition for what is happening. Rough time estimates for fully pruning the default model are given. Note that training with the TensorFlow CPU implementation will be much slower than a GPU:
- 10 minutes on a GPU
- 45 minutes on a laptop CPU

## Background

Neural networks are generally overparameterized for given tasks (i.e., the number of parameters far exceeds the number of training points), yet they [still generalize well](https://arxiv.org/abs/1611.03530). Overparameterization is contrary to conventional ML wisdom, where overparameterizing a model would traditionally lead to overﬁtting. The overall term for this is [double descent](https://openai.com/blog/deep-double-descent/) and it is a very active area of research.

A side effect of this overparameterization is that a large number of weights in deep learning networks can be pruned away (set to 0). This was [discovered early on](http://yann.lecun.com/exdb/publis/pdf/lecun-90b.pdf) by Yann Lecun, but interest waned due to lack of applications at that time. [Song Han's 2015 paper](https://arxiv.org/abs/1510.00149) reinvigorated the area in pursuit of compressing model size for mobile applications. This renewed interest has resulted in numerous papers on the topic of weight pruning, ﬁlter pruning, channel pruning, and ultimately, block pruning. A [Google paper](https://arxiv.org/abs/1902.09574) gives a good overview of the current state of kernel sparsity (model pruning).

While pruning to increase kernel sparsity, we iteratively go through and remove weights based on their absolute magnitude. The smallest weights are the ones pruned ﬁrst. Generally, two properties enable us to do this: the self-regularizing [effect of gradient descent](https://www.nature.com/articles/s41467-020-14663-9) as well as the L1 or L2 regularization functions applied to the weights. Weights that do not help in the optimization process are quickly reduced in absolute value. In this way, pruning can be thought of as an [architecture search](https://arxiv.org/abs/1905.09717).

What does pruning get us? We now have a model with a lot of multiplications by zero that we don't need to run. If we're smart about how we structure this compute (a surprisingly tricky problem), we can run the model much faster than before! The pruned model plus the ability to run it quickly in the Neural Magic Inference Engine helps to optimize performance. [Neural Magic](http://neuralmagic.com/) makes it easier to apply the algorithm, giving you more information so you can apply the algorithm with better results.

In this notebook, you prune a simple [CNN](http://yann.lecun.com/exdb/mnist/) on the [MNIST dataset](https://arxiv.org/abs/1412.6980) using an [Adam optimizer](https://arxiv.org/abs/1412.6980). However, the notebook is designed to be easily extendable for your model and dataset.  Guided instructions are provided in the notebook code comments. 

Note that the Adam optimizer is easier to use when compared with a [Stocahstic Gradient Descent (SGD) optimizer](https://en.wikipedia.org/wiki/Stochastic_gradient_descent); however, SGD is the preferred method for pruning to ensure the resulting model will generalize well. See our other notebooks for pruning with SGD. 

## Before you begin…
Be sure to read through the README found in the Neural Magic ML Tooling (neuralmagicML) package.



## Step 1 - Setting Up the Environment

In this step, Neural Magic checks your environment setup to ensure the rest of the notebook will flow smoothly.
Before running, install the neuralmagicML package into the system using the following command:

`pip install neuralmagicML-python/ `


In [None]:
notebook_name = "pruning_adam_tensorflow"
print("checking setup for {}...".format(notebook_name))

# filter because of tensorboard future warnings
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

try:
    # make sure neuralmagicML is installed
    import neuralmagicML
except Exception as ex:
    raise Exception(
        "please install neuralmagicML using the setup.py file before continuing"
    )

from neuralmagicML.utilsnb import check_tensorflow_notebook_setup

check_tensorflow_notebook_setup()

## Step 2 - Setting Up the Model and Dataset

By default, you will create a simple CNN to prune on the MNIST dataset. The CNN is already pre-trained, and the weights download from the Neural Magic Model Repo. The MNIST dataset will auto-download as well through TensorFlow.

If you would like to try out your model for pruning, modify the appropriate lines for your model and dataset, speciﬁcally:
- graph_creator():  # replace the function wiith your desired model
- dataset_root = clean_path(os.path.join(".", notebook_name, "datasets", "mnist"))
- dataset = input_data.read_data_sets(dataset_root, one_hot=True)

Take care to keep the variable names the same, as the rest of the notebook is set up according to those.


In [None]:
import os
from tensorflow.examples.tutorials.mnist import input_data
from neuralmagicML.tensorflow.models import ModelRegistry
from neuralmagicML.tensorflow.utils import tf_compat, batch_cross_entropy_loss, accuracy
from neuralmagicML.utils import clean_path

# ignore tf v1 vs v2 deprecation warnings
tf_compat.logging.set_verbosity(tf_compat.logging.ERROR)

#######################################################
# Define your graph below
#######################################################
model_name = "mnistnet"


def graph_creator():
    inputs = tf_compat.placeholder(tf_compat.float32, [None, 28, 28, 1], name="inputs")
    labels = tf_compat.placeholder(tf_compat.float32, [None, 10])
    logits = ModelRegistry.create("mnistnet", inputs)
    loss = batch_cross_entropy_loss(logits, labels)
    acc = accuracy(logits, labels)

    return inputs, labels, logits, loss, acc


print(graph_creator)

#######################################################
# Define your dataset below
#######################################################
print("\nloading dataset...")
dataset_root = clean_path(os.path.join(".", notebook_name, "datasets", "mnist"))
dataset = input_data.read_data_sets(dataset_root, one_hot=True)
print(dataset)

## Step 3 - Analyzing Loss Sensitivity

One of the hyperparameters you need to control is how sparse (percentage of zeros) to make each fully connected or convolutional layer in a network. Not all layers are created equal, so you will want to be careful about how you assign sparsity. Generally, the more parameters there are per input data, the less sensitive (and therefore more prunable) the layer will be. For example, a 3x3 convolution is much less sensitive than an equivalent channel sized 1x1 convolution. Likewise, increasing stride for convolutions will increase sensitivity.

To enable more natural visibility into this, we provide a quick approach to approximating sensitivity. Using the `approx_ks_loss_sensitivity()` function, an algorithm goes layer by layer and checks the magnitude of remaining weights at various levels of sparsity. In this way, it is a reasonable approximation that is inexpensive to run because it does not require significant computer resources. For display, the piecewise integral of the sparsity versus loss curve is calculated for each layer. Therefore, higher sensitivities mean more sensitivity for a given amount of sparsity.

To run the analysis, we create the graph and then pass it in to `approx_ks_loss_sensitivity()`.

Finally, after running, the results will be saved to a JSON ﬁle and plotted in this notebook for easy viewing.


In [None]:
from neuralmagicML.utils import clean_path
from neuralmagicML.tensorflow.utils import tf_compat
from neuralmagicML.tensorflow.recal import approx_ks_loss_sensitivity
from neuralmagicML.tensorflow.models import ModelRegistry

with tf_compat.Graph().as_default() as graph:
    print("Creating graph...")
    inputs, labels, logits, loss, acc = graph_creator()
    
    with tf_compat.Session() as sess:
        print("loading pre-trained weights...")
        ModelRegistry.load_pretrained("mnistnet")

        print("running ks loss sensitivity analysis...")
        loss_analysis = approx_ks_loss_sensitivity(graph)

save_path = clean_path(
    os.path.join(".", notebook_name, model_name, "ks-loss-sensitivity.json")
)
loss_analysis.save_json(save_path)
print("saved analysis to {}".format(save_path))
print("plotting...")
fig, axes = loss_analysis.plot(path=None, plot_integral=True, normalize=False)

## Step 4 - Hyperparameters

In addition to the sparsity per layer hyperparameter, there are a few more for pruning. The most significant are:
- When to start pruning (stabilization period). Letting the model stabilize a bit before beginning pruning is generally a good idea. Edits to the training setup can make the initial epoch or two unstable. So, before cutting out weights, you want to make sure the model is stable.
- How long to prune (pruning period). Pruning for more epochs is preferred up to a point. The shorter the pruning period, the less likely it is that the model has converged to a stable position before pruning again. A good rule is to prune over roughly 1/6 to 1/3 the number of epochs it took to train.
- How long to train after pruning (fine-tuning period). Generally, the model will not have fully recovered after pruning has stopped. In this case, training should continue a bit longer until the validation loss has stabilized. A good rule is to ﬁne-tune for roughly 1/6 to 1/3 the number of epochs it took to train.
- How often to update pruning steps while in the pruning period. The general convention is to apply pruning steps once per epoch. For different setups, it may be beneﬁcial to prune more often (e.g., once every tenth of an epoch -- 0.1). It depends on how many weight updates have happened since the last pruning step and how stable the loss function is currently.

In support of all these different hyperparameters, a conﬁguration ﬁle is used and then loaded at training time. A simple UI is given in the cell block below to enable easy editing of the conﬁguration. The parameters mentioned above can all be adjusted. Soon, Neural Magic will replace this with a more advanced UI with more features to make this selection even easier! For now, we recommend using this notebook and the UI inside to generate the conﬁguration ﬁle. You can look at the output after the next step as it saves the conﬁguration to a ﬁle locally.

Defaults are given for the MNIST network and dataset. You may need to change these to better ﬁt your application.


In [None]:
from neuralmagicML.utilsnb import (
    KSWidgetContainer,
    PruningEpochWidget,
    PruningParamsWidget,
)
from neuralmagicML.tensorflow.utils import get_ops_and_inputs_by_name_or_regex

if "loss_analysis" not in globals():
    loss_analysis = None

with tf_compat.Graph().as_default() as graph:
    print("Creating graph...")
    graph_creator()

    # match all model weights
    prune_ops_and_tens = get_ops_and_inputs_by_name_or_regex(
        ["re:.*/conv./weight", "re:.*/fc/weight"]
    )

    not_mnist = model_name != "mnistnet"
    widget_container = KSWidgetContainer(
        PruningEpochWidget(
            start_epoch=2,
            end_epoch=20,
            total_epochs=25,
            max_epochs=100,
            update_frequency=0.0,
        ),
        PruningParamsWidget(
            param_names=[tens.name for _, tens in prune_ops_and_tens],
            param_descs=[op.type for op, _ in prune_ops_and_tens],
            param_enables=None if not_mnist else [True, True, True, True, False],
            param_sparsities=None if not_mnist else [0.85, 0.8, 0.85, 0.85, 0.0],
            loss_sens_analysis=loss_analysis,
        ),
    )

print("Creating ui...")
display(widget_container.create())

## Step 5 - Recalibrating Using Pruning

Now that the hyperparameters are chosen, you will use them to recalibrate the given model and dataset. The library is designed to be easily plugged into nearly any training setup for TensorFlow. In the cell block below is an example of how an integration looks. Note that only five lines are needed to be able to integrate fully.
- Create a `ScheduledModifierManager()`. This loads the conﬁg into TensorFlow objects that modify the training process.
- Invoke `manager.create_ops()` for the desired graph. This updates the TensorFlow graph with the proper operators that modify the training process.
- Use `manager.max_epochs` to know how many epochs are needed for training.
- Invoke `sess.run(mod_ops)` on each optimizer step. This updates the modifying operators and variables in the TensorFlow graph.
- Invoke `manager.complete_graph()` once training has completed. This wilil cleanup the graph and set any final state for graph export and saving.
 
Once the training objects are created (optimizer, loss function, etc.), a `ScheduledModifierManager` is instantiated from the conﬁguration. Most logging and updates are done through TensorBoard for this notebook. The use of TensorBoard is entirely optional. Finally, regular training and testing code is used to go through the process.

Note, for convenience a TensorBoard instance is launched in the cell below pointed at `localhost`. If you are running this notebook on a remote server, then you will need to update TensorBoard accordingly.


In [None]:
import math
from tqdm import auto
from neuralmagicML.utils import create_unique_dir, clean_path, create_dirs
from neuralmagicML.tensorflow.utils import eval_tensor_sparsity
from neuralmagicML.tensorflow.models import ModelRegistry
from neuralmagicML.tensorflow.recal import ScheduledModifierManager

# save the config locally for use in this flow
config_path = clean_path(os.path.join(".", notebook_name, model_name, "config.yaml"))
print("saving config to {}".format(config_path))
widget_container.get_manager("tensorflow").save(config_path)

# startup tensorboard
%load_ext tensorboard
%tensorboard --logdir ./tensorboard-logs


def calc_test_metrics(acc, loss, inputs, labels):
    test_xs = dataset.test.images.reshape([-1, 28, 28, 1])
    test_acc, test_loss = sess.run(
        [acc, loss], feed_dict={inputs: test_xs, labels: dataset.test.labels}
    )
    return test_acc, test_loss


with tf_compat.Graph().as_default() as graph:
    batch_size = 1024
    steps_per_epoch = int(len(dataset.train.images) / batch_size)
    inputs, labels, logits, loss, acc = graph_creator()
    global_step = tf_compat.train.get_or_create_global_step()

    epoch_ph = tf_compat.placeholder(dtype=tf_compat.float32, name="epoch")
    test_loss_ph = tf_compat.placeholder(dtype=tf_compat.float32, name="test_loss")
    test_acc_ph = tf_compat.placeholder(dtype=tf_compat.float32, name="test_acc")
    tf_compat.summary.scalar("Train/epoch", epoch_ph)
    tf_compat.summary.scalar("Train/loss", loss)
    tf_compat.summary.scalar("Train/accuracy", acc)
    tf_compat.summary.scalar("Test/loss", test_loss_ph)
    tf_compat.summary.scalar("Test/accuracy", test_acc_ph)

    #######################################################
    # Create a manager for recalibration and create the ops
    #######################################################
    manager = ScheduledModifierManager.from_yaml(config_path)
    mod_ops, mod_extras = manager.create_ops(steps_per_epoch, global_step)

    train_op = tf_compat.train.AdamOptimizer(learning_rate=1e-4).minimize(
        loss, global_step=global_step
    )

    summaries = tf_compat.summary.merge_all()
    tensorboard_path = create_unique_dir(
        os.path.join(".", "tensorboard-logs", notebook_name, model_name)
    )
    summary_writer = tf_compat.summary.FileWriter(tensorboard_path, sess.graph)

    with tf_compat.Session() as sess:
        sess.run(tf_compat.global_variables_initializer())
        print("loading pre-trained weights...")
        ModelRegistry.load_pretrained("mnistnet")
        test_acc, test_loss = calc_test_metrics(acc, loss, inputs, labels)

        for epoch in auto.tqdm(range(manager.max_epochs), desc="training"):
            for batch in range(steps_per_epoch):
                batch_xs, batch_ys = dataset.train.next_batch(batch_size)
                batch_xs = batch_xs.reshape([-1, 28, 28, 1])
                sess.run(train_op, feed_dict={inputs: batch_xs, labels: batch_ys})
                sess.run(mod_ops)

                # log summaries every 5% of an epoch
                if batch % int(steps_per_epoch * 0.05):
                    step = sess.run(global_step)
                    summary_str = sess.run(
                        summaries,
                        feed_dict={
                            inputs: batch_xs,
                            labels: batch_ys,
                            test_acc_ph: test_acc,
                            test_loss_ph: test_loss,
                            epoch_ph: epoch,
                        },
                    )
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

            test_acc, test_loss = calc_test_metrics(acc, loss, inputs, labels)

        manager.complete_graph()
        print("final accuracy: {}".format(test_acc))

        checkpoint_path = create_unique_dir(
            os.path.join(".", notebook_name, model_name, "checkpoint")
        )
        checkpoint_path = os.path.join(checkpoint_path, "model")
        create_dirs(checkpoint_path)
        saver = tf_compat.train.Saver(
            tf_compat.get_collection(tf_compat.GraphKeys.TRAINABLE_VARIABLES)
        )
        saver.save(sess, checkpoint_path)
        print("saved model checkpoint to {}".format(checkpoint_path))

## Step 6 - Exporting to ONNX

Now that the model is fully recalibrated, you need to export it to an ONNX format, which is the format used by the Neural Magic Inference Engine. For TensorFlow, exporting to ONNX is not natively supported. To add support, you will use the `tf2onnx` Python package. In the cell block below, a convenience class, `GraphExporter()`, is used to handle exporting. It wraps the somewhat complicated API for `tf2onnx` into an easy to use interface.

Note, for some configurations, the tf2onnx code does not work properly in a Jupyter Notebook. To remedy this, you should run the `exporter.export_onnx()` function call in a Python console or script.

Once the model is saved as an ONNX ﬁle, it is ready to be used for inference with Neural Magic.


In [None]:
from neuralmagicML.utils import create_unique_dir, clean_path, create_dirs
from neuralmagicML.tensorflow.utils import GraphExporter

export_path = clean_path(os.path.join(".", notebook_name, model_name, "exported"))
exporter = GraphExporter(export_path)

with tf_compat.Graph().as_default() as graph:
    print('Recreating graph...', flush=True)
    inputs, labels, logits, loss, acc = graph_creator()
    input_names = [inputs.name]
    output_names = [logits.name]

    with tf_compat.Session() as sess:
        sess.run(tf_compat.global_variables_initializer())
        print("Restoring previous weights...", flush=True)
        saver = tf_compat.train.Saver(
            tf_compat.get_collection(tf_compat.GraphKeys.TRAINABLE_VARIABLES)
        )
        saver.restore(sess, checkpoint_path)
        
        print("Exporting to pb...", flush=True)
        exporter.export_pb(outputs=[logits])
        print("Exported pb file to {}".format(exporter.pb_path), flush=True)
        
print("Exporting to onnx...", flush=True)
exporter.export_onnx(inputs=input_names, outputs=output_names)
print("Exported onnx file to {}".format(exporter.onnx_path))


## Next Step

Run your model (ONNX file) through the Neural Magic Inference Engine. The following is an example of code that you can run in your Python console. Be sure to enter your ONNX file path and batch size.

```
from neuralmagic import create_model
model = create_model(onnx_file_path=’some/path/to/model.onnx’, batch_size=1)
inp = [numpy.random.rand(1, 3, 224, 224).astype(numpy.float32)]
out = model.forward(inp)
print(out)
```