<sub>&copy; 2021-present Neuralmagic, Inc. // [Neural Magic Legal](https://neuralmagic.com/legal)</sub> 

# TensorFlow v1 Classification Model Pruning Using SparseML

This notebook provides a step-by-step walkthrough for pruning an already trained (dense) model to enable better performance at inference time using the DeepSparse Engine. You will:
- Set up the model and dataset
- Define a TensorFlow training flow with a simple SparseML integration
- Prune the model using the TensorFlow+SparseML flow
- Export to [ONNX](https://onnx.ai/)

Reading through this notebook will be reasonably quick to gain an intuition for how to plug SparseML into your TensorFlow training flow. Rough time estimates for fully pruning the default model are given. Note that training with the PyTorch CPU implementation will be much slower than a GPU:
- 20 minutes on a GPU
- 60 minutes on a laptop CPU

## Step 1 - Requirements
To run this notebook, you will need the following packages already installed:
* SparseML and SparseZoo
* TensorFlow v1 and tf2onnx

You can install any package that is not already present via `pip`.

In [None]:
import sparseml
import sparsezoo
import tensorflow as tf
import tf2onnx

assert tf.__version__ < "2"

# suppress warnings
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

## Step 2 - Setting Up the Model and Dataset

By default, you will prune a [ResNet-50](https://arxiv.org/abs/1512.03385) model trained on the [Imagenette dataset](https://github.com/fastai/imagenette). The model's pretrained weights are downloaded from the SparseZoo model repo.   The Imagenette dataset is downloaded from its repository via a helper class from SparseML.

In the cells below, functions are defined to load the dataset, model, and training objects to be called during training from within the `Graph` context.

If you would like to try out your model for pruning, modify the appropriate function for to load your model or dataset.

In [None]:
import math

from sparseml.tensorflow_v1.models import ModelRegistry
from sparseml.tensorflow_v1.datasets import (
    ImagenetteDataset,
    ImagenetteSize,
)
from sparseml.tensorflow_v1.utils import (
    batch_cross_entropy_loss,
    accuracy,
)

MODEL_NAME = "resnet50"
BATCH_SIZE = 128
INPUT_SIZE = 224
NUM_CLASSES = 10
SAVE_DIR = "tensorflow_v1_classification_pruning"


def load_dataset():
    with tf.device("/cpu:0"):
        print("loading datasets")
        train_dataset = ImagenetteDataset(
            train=True, dataset_size=ImagenetteSize.s320, image_size=INPUT_SIZE
        )
        train_len = len(train_dataset)
        train_steps = math.ceil(train_len / float(BATCH_SIZE))
        train_dataset = train_dataset.build(
            BATCH_SIZE,
            shuffle_buffer_size=1000,
            prefetch_buffer_size=BATCH_SIZE,
            num_parallel_calls=4,
        )

        val_dataset = ImagenetteDataset(
            train=False, dataset_size=ImagenetteSize.s320, image_size=INPUT_SIZE
        )
        val_len = len(val_dataset)
        val_steps = math.ceil(val_len / float(BATCH_SIZE))
        val_dataset = val_dataset.build(
            BATCH_SIZE,
            shuffle_buffer_size=1000,
            prefetch_buffer_size=BATCH_SIZE,
            num_parallel_calls=4,
        )

    return train_dataset, val_dataset, (train_steps, val_steps)


def create_model(sample_input, training):
    print("Creating model graph for {}".format(MODEL_NAME))
    logits = ModelRegistry.create(
        MODEL_NAME,
        inputs=sample_input,
        training=training,
        num_classes=NUM_CLASSES,
    )
    return logits


def create_training_objects(sample_labels):
    print("Creating loss, accuracy, and optimizer in graph")
    loss = batch_cross_entropy_loss(logits, labels)
    acc = accuracy(logits, labels)
    global_step = tf.train.get_or_create_global_step()
    train_op = tf.train.AdamOptimizer(learning_rate=0.00008).minimize(
        loss, global_step=global_step
    )
    return loss, acc, global_step, train_op


def load_pretrained():
    print("loading pre-trained model weights")
    ModelRegistry.load_pretrained(
        MODEL_NAME, pretrained="base", remove_dynamic_tl_vars=True,
    )

## Step 3 - Create a SparseML Modifier Manager

To prune a model with SparseML, you will download a recipe from SparseZoo and use it to create a `ScheduledModifierManager` object.  This manager will be used to create operations that modify the training process.

You can create SparseML recipes to perform various model pruning schedules, quantization aware training, sparse transfer learning, and more.  If you are using a different model than the default, you will have to modify the recipe YAML file to target the new model's parameters.

Using the operators generated from this manager object, you will be able to prune your model.

In [None]:
from sparseml.tensorflow_v1.optim import (
    ScheduledModifierManager,
)
from sparsezoo import Zoo


def create_sparseml_manager():
    recipe = Zoo.search_recipes(
        domain="cv",
        sub_domain="classification",
        architecture="resnet_v1",
        sub_architecture="50",
        framework="tensorflow_v1",
        repo="sparseml",
        dataset="imagenette",
        sparse_name="pruned",
    )[0]  # unwrap search result
    recipe.download()
    
    recipe_path = recipe.downloaded_path()
    print(f"Recipe downloaded to: {recipe_path}")

    return ScheduledModifierManager.from_yaml(recipe_path)

## Step 4 - Prune your model using a TensorFlow training loop
SparseML can plug directly into your existing TensorFlow training flow by creating additional operators to run. To demonstrate this, in the cell below, prune the model using a standard TensorFlow training loop while also running the operators created by the manager object.  To prune your existing models using SparseML, you can use your own training flow with the additional operators added.

For your convienence the lines needed for integrating with SparseML are preceeded by large comment blocks.

If the kernel shuts down during training, this may be an out of memory error, to resolve this, try lowering the `BATCH_SIZE` in Step 2 above.

In [None]:
import numpy
import os
from tqdm.auto import tqdm
from sparseml.utils import create_unique_dir, create_dirs
from sparseml.tensorflow_v1.datasets import create_split_iterators_handle

with tf.Graph().as_default() as graph:
    # create dataset
    train_dataset, val_dataset, (train_steps, val_steps) = load_dataset()
    handle, iterator, (train_iter, val_iter) = create_split_iterators_handle(
        [train_dataset, val_dataset]
    )
    images, labels = iterator.get_next()

    # create base training objects
    training = tf.placeholder(dtype=tf.bool, shape=[])
    logits = create_model(images, training)
    loss, acc, global_step, train_op = create_training_objects(labels)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    #######################################################
    # create sparseml training manager
    #######################################################
    manager = create_sparseml_manager()
    mod_ops, mod_extras = manager.create_ops(train_steps, global_step, graph=graph)

    with tf.Session() as sess:
        print("initializing session")
        sess.run(
            [
                tf.global_variables_initializer(),
                tf.local_variables_initializer(),
            ]
        )
        train_iter_handle, val_iter_handle = sess.run(
            [train_iter.string_handle(), val_iter.string_handle()]
        )

        # initialize sparseml manager after pretrained weights loaded
        load_pretrained()
        manager.initialize_session()

        num_epochs = manager.max_epochs
        for epoch in tqdm(range(num_epochs), desc="pruning"):
            print("training for epoch {}...".format(epoch))
            sess.run(train_iter.initializer)
            train_losses = []
            train_acc = []

            for step in range(train_steps):
                _, __, meas_step, meas_loss, meas_acc = sess.run(
                    [train_op, update_ops, global_step, loss, acc],
                    feed_dict={handle: train_iter_handle, training: True},
                )
                train_losses.append(meas_loss)
                train_acc.append(meas_acc)

                #######################################################
                # Modifier update ops line for transfer learning from a sparse model in TensorFlow
                #######################################################
                sess.run(mod_ops)
            print(
                "completed epoch {} training with: loss {} / acc {}".format(
                    epoch,
                    numpy.mean(train_losses).item(),
                    numpy.mean(train_acc).item() * 100,
                )
            )

            print("validating for epoch {}...".format(epoch))
            sess.run(val_iter.initializer)
            val_losses = []
            val_acc = []

            for step in range(val_steps):
                meas_loss, meas_acc = sess.run(
                    [loss, acc],
                    feed_dict={handle: val_iter_handle, training: False},
                )
                val_losses.append(meas_loss)
                val_acc.append(meas_acc)

            print(
                "completed epoch {} validation with: loss {} / acc {}".format(
                    epoch,
                    numpy.mean(val_losses).item(),
                    numpy.mean(val_acc).item() * 100,
                )
            )

        #######################################################
        # Final line for sparseml training in TensorFlow, complete the graph
        #######################################################
        manager.complete_graph()

        NAME = "resnet50-imagenette-pruned"
        checkpoint_path = create_unique_dir(
            os.path.join(".", SAVE_DIR, NAME, "checkpoint")
        )
        checkpoint_path = os.path.join(checkpoint_path, "model")
        create_dirs(checkpoint_path)
        saver = ModelRegistry.saver(MODEL_NAME)
        saver.save(sess, checkpoint_path)
        print("saved model checkpoint to {}".format(checkpoint_path))

## Step 5 - Exporting to ONNX

Now that the model is fully recalibrated, you need to export it to an ONNX format, which is the format used by the DeepSparse Engine. For TensorFlow, exporting to ONNX is not natively supported. To add support, you will use the `tf2onnx` Python package. In the cell block below, a convenience class, `GraphExporter()`, is used to handle exporting. It wraps the somewhat complicated API for `tf2onnx` into an easy to use interface.

Note, for some configurations, the tf2onnx code does not work properly in a Jupyter Notebook. To remedy this, you should run the `exporter.export_onnx()` function call in a Python console or script.

Once the model is saved as an ONNX ﬁle, it is ready to be used for inference with the DeepSparse Engine.

In [None]:
from sparseml.utils import clean_path
from sparseml.tensorflow_v1.utils import GraphExporter


export_path = clean_path(os.path.join(".", SAVE_DIR, NAME))
exporter = GraphExporter(export_path)

with tf.Graph().as_default() as graph:
    print("Recreating graph...", flush=True)

    input_placeholder = tf.placeholder(
        tf.float32, [None, INPUT_SIZE, INPUT_SIZE, 3], name="inputs"
    )
    logits = create_model(input_placeholder, training=False)

    input_names = [input_placeholder.name]
    output_names = [logits.name]

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print("Restoring previous weights...", flush=True)
        saver = ModelRegistry.saver(MODEL_NAME)
        saver.restore(sess, checkpoint_path)

        print("Exporting to pb...", flush=True)
        exporter.export_pb(outputs=[logits])
        print("Exported pb file to {}".format(exporter.pb_path), flush=True)

print("Exporting to onnx...", flush=True)
exporter.export_onnx(inputs=input_names, outputs=output_names)
print("Exported onnx file to {}".format(exporter.onnx_path))

## Next Steps

Congratulations, you have pruned a model and exported it to ONNX for inference!  Next steps you can pursue include:
* Pruning different models using SparseML
* Trying different pruning and optimization recipes
* Running your model on the DeepSparse Engine