# Sparsifying MobileNetv2 from Scratch (Beans)

In this example, we will demonstrate how to sparsify an image classification model from scratch using SparseML's PyTorch integration. We train and prune [MobileNetv2](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html) on the downstream [Beans dataset](https://huggingface.co/datasets/beans) using the Global Magnitude Pruning algorithm. 

## Agenda

There are a few steps:

 1. Setup the dataset
 2. Setup the PyTorch training loop
 3. Train a dense version of MobileNetv2
 4. Run the GMP pruning algorithm on the dense model
 
## Installation

Install SparseML with `pip`:

```
pip install sparseml[torchvision]
```

In [1]:
import torch
import sparseml
import torchvision
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import TensorBoardLogger, ModuleExporter, get_prunable_layers, tensor_sparsity
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision import transforms
from tqdm.auto import tqdm
import math
import datasets

## Step 1: Setup Dataset

Beans leaf dataset is a set of images of diseased and healthy leaves. Based on a leaf image, the goal of this task is to predict the disease type (Angular Leaf Spot and Bean Rust), if any.

We will use the Hugging Face `datasets` library to download the data and the torchvision `ImageFolder` in the training loop.

[Checkout the dataset card](https://huggingface.co/datasets/beans)

In [2]:
beans_dataset = datasets.load_dataset("beans")

Found cached dataset beans (/home/ubuntu/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
print(beans_dataset["train"][0]["image_file_path"])
print(beans_dataset["validation"][0]["image_file_path"])

/home/ubuntu/.cache/huggingface/datasets/downloads/extracted/7ad2d437b751e134577576a32849c44a9ade89297680ad5f6a64051e2108810b/train/angular_leaf_spot/angular_leaf_spot_train.0.jpg
/home/ubuntu/.cache/huggingface/datasets/downloads/extracted/1141f86479bc0bb56c75616d153591cc8299d1ea4edc53bb1ab65edd2c65b240/validation/angular_leaf_spot/angular_leaf_spot_val.0.jpg


In [4]:
train_path = "/home/ubuntu/.cache/huggingface/datasets/downloads/extracted/eeb026374cf5ecfd5f40131a3159be9b9055ac21a3da11690e7eb4d117c99eee/train"
val_path = "/home/ubuntu/.cache/huggingface/datasets/downloads/extracted/f287261265d2f9a3e8f87a5526a54d1847b17f7c3ec5714e5719432f2b3e4a73/validation" 

In [5]:
NUM_LABELS = 3
BATCH_SIZE = 32

# imagenet transforms
imagenet_transform = transforms.Compose([
   transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None),
   transforms.CenterCrop(size=(224, 224)),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# datasets
train_dataset = torchvision.datasets.ImageFolder(
    root=train_path,
    transform=imagenet_transform
)

val_dataset = torchvision.datasets.ImageFolder(
    root=val_path,
    transform=imagenet_transform
)

# dataloaders
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16)
val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16)

## Step 2: Setup PyTorch Training Loop

We will use this training loop below. This is standard PyTorch functionality.

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

def run_model_one_epoch(model, data_loader, criterion, device, train=False, optimizer=None):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    total_correct = 0
    total_predictions = 0

    # loop through batches
    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        if train:
            optimizer.zero_grad()

        # compute loss, run backpropogation
        outputs = model(inputs)  # model returns logits
        loss = criterion(outputs, labels)
        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

        # run evaluation
        predictions = outputs.argmax(dim=1)
        total_correct += torch.sum(predictions == labels).item()
        total_predictions += inputs.size(0)

    # return loss and evaluation metric
    loss = running_loss / (step + 1.0)
    accuracy = total_correct / total_predictions
    return loss, accuracy

cuda


## **Step 3: Train MobileNet-v2 on Beans**

First, we will train a dense version of MobileNetv2 on the Beans dataset.

In [7]:
# download pre-trained model, setup classification head
model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS)
model.to(device)

# setup loss function and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3) # lr will be override by sparseml

Next, we will use SparseML's recipes to set the hyperparameters of training loop. In this case, we will use the following recipe:

```yaml
# Epoch and Learning-Rate variables
num_epochs: 10.0
init_lr: 0.0005

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)
```

As you can see, the recipe includes an `!EpochRangeModifier` and a `!LearningRateFunctionModifier`. These modifiers simply set the number of epochs to train for and the learning rate schedule. As a result, the final model will be dense.

In [8]:
!cat ./recipes/mobilenetv2-beans-dense-recipe.yaml


# Epoch and Learning-Rate variables
num_epochs: 10.0
init_lr: 0.0005

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

In [9]:
dense_recipe_path = "./recipes/mobilenetv2-beans-dense-recipe.yaml"

Next, we use SparseML's `ScheduledModifierManager` to parse and apply the recipe. The `manager.modify` function modifies and wraps the `model` and `optimizer` with the instructions from the recipe. You can use the `model` and `optimizer` just like standard PyTorch objects.

In [10]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(dense_recipe_path)
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))

Kick off the transfer learning loop. Our run reached ~99% validation accuracy after 10 epochs.

In [11]:
# run transfer learning
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")

manager.finalize(model)

Running Training Epoch 1/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 1/10
Training Loss: 0.5145365252639308
Top 1 Acc: 0.8268858800773694

Running Validation Epoch 1/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 1/10
Val Loss: 0.13084974102675914
Top 1 Acc: 0.9398496240601504

Running Training Epoch 2/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 2/10
Training Loss: 0.1433423740619963
Top 1 Acc: 0.9497098646034816

Running Validation Epoch 2/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 2/10
Val Loss: 0.06366700865328312
Top 1 Acc: 0.9774436090225563

Running Training Epoch 3/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 3/10
Training Loss: 0.04070456024033554
Top 1 Acc: 0.9922630560928434

Running Validation Epoch 3/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 3/10
Val Loss: 0.06802571751177311
Top 1 Acc: 0.9774436090225563

Running Training Epoch 4/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 4/10
Training Loss: 0.03408790949844953
Top 1 Acc: 0.9903288201160542

Running Validation Epoch 4/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 4/10
Val Loss: 0.0302915308624506
Top 1 Acc: 0.9924812030075187

Running Training Epoch 5/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 5/10
Training Loss: 0.014188278910957954
Top 1 Acc: 0.9980657640232108

Running Validation Epoch 5/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 5/10
Val Loss: 0.036852234788239
Top 1 Acc: 0.9774436090225563

Running Training Epoch 6/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 6/10
Training Loss: 0.008212485654053815
Top 1 Acc: 0.9990328820116054

Running Validation Epoch 6/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 6/10
Val Loss: 0.023899264284409583
Top 1 Acc: 0.9924812030075187

Running Training Epoch 7/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 7/10
Training Loss: 0.003027547925188573
Top 1 Acc: 1.0

Running Validation Epoch 7/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 7/10
Val Loss: 0.02565102996304631
Top 1 Acc: 0.9924812030075187

Running Training Epoch 8/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 8/10
Training Loss: 0.003363958947981397
Top 1 Acc: 1.0

Running Validation Epoch 8/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 8/10
Val Loss: 0.0232504392741248
Top 1 Acc: 0.9849624060150376

Running Training Epoch 9/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 9/10
Training Loss: 0.004655612831949397
Top 1 Acc: 1.0

Running Validation Epoch 9/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 9/10
Val Loss: 0.027894998691044746
Top 1 Acc: 0.9849624060150376

Running Training Epoch 10/10


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 10/10
Training Loss: 0.01938260115937076
Top 1 Acc: 0.9990328820116054

Running Validation Epoch 10/10


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 10/10
Val Loss: 0.027051440824288876
Top 1 Acc: 0.9849624060150376



In [12]:
from sparseml.pytorch.utils import ModuleExporter

save_dir = "dense_model"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="mobilenet-v2-dense-beans.pth")

## Step 4: Prune The Model

With a model trained on Beans, we are now ready to apply the GMP algorithm to prune the model. The GMP algorithm is an interative pruning algorithm. At the end of each epoch, we identify the lowest magnitude weights (those closest to 0) and remove them from the network starting from an initial level of sparsity until a final level of sparsity. The remaining nonzero weights are then fine-tuned onto training dataset.

In [13]:
checkpoint = torch.load("./dense_model/training/mobilenet-v2-dense-beans.pth")
model = torchvision.models.mobilenet_v2()
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)

# setup loss function and optimizer, LR will be overriden by sparseml
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

Next, we need to create a SparseML recipe which includes the GMP algorithm. The `!GlobalMagnitudePruningModifier` modifier instructs SparseML to apply the GMP algorithm at a global level (pruning the lowest magnitude weights across all layers).

Firstly, we need to decide identify which parameters of the model to apply the GMP algorithm to. We can use the `get_prunable_layers` function to inspect:

In [14]:
# print parameters
for (name, layer) in get_prunable_layers(model):
    print(f"{name}")

features.0.0
features.1.conv.0.0
features.1.conv.1
features.2.conv.0.0
features.2.conv.1.0
features.2.conv.2
features.3.conv.0.0
features.3.conv.1.0
features.3.conv.2
features.4.conv.0.0
features.4.conv.1.0
features.4.conv.2
features.5.conv.0.0
features.5.conv.1.0
features.5.conv.2
features.6.conv.0.0
features.6.conv.1.0
features.6.conv.2
features.7.conv.0.0
features.7.conv.1.0
features.7.conv.2
features.8.conv.0.0
features.8.conv.1.0
features.8.conv.2
features.9.conv.0.0
features.9.conv.1.0
features.9.conv.2
features.10.conv.0.0
features.10.conv.1.0
features.10.conv.2
features.11.conv.0.0
features.11.conv.1.0
features.11.conv.2
features.12.conv.0.0
features.12.conv.1.0
features.12.conv.2
features.13.conv.0.0
features.13.conv.1.0
features.13.conv.2
features.14.conv.0.0
features.14.conv.1.0
features.14.conv.2
features.15.conv.0.0
features.15.conv.1.0
features.15.conv.2
features.16.conv.0.0
features.16.conv.1.0
features.16.conv.2
features.17.conv.0.0
features.17.conv.1.0
features.17.conv

We will apply pruning to each of the convs and exclude the classifier layer (which is the final projection head). Fortunately, SparseML allows us to pass regexes to identify layers in the network, so we can use the following list to identify the relevant layers for pruning:
    
    - 'features.0.0.weight'
    - 'features.18.0.weight'
    - 're:features.*.conv.*.weight'
    - 're:features.*.conv.*.*.weight'

Here is what the recipe looks like:

```yaml
# Epoch hyperparams
stabilization_epochs: 1.0
pruning_epochs: 9.0
finetuning_epochs: 5.0

# Learning rate hyperparams
init_lr: 0.0005
final_lr: 0.00025

# Pruning hyperparams
init_sparsity: 0.05
final_sparsity: 0.9

# Stabalization Stage
training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(stabilization_epochs + pruning_epochs + finetuning_epochs)
  
  - !SetLearningRateModifier
    start_epoch: 0.0
    learning_rate: eval(init_lr)

# Pruning Stage
pruning_modifiers:
  - !LearningRateFunctionModifier
    init_lr: eval(init_lr)
    final_lr: eval(final_lr)
    lr_func: cosine
    start_epoch: eval(stabilization_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs)
    
  - !GlobalMagnitudePruningModifier
    init_sparsity: eval(init_sparsity)
    final_sparsity: eval(final_sparsity)
    start_epoch: eval(stabilization_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs)
    update_frequency: 0.5
    params:        
        - 'features.0.0.weight'
        - 'features.18.0.weight'
        - 're:features.*.conv.*.weight'
        - 're:features.*.conv.*.*.weight'
    leave_enabled: True

# Finetuning Stage
finetuning_modifiers:
  - !LearningRateFunctionModifier
    init_lr: eval(init_lr)
    final_lr: eval(final_lr)
    lr_func: cosine
    start_epoch: eval(stabilization_epochs + pruning_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs + finetuning_epochs)
```

This recipe specifies that we will run the GMP algorithm for 9 epochs after 1 warmup epoch. We start at an init_sparsity level of 5% and gradually increase sparsity to a final_sparsity level of 90% following a cubic curve. The pruning is applied in an unstructured manner, meaning that any weight can be pruned.

Over the final 5 epochs, we will fine-tune the 90% pruned model further. Since we set leave_enabled=True the sparsity level will be maintained as the fine-tuning occurs.

In [15]:
!cat ./recipes/mobilenetv2-beans-pruning-recipe.yaml

# Epoch hyperparams
stabilization_epochs: 1.0
pruning_epochs: 9.0
finetuning_epochs: 5.0

# Learning rate hyperparams
init_lr: 0.0005
final_lr: 0.00025

# Pruning hyperparams
init_sparsity: 0.05
final_sparsity: 0.9

# Stabalization Stage
training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(stabilization_epochs + pruning_epochs + finetuning_epochs)
  
  - !SetLearningRateModifier
    start_epoch: 0.0
    learning_rate: eval(init_lr)

# Pruning Stage
pruning_modifiers:
  - !LearningRateFunctionModifier
    init_lr: eval(init_lr)
    final_lr: eval(final_lr)
    lr_func: cosine
    start_epoch: eval(stabilization_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs)
    
  - !GlobalMagnitudePruningModifier
    init_sparsity: eval(init_sparsity)
    final_sparsity: eval(final_sparsity)
    start_epoch: eval(stabilization_epochs)
    end_epoch: eval(stabilization_epochs + pruning_epochs)
    update_frequency: 

In [16]:
pruning_recipe_path = "./recipes/mobilenetv2-beans-pruning-recipe.yaml"

In [17]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(pruning_recipe_path)
logger = TensorBoardLogger(log_path="./tensorboard_outputs")
optimizer = manager.modify(model, optimizer, loggers=[logger], steps_per_epoch=len(train_loader))

Next, kick off the GMP training loop. 

As you can see, we use the wrapped `optimizer` and `model` in the same way as above. SparseML parsed the recipe and updated the `optimizer` with the logic of GMP algorithm from the recipe. This allows you to just the `optimizer` and `model` as usual, with all of the pruning-related logic specified by the declarative recipe interface.

Our 90% pruned model reaches ~99% validation accuracy (vs ~98% for the dense model).

In [18]:
# run GMP algorithm
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")
    
    logger.log_scalar("Metrics/Loss (Train)", train_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Train)", train_acc, epoch)
    logger.log_scalar("Metrics/Loss (Validation)", val_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Validation)", val_acc, epoch)
    
manager.finalize(model)

Running Training Epoch 1/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 1/15
Training Loss: 0.04981424145058334
Top 1 Acc: 0.988394584139265

Running Validation Epoch 1/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 1/15
Val Loss: 0.06650052629411221
Top 1 Acc: 0.9774436090225563

Running Training Epoch 2/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 2/15
Training Loss: 0.08335415836226082
Top 1 Acc: 0.9777562862669246

Running Validation Epoch 2/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 2/15
Val Loss: 0.06197752933949232
Top 1 Acc: 0.9774436090225563

Running Training Epoch 3/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 3/15
Training Loss: 0.040097551427386476
Top 1 Acc: 0.9854932301740812

Running Validation Epoch 3/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 3/15
Val Loss: 0.05052837722469121
Top 1 Acc: 0.9849624060150376

Running Training Epoch 4/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 4/15
Training Loss: 0.03676384532203277
Top 1 Acc: 0.9893617021276596

Running Validation Epoch 4/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 4/15
Val Loss: 0.08736732564866542
Top 1 Acc: 0.9624060150375939

Running Training Epoch 5/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 5/15
Training Loss: 0.03125915414244501
Top 1 Acc: 0.9932301740812379

Running Validation Epoch 5/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 5/15
Val Loss: 0.09458159506320954
Top 1 Acc: 0.9924812030075187

Running Training Epoch 6/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 6/15
Training Loss: 0.0748459067421429
Top 1 Acc: 0.9796905222437138

Running Validation Epoch 6/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 6/15
Val Loss: 0.20114150643348694
Top 1 Acc: 0.9699248120300752

Running Training Epoch 7/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 7/15
Training Loss: 0.10897209077621951
Top 1 Acc: 0.965183752417795

Running Validation Epoch 7/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 7/15
Val Loss: 0.3452451154589653
Top 1 Acc: 0.8421052631578947

Running Training Epoch 8/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 8/15
Training Loss: 0.13403249407807985
Top 1 Acc: 0.9506769825918762

Running Validation Epoch 8/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 8/15
Val Loss: 0.20368995219469072
Top 1 Acc: 0.9323308270676691

Running Training Epoch 9/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 9/15
Training Loss: 0.12121003064693826
Top 1 Acc: 0.9671179883945842

Running Validation Epoch 9/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 9/15
Val Loss: 0.11063596755266189
Top 1 Acc: 0.9774436090225563

Running Training Epoch 10/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 10/15
Training Loss: 0.09124855725376894
Top 1 Acc: 0.9642166344294004

Running Validation Epoch 10/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 10/15
Val Loss: 0.0751135416328907
Top 1 Acc: 0.9699248120300752

Running Training Epoch 11/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 11/15
Training Loss: 0.04337389501884128
Top 1 Acc: 0.9912959381044487

Running Validation Epoch 11/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 11/15
Val Loss: 0.034493994899094105
Top 1 Acc: 0.9774436090225563

Running Training Epoch 12/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 12/15
Training Loss: 0.018291185736994852
Top 1 Acc: 0.9980657640232108

Running Validation Epoch 12/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 12/15
Val Loss: 0.04321025339886546
Top 1 Acc: 0.9774436090225563

Running Training Epoch 13/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 13/15
Training Loss: 0.0146882157899778
Top 1 Acc: 0.9970986460348162

Running Validation Epoch 13/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 13/15
Val Loss: 0.03916811663657427
Top 1 Acc: 0.9924812030075187

Running Training Epoch 14/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 14/15
Training Loss: 0.014786866712242816
Top 1 Acc: 0.9990328820116054

Running Validation Epoch 14/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 14/15
Val Loss: 0.040105257043614985
Top 1 Acc: 0.9849624060150376

Running Training Epoch 15/15


  0%|          | 0/33 [00:00<?, ?it/s]

Training Epoch: 15/15
Training Loss: 0.007484704270168688
Top 1 Acc: 1.0

Running Validation Epoch 15/15


  0%|          | 0/5 [00:00<?, ?it/s]

Validation Epoch: 15/15
Val Loss: 0.02973868220578879
Top 1 Acc: 0.9774436090225563



In [19]:
for (name, layer) in get_prunable_layers(model):
    print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}")

features.0.0.weight: 0.4896
features.1.conv.0.0.weight: 0.4722
features.1.conv.1.weight: 0.5234
features.2.conv.0.0.weight: 0.5501
features.2.conv.1.0.weight: 0.2859
features.2.conv.2.weight: 0.4813
features.3.conv.0.0.weight: 0.6137
features.3.conv.1.0.weight: 0.4977
features.3.conv.2.weight: 0.6131
features.4.conv.0.0.weight: 0.5214
features.4.conv.1.0.weight: 0.2785
features.4.conv.2.weight: 0.5184
features.5.conv.0.0.weight: 0.6903
features.5.conv.1.0.weight: 0.5804
features.5.conv.2.weight: 0.7134
features.6.conv.0.0.weight: 0.6929
features.6.conv.1.0.weight: 0.6117
features.6.conv.2.weight: 0.7383
features.7.conv.0.0.weight: 0.5493
features.7.conv.1.0.weight: 0.3588
features.7.conv.2.weight: 0.6127
features.8.conv.0.0.weight: 0.8385
features.8.conv.1.0.weight: 0.7396
features.8.conv.2.weight: 0.8687
features.9.conv.0.0.weight: 0.8278
features.9.conv.1.0.weight: 0.7376
features.9.conv.2.weight: 0.8643
features.10.conv.0.0.weight: 0.8293
features.10.conv.1.0.weight: 0.7876
features

Finally, export your model to ONNX.

In [21]:
save_dir = "experiment-0"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="mobilenet_v2-sparse-beans.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="sparse-model.onnx", convert_qat=True)

## Wrap Up

The resulting model is is 90% sparse and achieves validation accuracy of ~99% (vs the unoptimized dense model at ~99%) without much hyperparameter search.

Key hyperparameter experiments you may want to run include:
- Learning rate
- Learning rate schedule
- Sparsity level
- Number of pruning epochs

DeepSparse supports speedup from pruning and quantization. To reach maximum performance, check out our examples of quantizing a model!