<sub>&copy; 2021-present Neuralmagic, Inc. // [Neural Magic Legal](https://neuralmagic.com/legal)</sub> 

# PyTorch Classification Model Pruning using SparseML

This notebook provides a step-by-step walkthrough for pruning an already trained (dense) model to enable better performance at inference time using the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse). You will:
- Set up the model and dataset
- Define a generic PyTorch training flow
- Integrate the PyTorch flow with SparseML
- Prune the model using the PyTorch+SparseML flow
- Export to [ONNX](https://onnx.ai/)

Reading through this notebook will be reasonably quick to gain an intuition for how to plug SparseML into your PyTorch training flow. Rough time estimates for fully pruning the default model are given. Note that training with the PyTorch CPU implementation will be much slower than a GPU:
- 15 minutes on a GPU
- 45 minutes on a laptop CPU

## Step 1 - Requirements
To run this notebook, you will need the following packages already installed:
* SparseML and SparseZoo
* PyTorch and torchvision

You can install any package that is not already present via `pip`.

In [None]:
import sparseml
import sparsezoo
import torch
import torchvision

## Step 2 - Setting Up the Model and Dataset

By default, you will prune a [ResNet50](https://arxiv.org/abs/1512.03385) model trained on the [Imagenette dataset](https://github.com/fastai/imagenette). The model's pretrained weights are downloaded from the SparseZoo model repo.   The Imagenette dataset is downloaded from its repository via a helper class from SparseML.

If you would like to try out your model for pruning, modify the appropriate lines for your model and dataset, speciﬁcally:
- model = ModelRegistry.create(...)
- train_dataset = ImagenetteDataset(...)
- val_dataset = ImagenetteDataset(...)

Take care to keep the variable names the same, as the rest of the notebook is set up according to those and update any parts of the training flow as needed.

In [None]:
from sparseml.pytorch.models import ModelRegistry
from sparseml.pytorch.datasets import ImagenetteDataset, ImagenetteSize

#######################################################
# Define your model below
#######################################################
print("loading model...")
model = ModelRegistry.create(
    key="resnet50",
    pretrained=True,
    pretrained_dataset="imagenette",
    num_classes=10,
)
input_shape = ModelRegistry.input_shape("resnet50")
input_size = input_shape[-1]
print(model)
#######################################################
# Define your train and validation datasets below
#######################################################

print("\nloading train dataset...")
train_dataset = ImagenetteDataset(
    train=True, dataset_size=ImagenetteSize.s320, image_size=input_size
)
print(train_dataset)

print("\nloading val dataset...")
val_dataset = ImagenetteDataset(
    train=False, dataset_size=ImagenetteSize.s320, image_size=input_size
)
print(val_dataset)

## Step 3 - Set Up a PyTorch Training Loop
SparseML can plug directly into your existing PyTorch training flow by overriding the Optimizer object. To demonstrate this, in the cell below, we define a simple PyTorch training loop adapted from [here](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html).  To prune your existing models using SparseML, you can use your own training flow.

In [None]:
from tqdm.auto import tqdm
import math
import torch


def run_model_one_epoch(model, data_loader, criterion, device, train=False, optimizer=None):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    total_correct = 0
    total_predictions = 0

    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        if train:
            optimizer.zero_grad()

        outputs, _ = model(inputs)  # model returns logits and softmax as a tuple
        loss = criterion(outputs, labels)

        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

        predictions = outputs.argmax(dim=1)
        total_correct += torch.sum(predictions == labels).item()
        total_predictions += inputs.size(0)

    loss = running_loss / (step + 1.0)
    accuracy = total_correct / total_predictions
    return loss, accuracy

## Step 4 - Set Up PyTorch Training Objects
In this step, you will select a device to train your model with, set up DataLoader objects, a loss function, and optimizer.  All of these variables and objects can be replaced to fit your training flow.

In [None]:
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import SGD

# setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print("Using device: {}".format(device))

# setup data loaders
batch_size = 128
train_loader = DataLoader(
    train_dataset, batch_size, shuffle=True, pin_memory=True, num_workers=8
)
val_loader = DataLoader(
    val_dataset, batch_size, shuffle=False, pin_memory=True, num_workers=8
)

# setup loss function and optimizer, LR will be overriden by sparseml
criterion = CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)

## Step 5 - Apply a SparseML Recipe and Prune Model

To prune a model with SparseML, you will download a recipe from SparseZoo and use it to create a `ScheduledModifierManager` object.  This manager will be used to wrap the optimizer object to gradually prune the model using unstructured weight magnitude pruning after each optimizer step.

You can create SparseML recipes to perform various model pruning schedules, quantization aware training, sparse transfer learning, and more.  If you are using a different model than the default, you will have to modify the recipe YAML file to target the new model's parameters.

Using the wrapped optimizer object, you will call the training function to prune your model. Finalize the model after training by making a call to manager's `finalize(...)` method.

If the kernel shuts down during training, this may be an out of memory error, to resolve this, try lowering the `batch_size` in the cell above.

#### Downloading a Recipe from SparseZoo
The [SparseZoo](https://github.com/neuralmagic/sparsezoo) API provides precofigured recipes for its optimized model.  In the cell below, you will download a recipe for pruning ResNet50 on the Imagenette dataset and record it's saved path.

In [None]:
from sparsezoo import Model, search_models

zoo_model = search_models(
    domain="cv",
    sub_domain="classification",
    architecture="resnet_v1",
    sub_architecture="50",
    framework="pytorch",
    repo="sparseml",
    dataset="imagenette",
    sparse_name="pruned",
)[0]  # unwrap search result
recipe_path = zoo_model.recipes.default.path
print(f"Recipe downloaded to: {recipe_path}")

In [None]:
from sparseml.pytorch.optim import (
    ScheduledModifierManager,
)

# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(recipe_path)
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))



# Run model pruning
epoch = manager.min_epochs
while epoch < manager.max_epochs:
    # run training loop
    epoch_name = "{}/{}".format(epoch + 1, manager.max_epochs)
    print("Running Training Epoch {}".format(epoch_name))
    train_loss, train_acc = run_model_one_epoch(
        model, train_loader, criterion, device, train=True, optimizer=optimizer
    )
    print(
        "Training Epoch: {}\nTraining Loss: {}\nTop 1 Acc: {}\n".format(
            epoch_name, train_loss, train_acc
        )
    )
    
    # run validation loop
    print("Running Validation Epoch {}".format(epoch_name))
    val_loss, val_acc = run_model_one_epoch(
        model, val_loader, criterion, device
    )
    print(
        "Validation Epoch: {}\nVal Loss: {}\nTop 1 Acc: {}\n".format(
            epoch_name, val_loss, val_acc
        )
    )
    
    epoch += 1

manager.finalize(model)

## Step 6 - View Model Sparsity
To see the effects of the model pruning, in this step, you will print out the sparsities of each Conv and FC layer in your model.

In [None]:
from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity

# print sparsities of each layer
for (name, layer) in get_prunable_layers(model):
    print("{}.weight: {:.4f}".format(name, tensor_sparsity(layer.weight).item()))

## Step 7 - Exporting to ONNX

Now that the model is fully recalibrated, you need to export it to an ONNX format, which is the format used by the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse). For PyTorch, exporting to ONNX is natively supported. In the cell block below, a convenience class, ModuleExporter(), is used to handle exporting.

Once the model is saved as an ONNX ﬁle, it is ready to be used for inference with the DeepSparse Engine.  For saving a custom model, you can override the sample batch for ONNX graph freezing and locations to save to.

In [None]:
from sparseml.pytorch.utils import ModuleExporter

save_dir = "pytorch_classification"

exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="resnet50_imagenette_pruned.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="resnet50_imagenette_pruned.onnx")

## Next Steps

Congratulations, you have pruned a model and exported it to ONNX for inference!  Next steps you can pursue include:
* Pruning different models using SparseML
* Trying different pruning and optimization recipes
* Running your model on the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse)