<sub>&copy; 2021 Neuralmagic, Inc. // [Neural Magic Legal](https://neuralmagic.com/legal)</sub> 

# Keras Classification Model Pruning using SparseML

This notebook provides a step-by-step walkthrough for pruning an already trained (dense) model to enable better performance at inference time using the DeepSparse Inference Engine. You will:
- Set up the model and dataset
- Integrate the Keras training flow with SparseML
- Prune the model using the Keras+SparseML flow
- Export to [ONNX](https://onnx.ai/)

Reading through this notebook will be reasonably quick to gain an intuition for how to plug SparseML into your Keras training flow.


## Step 1 - Requirements
To run this notebook, you will need the following packages already installed:
* SparseML and SparseZoo
* Keras
* TensorBoard

You can install any package that is not already present via `pip`.

## Step 2 - Setting Up the Model and Dataset

In this notebook, you will prune a simple Convolution Neural Network model trained on the MNIST dataset. The pretrained model's architecture and weights are downloaded from the SparseZoo model repo. The dataset is downloaded directly from  Keras datasets library.

### Set Up the Model

The following cell defines a procedure to download a model from the SparseZoo; additionally, for convenience it also returns the path to an optimization recipe. You construct a Keras model instance from the pretrained to prune in a later step.

In [None]:
import os
from tensorflow import keras
from sparsezoo.models import Zoo

# Root directory for the notebook artifacts
root_dir = "./notebooks/keras"

def download_model_and_recipe(root_dir: str):
    """
    Download pretrained model and a pruning recipe
    """
    model_dir = os.path.join(root_dir, "mnist")
    zoo_model = Zoo.load_model(
            domain="cv",
            sub_domain="classification",
            architecture="mnist",
            sub_architecture=None,
            framework="keras",
            repo="sparseml",
            dataset="mnist",
            training_scheme=None,
            optim_name="pruned",
            optim_category="conservative",
            optim_target=None,
            override_parent_path=model_dir,
        )
    zoo_model.download()

    model_file_path = zoo_model.framework_files[0].downloaded_path()
    if not os.path.exists(model_file_path) or not model_file_path.endswith(".h5"):
        raise RuntimeError("Model file not found: {}".format(model_file_path))
    recipe_file_path = zoo_model.recipes[0].downloaded_path()
    if not os.path.exists(recipe_file_path):
        raise RuntimeError("Recipe file not found: {}".format(recipe_file_path))
    return model_file_path, recipe_file_path

model_file_path, recipe_file_path = download_model_and_recipe(root_dir)

print("Loading model {}".format(model_file_path))
model = keras.models.load_model(model_file_path)
model.summary()

### Set Up the Dataset

You will download the MNIST dataset from Keras datasets library as follows. You will also normalize the data before using it for training and evaluation.

In [None]:
# Number of classes
num_classes = 10

# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

input_shape = (28, 28, 1)

# Normalize data
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

Before pruning the model, you could run the cell below to verify the accuracy of the model on the dataset.

In [None]:
res = model.evaluate(x_test, y_test)
print("Test loss, accuracy: ", res)

## Step 3 - Pruning the Pretrained Model

In this section, you will prune the above pretrained Keras model using the SparseML model optimization library. Recall that a common training workflow in Keras is first to compile the model with the appropriate losses, metrics and an optimizer, then to train the model using the `fit()` method of the `Model` class. The SparseML library makes it easy to extend this training workflow to perform gradual pruning based on weight magnitudes.

Given a pretrained model, the pruning workflow can be summarized as follows:
1. Create a recipe for pruning, which could be done effectively using the Sparsify toolkit
2. Instantiate a Keras optimizer instance (such as SGD or Adam)
3. Instantiate a `ScheduledModifierManager` object from the recipe
4. Enhance the model and optimizer with pruning data structures by calling the manager's `modify` method. At this step, you have options to define the loggers used during the pruning process. The results of this step are a model to be pruned, an optimizer that should be used and a list of callbacks
5. [Optional] Add to the callback list any additional callbacks such as model checkpoint and the SparseML built-in LossesAndMetricsLogging callback
6. Compile and fit the modified model using Keras built-in APIs, using the optimizer and callback list
7. Erase the pruning information in the enhanced model, and get back the original model with pruned weights


Next, you will set up a directory path for logging and the frequency for the logging update.

In [None]:
from datetime import datetime

# Logging directory
log_dir = "/hdd/src/sparseml/notebooks/keras/tensorboard/mnist"
log_dir += ":" + datetime.now().strftime("%Y%m%d-%H%M%S")
print("Logging directory: {}".format(log_dir))

# Number of steps before the next logging should take place
# Use "epoch" or "batch" to log at every training epoch or batch (respectively)
update_freq = 100

The following cell contains steps described in the pruning workflow.

In [None]:
from datetime import datetime
import math
from tensorflow.keras.optimizers import Adam

from sparseml.keras.optim import ScheduledModifierManager
from sparseml.keras.utils import TensorBoardLogger

# Number of steps per epoch
steps_per_epoch = math.ceil(len(x_train) / batch_size)

# Create a manager from the recipe
manager = ScheduledModifierManager.from_yaml(recipe_file_path)

# Create an optimizer
optimizer = Adam()

# Create a TensorBoardLogger instance
loggers = TensorBoardLogger(log_dir=log_dir, update_freq=update_freq)

# Modify the model and optimizer
model_for_pruning, optimizer, callbacks = manager.modify(model, optimizer, steps_per_epoch, loggers=loggers)

# Compile the modified model
model_for_pruning.compile(
    loss=keras.losses.categorical_crossentropy,
    optimizer=optimizer,
    metrics=['accuracy'],
    run_eagerly=True
)

# Prune the model
model_for_pruning.fit(x_train, y_train, batch_size=batch_size, epochs=manager.max_epochs,
                      validation_data=(x_test, y_test), shuffle=True, callbacks=callbacks)

print("Prunning finished")

# Verify the pruned model's accuracy
res = model_for_pruning.evaluate(x_test, y_test)
print("Validation loss, accuracy: ", res)

# Erase the enhanced information used for pruning 
pruned_model = manager.finalize(model_for_pruning)

You will observe the logging information in TensorBoard.

In [None]:
%load_ext tensorboard
%tensorboard --logdir log_dir

## Step 4 - Examine the Pruned Model

You can observe the layers of the pruned Keras model using its `layers` property and `get_weights()` method.

In [None]:
pruned_model.layers

In [None]:
# Change the layer index to examine the layers
layer_index = 1
pruned_model.layers[layer_index].get_weights()

## Step 7 - Exporting to ONNX

Now that the model is fully recalibrated, you need to export it to an ONNX format, which is the format used by the DeepSparse. For Keras, exporting to ONNX is natively supported. In the cell block below, a convenience class, ModuleExporter(), is used to handle exporting.

Once the model is saved as an ONNX ﬁle, it is ready to be used for inference with Neural Magic.  For saving a custom model, you can override the sample batch for ONNX graph freezing and locations to save to.

In [None]:
from sparseml.keras.utils import ModelExporter

save_dir = "keras_classification"

exporter = ModelExporter(model, output_dir=save_dir)
exporter.export_onnx(name="pruned_mnist.onnx")

## Next Steps

Congratulations, you have pruned a model and exported it to ONNX for inference!  Next steps you can pursue include:
* Pruning different models using SparseML
* Trying different pruning and optimization recipes
* Running your model on the DeepSparse Inference Engine