<sub>&copy; 2021-present Neuralmagic, Inc. // [Neural Magic Legal](https://neuralmagic.com/legal)</sub> 

# Torchvision Classification Model Pruning using SparseML

This notebook provides a step-by-step walkthrough for pruning a [torchvision model](https://pytorch.org/docs/stable/torchvision/models.html) using SparseML. You will:
- Download a pre-trained torchvision model and generic dataset
- Define a generic torchvision finetuning flow
- Integrate the torchvision flow with SparseML
- Prune the model using the torchvision+SparseML flow
- Save the model and export to [ONNX](https://onnx.ai/)

Reading through this notebook will be reasonably quick to gain an intuition for how to integrate SparseML with torchvision or more generically a PyTorch training flow. Rough time estimates for fully pruning the default model are given. Note that training with the PyTorch CPU implementation will be much slower than a GPU:
- 15 minutes on a GPU
- 45 minutes on a laptop CPU

## Step 1 - Requirements
To run this notebook, you will need the following packages already installed:
* SparseML and SparseZoo
* PyTorch and torchvision

You can install any package that is not already present via `pip`.

In [None]:
import sparseml
import sparsezoo
import torch
import torchvision

## Step 2 - Setting Up the Model and Dataset

By default, you will prune a [ResNet50](https://arxiv.org/abs/1512.03385) model while finetuning it on the [Imagenette dataset](https://github.com/fastai/imagenette). The model's pretrained weights are downloaded from torchvision. The Imagenette dataset is downloaded from its repository via a helper class from SparseML.

Additionally, we will override the FC layer in the ResNet50 model to have 10 output classes instead of the ImageNet standard 1000.

If you would like to try out your model for pruning, modify the appropriate lines for your model and dataset, speciﬁcally:
- model = resnet50(pretrained=True)
- train_dataset = ImagenetteDataset(...)
- val_dataset = ImagenetteDataset(...)

Take care to keep the variable names the same, as the rest of the notebook is set up according to those and update any parts of the training flow as needed.

In [None]:
from torchvision.models import resnet50
from torch.nn import Linear

from sparseml.pytorch.datasets import ImagenetteDataset, ImagenetteSize

#######################################################
# Define your model below
#######################################################
print("loading model...")
model = resnet50(pretrained=True)
print(model)
#######################################################
# Define your train and validation datasets below
#######################################################

print("\nloading train dataset...")
train_dataset = ImagenetteDataset(
    train=True, dataset_size=ImagenetteSize.s320, image_size=224
)
print(train_dataset)

print("\nloading val dataset...")
val_dataset = ImagenetteDataset(
    train=False, dataset_size=ImagenetteSize.s320, image_size=224
)
print(val_dataset)

# Overriding number of classes
NUM_CLASSES = 10  # number of imagenette classes
model.fc = Linear(in_features=model.fc.in_features, out_features=NUM_CLASSES, bias=True)
print(model.fc)

## Step 3 - Set Up a Torchvision Finetuning Loop
SparseML can plug directly into your existing PyTorch training flow by overriding the Optimizer object. To demonstrate this, in the cell below, we define a simple PyTorch training loop taken from the [torchvision finetuning example](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html).  To prune your existing models using SparseML, you can use your own training flow.

In [None]:
import time
import copy
import torch

def train_model(
    model, dataloaders, criterion, optimizer, device, num_epochs=25, is_inception=False
):
    since = time.time()

    val_acc_history = []

    best_acc = 0.0

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs - 1))
        print("-" * 10)

        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == "train":
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4 * loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == "val" and epoch_acc > best_acc:
                best_acc = epoch_acc
            if phase == "val":
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print(
        "Training complete in {:.0f}m {:.0f}s".format(
            time_elapsed // 60, time_elapsed % 60
        )
    )
    print("Best val Acc: {:4f}".format(best_acc))

    # load best model weights
    return model, val_acc_history

## Step 4 - Set Up PyTorch Training Objects
In this step, you will select a device to train your model with, set up DataLoader objects, a loss function, and optimizer.  All of these variables and objects can be replaced to fit your training flow.

In [None]:
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import SGD

# setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print("Using device: {}".format(device))

# setup data loaders
batch_size = 128
train_loader = DataLoader(
    train_dataset, batch_size, shuffle=True, pin_memory=True, num_workers=8
)
val_loader = DataLoader(
    val_dataset, batch_size, shuffle=False, pin_memory=True, num_workers=8
)
dataloaders = {"train": train_loader, "val": val_loader}

# setup loss function and optimizer, LR will be overriden by sparseml
criterion = CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)

## Step 5 - Apply a SparseML Recipe and Prune Model

To prune a model with SparseML, you will download a recipe from SparseZoo and use it to create a `ScheduledModifierManager` object.  This manager will be used to wrap the optimizer object to gradually prune the model using unstructured weight magnitude pruning after each optimizer step.

You can create SparseML recipes to perform various model pruning schedules, quantization aware training, sparse transfer learning, and more.  If you are using a different model than the default, you will have to modify the recipe YAML file to target the new model's parameters.

Finally, using the wrapped optimizer object, you will call the training function to prune your model.

If the kernel shuts down during training, this may be an out of memory error, to resolve this, try lowering the `batch_size` in the cell above.

#### Downloading a Recipe from SparseZoo
The [SparseZoo](https://github.com/neuralmagic/sparsezoo) API provides precofigured recipes for its optimized model.  In the cell below, you will download a recipe for pruning ResNet50 on the Imagenette dataset and record it's saved path.

In [None]:
from sparsezoo import Zoo

recipe = Zoo.search_recipes(
    domain="cv",
    sub_domain="classification",
    architecture="resnet_v1",
    sub_architecture="50",
    framework="pytorch",
    repo="torchvision",
    dataset="imagenette",
    sparse_name="pruned",
)[0]  # unwrap search result
recipe.download()
recipe_path = recipe.downloaded_path()
print(f"Recipe downloaded to: {recipe_path}")

In [None]:
from sparseml.pytorch.optim import (
    ScheduledModifierManager,
    ScheduledOptimizer,
)

# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(recipe_path)
optimizer = ScheduledOptimizer(
    optimizer, model, manager, steps_per_epoch=len(train_loader), loggers=[],
)

train_model(
    model,
    dataloaders,
    criterion,
    optimizer,
    device,
    num_epochs=manager.max_epochs,
    is_inception=False,
)

## Step 6 - View Model Sparsity
To see the effects of the model pruning, in this step, you will print out the sparsities of each Conv and FC layer in your model.

In [None]:
from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity

# print sparsities of each layer
for (name, layer) in get_prunable_layers(model):
    print("{}.weight: {:.4f}".format(name, tensor_sparsity(layer.weight).item()))

## Step 6 - Save Model and Export to ONNX

Now that the model is fully recalibrated, you need to export it to an ONNX format, which is the format used by the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse). For PyTorch, exporting to ONNX is natively supported. In the cell block below, a convenience class, ModuleExporter(), is used to handle exporting.

Once the model is saved as an ONNX ﬁle, it is ready to be used for inference with the DeepSparse Engine.  For saving a custom model, you can override the sample batch for ONNX graph freezing and locations to save to.

In [None]:
from sparseml.pytorch.utils import ModuleExporter

save_dir = "torchvision_models"

exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="resnet50_imagenette_pruned.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="resnet50_imagenette_pruned.onnx")

## Next Steps

Congratulations, you have pruned a model and exported it to ONNX for inference!  Next steps you can pursue include:
* Pruning different models using SparseML
* Trying different pruning and optimization recipes
* Running your model on the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse)