# Sparsifying MobileNetv2 from Scratch (Beans)

In this example, we will demonstrate how to sparsify an image classification model from scratch using SparseML's PyTorch integration. We train and prune [MobileNetv2](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html) on the downstream [Beans dataset](https://huggingface.co/datasets/beans) using the Global Magnitude Pruning algorithm. 

## Agenda

There are a few steps:

 1. Setup the dataset
 2. Setup the PyTorch training loop
 3. Train a dense version of MobileNetv2
 4. Run the GMP pruning algorithm on the dense model
 
## Installation

Install SparseML with `pip`:

```
pip install sparseml[torchvision]
```

In [1]:
import torch
import sparseml
import torchvision
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import TensorBoardLogger, ModuleExporter, get_prunable_layers, tensor_sparsity
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision import transforms
from tqdm.auto import tqdm
import math
import datasets

## Step 1: Setup Dataset

Beans leaf dataset is a set of images of diseased and healthy leaves. Based on a leaf image, the goal of this task is to predict the disease type (Angular Leaf Spot and Bean Rust), if any.

We will use the Hugging Face `datasets` library to download the data and the torchvision `ImageFolder` in the training loop.

[Checkout the dataset card](https://huggingface.co/datasets/beans)

In [None]:
beans_dataset = datasets.load_dataset("beans")

In [3]:
print(beans_dataset["train"][0]["image_file_path"])
print(beans_dataset["validation"][0]["image_file_path"])

/home/ubuntu/.cache/huggingface/datasets/downloads/extracted/eeb026374cf5ecfd5f40131a3159be9b9055ac21a3da11690e7eb4d117c99eee/train/bean_rust/bean_rust_train.84.jpg
/home/ubuntu/.cache/huggingface/datasets/downloads/extracted/f287261265d2f9a3e8f87a5526a54d1847b17f7c3ec5714e5719432f2b3e4a73/validation/bean_rust/bean_rust_val.36.jpg


In [4]:
train_path = "/home/ubuntu/.cache/huggingface/datasets/downloads/extracted/eeb026374cf5ecfd5f40131a3159be9b9055ac21a3da11690e7eb4d117c99eee/train"
val_path = "/home/ubuntu/.cache/huggingface/datasets/downloads/extracted/f287261265d2f9a3e8f87a5526a54d1847b17f7c3ec5714e5719432f2b3e4a73/validation" 

In [5]:
NUM_LABELS = 3
BATCH_SIZE = 32

# imagenet transforms
imagenet_transform = transforms.Compose([
   transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None),
   transforms.CenterCrop(size=(224, 224)),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# datasets
train_dataset = torchvision.datasets.ImageFolder(
    root=train_path,
    transform=imagenet_transform
)

val_dataset = torchvision.datasets.ImageFolder(
    root=val_path,
    transform=imagenet_transform
)

# dataloaders
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16)
val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16)

## Step 2: Setup PyTorch Training Loop

We will use this training loop below. This is standard PyTorch functionality.

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

def run_model_one_epoch(model, data_loader, criterion, device, train=False, optimizer=None):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    total_correct = 0
    total_predictions = 0

    # loop through batches
    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        if train:
            optimizer.zero_grad()

        # compute loss, run backpropogation
        outputs = model(inputs)  # model returns logits
        loss = criterion(outputs, labels)
        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

        # run evaluation
        predictions = outputs.argmax(dim=1)
        total_correct += torch.sum(predictions == labels).item()
        total_predictions += inputs.size(0)

    # return loss and evaluation metric
    loss = running_loss / (step + 1.0)
    accuracy = total_correct / total_predictions
    return loss, accuracy

cuda


## **Step 3: Train MobileNet-v2 on Beans**

First, we will train a dense version of MobileNetv2 on the Beans dataset.

In [16]:
# download pre-trained model, setup classification head
model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS)
model.to(device)

# setup loss function and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3) # lr will be override by sparseml

Next, we will use SparseML's recipes to set the hyperparameters of training loop. In this case, we will use the following recipe:

```yaml
# Epoch and Learning-Rate variables
num_epochs: 10.0
init_lr: 0.0005

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)
```

As you can see, the recipe includes an `!EpochRangeModifier` and a `!LearningRateFunctionModifier`. These modifiers simply set the number of epochs to train for and the learning rate schedule. As a result, the final model will be dense.

In [17]:
!cat ./recipes/mobilenetv2-beans-dense-recipe.yaml


# Epoch and Learning-Rate variables
num_epochs: 10.0
init_lr: 0.0005

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

In [18]:
dense_recipe_path = "./recipes/mobilenetv2-beans-dense-recipe.yaml"

Next, we use SparseML's `ScheduledModifierManager` to parse and apply the recipe. The `manager.modify` function modifies and wraps the `model` and `optimizer` with the instructions from the recipe. You can use the `model` and `optimizer` just like standard PyTorch objects.

In [19]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(dense_recipe_path)
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))

Kick off the transfer learning loop. Our run reached ~99% validation accuracy after 10 epochs.

In [20]:
# run transfer learning
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")

manager.finalize(model)

Running Training Epoch 1/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 1/10
Training Loss: 0.47380749520027277
Top 1 Acc: 0.8152804642166345

Running Validation Epoch 1/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 1/10
Val Loss: 0.12894474416971208
Top 1 Acc: 0.9774436090225563

Running Training Epoch 2/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 2/10
Training Loss: 0.09328031384696563
Top 1 Acc: 0.9690522243713733

Running Validation Epoch 2/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 2/10
Val Loss: 0.1064663665369153
Top 1 Acc: 0.9548872180451128

Running Training Epoch 3/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 3/10
Training Loss: 0.04011029944839803
Top 1 Acc: 0.9903288201160542

Running Validation Epoch 3/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 3/10
Val Loss: 0.06654092976823449
Top 1 Acc: 0.9699248120300752

Running Training Epoch 4/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 4/10
Training Loss: 0.013448385923931544
Top 1 Acc: 0.9990328820116054

Running Validation Epoch 4/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 4/10
Val Loss: 0.04685080051422119
Top 1 Acc: 0.9699248120300752

Running Training Epoch 5/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 5/10
Training Loss: 0.011129878745343762
Top 1 Acc: 0.9970986460348162

Running Validation Epoch 5/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 5/10
Val Loss: 0.04289984308416024
Top 1 Acc: 0.9774436090225563

Running Training Epoch 6/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 6/10
Training Loss: 0.007474540554650241
Top 1 Acc: 0.9990328820116054

Running Validation Epoch 6/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 6/10
Val Loss: 0.038813505135476586
Top 1 Acc: 0.9924812030075187

Running Training Epoch 7/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 7/10
Training Loss: 0.022992545157621586
Top 1 Acc: 0.9961315280464217

Running Validation Epoch 7/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 7/10
Val Loss: 0.038337701372802256
Top 1 Acc: 0.9699248120300752

Running Training Epoch 8/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 8/10
Training Loss: 0.01563219685693074
Top 1 Acc: 0.9941972920696325

Running Validation Epoch 8/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 8/10
Val Loss: 0.052491285931319
Top 1 Acc: 0.9774436090225563

Running Training Epoch 9/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 9/10
Training Loss: 0.004495378677154694
Top 1 Acc: 1.0

Running Validation Epoch 9/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 9/10
Val Loss: 0.03483513812534511
Top 1 Acc: 0.9924812030075187

Running Training Epoch 10/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 10/10
Training Loss: 0.003678643000911865
Top 1 Acc: 1.0

Running Validation Epoch 10/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 10/10
Val Loss: 0.04160818513482809
Top 1 Acc: 0.9924812030075187



In [None]:
from sparseml.pytorch.utils import ModuleExporter

save_dir = "dense_model"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="mobilenet-v2-dense-beans.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="dense-model.onnx", convert_qat=True)

## Step 4: Prune The Model

With a model trained on Beans, we are now ready to apply the GMP algorithm to prune the model. The GMP algorithm is an interative pruning algorithm. At the end of each epoch, we identify the lowest magnitude weights (those closest to 0) and remove them from the network starting from an initial level of sparsity until a final level of sparsity. The remaining nonzero weights are then fine-tuned onto training dataset.

In [22]:
checkpoint = torch.load("./dense_model/training/mobilenet-v2-dense-beans.pth")
model = torchvision.models.mobilenet_v2()
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)

# setup loss function and optimizer, LR will be overriden by sparseml
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

Next, we need to create a SparseML recipe which includes the GMP algorithm. The `!GlobalMagnitudePruningModifier` modifier instructs SparseML to apply the GMP algorithm at a global level (pruning the lowest magnitude weights across all layers).

Firstly, we need to decide identify which parameters of the model to apply the GMP algorithm to. We can use the `get_prunable_layers` function to inspect:

In [23]:
# print parameters
for (name, layer) in get_prunable_layers(model):
    print(f"{name}")

features.0.0
features.1.conv.0.0
features.1.conv.1
features.2.conv.0.0
features.2.conv.1.0
features.2.conv.2
features.3.conv.0.0
features.3.conv.1.0
features.3.conv.2
features.4.conv.0.0
features.4.conv.1.0
features.4.conv.2
features.5.conv.0.0
features.5.conv.1.0
features.5.conv.2
features.6.conv.0.0
features.6.conv.1.0
features.6.conv.2
features.7.conv.0.0
features.7.conv.1.0
features.7.conv.2
features.8.conv.0.0
features.8.conv.1.0
features.8.conv.2
features.9.conv.0.0
features.9.conv.1.0
features.9.conv.2
features.10.conv.0.0
features.10.conv.1.0
features.10.conv.2
features.11.conv.0.0
features.11.conv.1.0
features.11.conv.2
features.12.conv.0.0
features.12.conv.1.0
features.12.conv.2
features.13.conv.0.0
features.13.conv.1.0
features.13.conv.2
features.14.conv.0.0
features.14.conv.1.0
features.14.conv.2
features.15.conv.0.0
features.15.conv.1.0
features.15.conv.2
features.16.conv.0.0
features.16.conv.1.0
features.16.conv.2
features.17.conv.0.0
features.17.conv.1.0
features.17.conv

We will apply pruning to each of the convs and exclude the classifier layer (which is the final projection head). Fortunately, SparseML allows us to pass regexes to identify layers in the network, so we can use the following list to identify the relevant layers for pruning:
    
    - 'features.0.0.weight'
    - 'features.18.0.weight'
    - 're:features.*.conv.*.weight'
    - 're:features.*.conv.*.*.weight'

Here is what the recipe looks like:

```yaml
# Epoch and Learning-Rate variables
num_epochs: 13.0
pruning_epochs: 10.0
init_lr: 0.0005
inter_func: cubic
mask_type: unstructured

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

# Pruning
pruning_modifiers:
  - !GlobalMagnitudePruningModifier
    init_sparsity: 0.05
    final_sparsity: 0.90
    start_epoch: 0.0
    end_epoch: eval(pruning_epochs)
    update_frequency: 1.0
    params: 
        - 'features.0.0.weight'
        - 'features.18.0.weight'
        - 're:features.*.conv.*.weight'
        - 're:features.*.conv.*.*.weight'
    leave_enabled: True
    inter_func: eval(inter_func)
    mask_type: eval(mask_type)
```

This recipe specifies that we will run the GMP algorithm for the first 10 epochs. We start at an init_sparsity level of 5% and gradually increase sparsity to a final_sparsity level of 90% following a cubic curve. The pruning is applied in an unstructured manner, meaning that any weight can be pruned.

Over the final 3 epochs, we will fine-tune the 90% pruned model further. Since we set leave_enabled=True the sparsity level will be maintained as the fine-tuning occurs.

In [None]:
!cat ./recipes/mobilenetv2-beans-pruning-recipe.yaml

In [28]:
pruning_recipe_path = "./recipes/mobilenetv2-beans-pruning-recipe.yaml"

In [32]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(pruning_recipe_path)
logger = TensorBoardLogger(log_path="./tensorboard_outputs")
optimizer = manager.modify(model, optimizer, loggers=[logger], steps_per_epoch=len(train_loader))

Next, kick off the GMP training loop. 

As you can see, we use the wrapped `optimizer` and `model` in the same way as above. SparseML parsed the recipe and updated the `optimizer` with the logic of GMP algorithm from the recipe. This allows you to just the `optimizer` and `model` as usual, with all of the pruning-related logic specified by the declarative recipe interface.

Our 90% pruned model reaches ~99% validation accuracy (vs ~99% for the dense model).

In [33]:
# run GMP algorithm
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")
    
    logger.log_scalar("Metrics/Loss (Train)", train_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Train)", train_acc, epoch)
    logger.log_scalar("Metrics/Loss (Validation)", val_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Validation)", val_acc, epoch)
    
manager.finalize(model)

Running Training Epoch 1/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 1/13
Training Loss: 0.030314430911940606
Top 1 Acc: 0.9893617021276596

Running Validation Epoch 1/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 1/13
Val Loss: 0.10016061989590526
Top 1 Acc: 0.9699248120300752

Running Training Epoch 2/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 2/13
Training Loss: 0.03320332106721418
Top 1 Acc: 0.9903288201160542

Running Validation Epoch 2/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 2/13
Val Loss: 0.05205255227629095
Top 1 Acc: 0.9774436090225563

Running Training Epoch 3/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 3/13
Training Loss: 0.02863485404558367
Top 1 Acc: 0.9922630560928434

Running Validation Epoch 3/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 3/13
Val Loss: 0.04446154618635774
Top 1 Acc: 0.9849624060150376

Running Training Epoch 4/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 4/13
Training Loss: 0.028403193393552847
Top 1 Acc: 0.9932301740812379

Running Validation Epoch 4/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 4/13
Val Loss: 0.021256184950470925
Top 1 Acc: 0.9924812030075187

Running Training Epoch 5/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 5/13
Training Loss: 0.10820412065720919
Top 1 Acc: 0.9671179883945842

Running Validation Epoch 5/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 5/13
Val Loss: 0.07088811788707972
Top 1 Acc: 0.9774436090225563

Running Training Epoch 6/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 6/13
Training Loss: 0.21380004890714632
Top 1 Acc: 0.9216634429400387

Running Validation Epoch 6/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 6/13
Val Loss: 0.07981351003982126
Top 1 Acc: 0.9774436090225563

Running Training Epoch 7/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 7/13
Training Loss: 0.19059311598539352
Top 1 Acc: 0.9332688588007737

Running Validation Epoch 7/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 7/13
Val Loss: 0.07885787561535836
Top 1 Acc: 0.9924812030075187

Running Training Epoch 8/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 8/13
Training Loss: 0.14277072397596907
Top 1 Acc: 0.965183752417795

Running Validation Epoch 8/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 8/13
Val Loss: 0.08072680607438087
Top 1 Acc: 0.9924812030075187

Running Training Epoch 9/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 9/13
Training Loss: 0.11926381854396878
Top 1 Acc: 0.9680851063829787

Running Validation Epoch 9/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 9/13
Val Loss: 0.08937856331467628
Top 1 Acc: 0.9849624060150376

Running Training Epoch 10/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 10/13
Training Loss: 0.08645325391130014
Top 1 Acc: 0.9854932301740812

Running Validation Epoch 10/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 10/13
Val Loss: 0.06663795448839664
Top 1 Acc: 0.9924812030075187

Running Training Epoch 11/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 11/13
Training Loss: 0.07962510655775215
Top 1 Acc: 0.9816247582205029

Running Validation Epoch 11/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 11/13
Val Loss: 0.0655035775154829
Top 1 Acc: 0.9924812030075187

Running Training Epoch 12/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 12/13
Training Loss: 0.08157054532432195
Top 1 Acc: 0.9825918762088974

Running Validation Epoch 12/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 12/13
Val Loss: 0.061416388303041455
Top 1 Acc: 0.9924812030075187

Running Training Epoch 13/13


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 13/13
Training Loss: 0.06805486869857166
Top 1 Acc: 0.9874274661508704

Running Validation Epoch 13/13


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 13/13
Val Loss: 0.06589164286851883
Top 1 Acc: 0.9924812030075187



Here is a sample of the TensorBoard output, showing the validation accuracy, a particular layer's sparsity level, and the learning rate over time.

![tensorboard output](./images/mobilenetv2-beans-tensorboard-output.png)

We can print layer-by-layer sparsity as well.

In [35]:
for (name, layer) in get_prunable_layers(model):
    print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}")

features.0.0.weight: 0.4861
features.1.conv.0.0.weight: 0.4688
features.1.conv.1.weight: 0.5156
features.2.conv.0.0.weight: 0.5547
features.2.conv.1.0.weight: 0.2905
features.2.conv.2.weight: 0.4865
features.3.conv.0.0.weight: 0.6102
features.3.conv.1.0.weight: 0.4961
features.3.conv.2.weight: 0.6111
features.4.conv.0.0.weight: 0.5229
features.4.conv.1.0.weight: 0.2731
features.4.conv.2.weight: 0.5187
features.5.conv.0.0.weight: 0.6901
features.5.conv.1.0.weight: 0.5880
features.5.conv.2.weight: 0.7126
features.6.conv.0.0.weight: 0.6934
features.6.conv.1.0.weight: 0.6111
features.6.conv.2.weight: 0.7402
features.7.conv.0.0.weight: 0.5495
features.7.conv.1.0.weight: 0.3634
features.7.conv.2.weight: 0.6127
features.8.conv.0.0.weight: 0.8370
features.8.conv.1.0.weight: 0.7405
features.8.conv.2.weight: 0.8673
features.9.conv.0.0.weight: 0.8272
features.9.conv.1.0.weight: 0.7384
features.9.conv.2.weight: 0.8634
features.10.conv.0.0.weight: 0.8271
features.10.conv.1.0.weight: 0.7870
features

Finally, export your model to ONNX.

In [37]:
save_dir = "experiment-0"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="mobilenet_v2-sparse-beans.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="sparse-model.onnx", convert_qat=True)

## Wrap Up

The resulting model is is 90% sparse and achieves validation accuracy of ~99% (vs the unoptimized dense model at ~99%) without much hyperparameter search.

Key hyperparameter experiments you may want to run include:
- Learning rate
- Learning rate schedule
- Sparsity level
- Number of pruning epochs

DeepSparse supports speedup from pruning and quantization. To reach maximum performance, check out our examples of quantizing a model!