<sub>&copy; 2021-present Neuralmagic, Inc. // [Neural Magic Legal](https://neuralmagic.com/legal)</sub> 

# PyTorch Detection Model Pruning using SparseML

This notebook provides a step-by-step walkthrough for pruning an already trained (dense) model to enable better performance at inference time using the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse). You will:
- Set up the model and dataset
- Define a generic PyTorch training flow
- Integrate the PyTorch flow with SparseML
- Prune the model using the PyTorch+SparseML flow
- Export to [ONNX](https://onnx.ai/)

Reading through this notebook will be reasonably quick to gain an intuition for how to plug SparseML into your PyTorch training flow. Rough time estimates for fully pruning the default model are given. Note that training with the PyTorch CPU implementation will be much slower than a GPU:
- 30 minutes on a GPU
- 90 minutes on a laptop CPU

## Step 1 - Requirements
To run this notebook, you will need the following packages already installed:
* SparseML and SparseZoo
* PyTorch and torchvision

You can install any package that is not already present via `pip`.

In [None]:
import sparseml
import sparsezoo
import torch
import torchvision

## Step 2 - Setting Up the Model and Dataset

By default, you will prune a [SSD](https://arxiv.org/abs/1512.02325) model trained on the [VOC detection dataset](http://host.robots.ox.ac.uk/pascal/VOC/). The model's pretrained weights are downloaded from the SparseZoo model repo. The VOC detection dataset is downloaded from its repository via a helper class from SparseML.

Note, for this notebook, you will use a ResNet18 backbone for the object detector.  This is to save training time and demonstrate the general pruning flow. For better accuracy, you can use a different backbone or model.

If you would like to try out your model for pruning, modify the appropriate lines for your model and dataset, speciﬁcally:
- model = ModelRegistry.create(...)
- train_dataset = VOCDetectionDataset(...)
- val_dataset = VOCDetectionDataset(...)

Take care to keep the variable names the same, as the rest of the notebook is set up according to those and update any parts of the training flow as needed.

In [None]:
from sparseml.pytorch.models import ModelRegistry
from sparseml.pytorch.datasets import VOCDetectionDataset
from sparseml.pytorch.utils import get_default_boxes_300

#######################################################
# Define your model below
#######################################################
print("loading model...")
model = ModelRegistry.create(
    key="ssd300_resnet18",
    pretrained=True,
    pretrained_dataset="voc",
    pretrained_backbone=False,  # no need to download initial weights
    num_classes=21,
)
model_name = model.__class__.__name__
input_shape = ModelRegistry.input_shape("ssd300_resnet18")
input_size = input_shape[-1]
print(model)
#######################################################
# Define your train and validation datasets below
#######################################################

print("\nloading train dataset...")
default_boxes = get_default_boxes_300("voc")
train_dataset = VOCDetectionDataset(
    train=True, rand_trans=True, preprocessing_type="ssd", default_boxes=default_boxes
)
print(train_dataset)

print("\nloading val dataset...")
val_dataset = VOCDetectionDataset(
    train=False, preprocessing_type="ssd", default_boxes=default_boxes
)
print(val_dataset)

## Step 3 - Set Up a PyTorch Training Loop
SparseML can plug directly into your existing PyTorch training flow by overriding the Optimizer object. To demonstrate this, in the cell below, we define a simple PyTorch training loop adapted from [here](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) to work with our detection model and loss function.  To prune your existing models using SparseML, you can use your own training flow.

In [None]:
from tqdm.auto import tqdm
import math
import torch

from sparseml.pytorch.utils import DEFAULT_LOSS_KEY


def run_model_one_epoch(
    model, data_loader, criterion, device, train=False, optimizer=None
):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0

    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = inputs.to(device)
        labels = [
            label.to(device) if isinstance(label, torch.Tensor) else label
            for label in labels
        ]

        if train:
            optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion((inputs, labels), outputs)[DEFAULT_LOSS_KEY]

        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

    loss = running_loss / (step + 1.0)
    return loss

## Step 4 - Set Up Detection Training Objects
In this step, you will select a device to train your model with, set up DataLoader objects, a loss function, and optimizer.  All of these variables and objects can be replaced to fit your training flow.  The loss function and collate function are standard for SSD training and are defined in the sparseml API.

In [None]:
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import SGD

from sparseml.pytorch.datasets import ssd_collate_fn
from sparseml.pytorch.utils import SSDLossWrapper

# setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print("Using device: {}".format(device))

# setup data loaders
batch_size = 64
train_loader = DataLoader(
    train_dataset,
    batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=12,
    collate_fn=ssd_collate_fn,
)
val_loader = DataLoader(
    val_dataset,
    batch_size,
    shuffle=False,
    pin_memory=True,
    num_workers=12,
    collate_fn=ssd_collate_fn,
)

# setup loss function and optimizer
criterion = SSDLossWrapper()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)

## Step 5 - Apply a SparseML Recipe and Prune Model

To prune a model with SparseML, you will download a recipe from SparseZoo and use it to create a `ScheduledModifierManager` object.  This manager will be used to wrap the optimizer object to gradually prune the model using unstructured weight magnitude pruning after each optimizer step.

You can create SparseML recipes to perform various model pruning schedules, quantization aware training, sparse transfer learning, and more.  If you are using a different model than the default, you will have to modify the recipe YAML file to target the new model's parameters.

Finally, using the wrapped optimizer object, you will call the training function to prune your model.

If the kernel shuts down during training, this may be an out of memory error, to resolve this, try lowering the `batch_size` in the cell above.

#### Downloading a Recipe from SparseZoo
The [SparseZoo](https://github.com/neuralmagic/sparsezoo) API provides precofigured recipes for its optimized model.  In the cell below, you will download a recipe for pruning SSD-ResNet18 on the VOC dataset and record it's saved path.

In [None]:
from sparsezoo import Zoo

recipe = Zoo.search_recipes(
    domain="cv",
    sub_domain="detection",
    architecture="ssd",
    sub_architecture="resnet18_300",
    framework="pytorch",
    repo="sparseml",
    dataset="voc",
    sparse_name="pruned",
)[0]  # unwrap search result
recipe.download()
recipe_path = recipe.downloaded_path()
print(f"Recipe downloaded to: {recipe_path}")

In [None]:
from sparseml.pytorch.optim import (
    ScheduledModifierManager,
    ScheduledOptimizer,
)

# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(recipe_path)
optimizer = ScheduledOptimizer(
    optimizer,
    model,
    manager,
    steps_per_epoch=len(train_loader),
    loggers=[],
)

epoch = manager.min_epochs
while epoch < manager.max_epochs:
    # run training loop
    epoch_name = "{}/{}".format(epoch + 1, manager.max_epochs)
    print("Running Training Epoch {}".format(epoch_name))
    train_loss = run_model_one_epoch(
        model, train_loader, criterion, device, train=True, optimizer=optimizer
    )
    print("Training Epoch: {}\nTraining Loss: {}\n".format(epoch_name, train_loss))

    # run validation loop
    print("Running Validation Epoch {}".format(epoch_name))
    val_loss = run_model_one_epoch(model, val_loader, criterion, device)
    print("Validation Epoch: {}\nVal Loss: {}\n".format(epoch_name, val_loss))

    epoch += 1

## Step 6 - View Model Sparsity
To see the effects of the model pruning, in this step, you will print out the sparsities of each Conv and FC layer in your model.

In [None]:
from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity

# print sparsities of each layer
for (name, layer) in get_prunable_layers(model):
    print("{}.weight: {:.4f}".format(name, tensor_sparsity(layer.weight).item()))

## Step 7 - Exporting to ONNX

Now that the model is fully recalibrated, you need to export it to an ONNX format, which is the format used by the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse). For PyTorch, exporting to ONNX is natively supported. In the cell block below, a convenience class, ModuleExporter(), is used to handle exporting.

Once the model is saved as an ONNX ﬁle, it is ready to be used for inference with the DeepSparse Engine.  For saving a custom model, you can override the sample batch for ONNX graph freezing and locations to save to.

In [None]:
from sparseml.pytorch.utils import ModuleExporter

save_dir = "pytorch_detection"

exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="ssd_resnet18_pruned.pth")
exporter.export_onnx(torch.randn(1, 3, 300, 300), name="ssd_resnet18_pruned.onnx")

## Next Steps

Congratulations, you have pruned a model and exported it to ONNX for inference!  Next steps you can pursue include:
* Pruning different models using SparseML
* Trying different pruning and optimization recipes
* Running your model on the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse)