# Sparsifying DenseNet121 from Scratch (Flower102)

In this example, we will demonstrate how to sparsify an image classification model from scratch using SparseML's PyTorch integration. We train and prune [DenseNet121](https://pytorch.org/vision/main/models/generated/torchvision.models.densenet121.html) on the downstream [Oxford Flower 102 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Flowers102.html#:~:text=Oxford%20102%20Flower%20is%20an,scale%2C%20pose%20and%20light%20variations) using the Global Magnitude Pruning algorithm. 

## Agenda

There are a few steps:

 1. Setup the dataset
 2. Setup the PyTorch training loop
 3. Train a dense version of DenseNet121
 4. Run the GMP pruning algorithm on the dense model
 
## Installation

Install SparseML with `pip`:

```
pip install sparseml[torchvision]
```

In [None]:
import torch
import sparseml
import torchvision
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import TensorBoardLogger, ModuleExporter, get_prunable_layers, tensor_sparsity

from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

from torchvision import transforms
from tqdm.auto import tqdm
import math

## **Step 1: Setup Dataset**

Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of between 40 and 258 images. The images have large scale, pose and light variations. In addition, there are categories that have large variations within the category, and several very similar categories.

We use the standard PyTorch `datasets` and `dataloaders` to manage the dataset.

In [2]:
NUM_LABELS = 102
BATCH_SIZE = 32

# imagenet transforms
imagenet_transform = transforms.Compose([
   transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None),
   transforms.CenterCrop(size=(224, 224)),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# datasets
train_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="train",
    transform=imagenet_transform,
    download=True
)

val_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="val",
    transform=imagenet_transform,
    download=True
)

# dataloaders
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16)
val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16)

## Step 2: Setup PyTorch Training Loop

We will use this training loop below. This is standard PyTorch functionality.

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

def run_model_one_epoch(model, data_loader, criterion, device, train=False, optimizer=None):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    total_correct = 0
    total_predictions = 0

    # loop through batches
    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        if train:
            optimizer.zero_grad()

        # compute loss, run backpropogation
        outputs = model(inputs)  # model returns logits
        loss = criterion(outputs, labels)
        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

        # run evaluation
        predictions = outputs.argmax(dim=1)
        total_correct += torch.sum(predictions == labels).item()
        total_predictions += inputs.size(0)

    # return loss and evaluation metric
    loss = running_loss / (step + 1.0)
    accuracy = total_correct / total_predictions
    return loss, accuracy

cuda


## **Step 3: Train DenseNet121 on Flowers102**

First, we will train a dense version of DenseNet121 on the Flowers dataset.

In [4]:
# download pre-trained model, setup classification head
model = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.DEFAULT)
model.classifier = torch.nn.Linear(model.classifier.in_features, NUM_LABELS)
model.to(device)

# setup loss function and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3) # lr will be override by sparseml

Next, we will use SparseML's recipes to set the hyperparameters of training loop. In this case, we will use the following recipe:

```yaml
# Epoch and Learning-Rate variables
num_epochs: 10.0
init_lr: 0.0005

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)
```

As you can see, the recipe includes an `!EpochRangeModifier` and a `!LearningRateFunctionModifier`. These modifiers simply set the number of epochs to train for and the learning rate schedule. As a result, the final model will be dense.

In [5]:
dense_recipe_path = "./recipes/densenet-flowers-dense-recipe.yaml"

In [6]:
!cat ./recipes/densenet-flowers-dense-recipe.yaml


# Epoch and Learning-Rate variables
num_epochs: 10.0
init_lr: 0.0005

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)


Next, we use SparseML's `ScheduledModifierManager` to parse and apply the recipe. The `manager.modify` function modifies and wraps the `model` and `optimizer` with the instructions from the recipe. You can use the `model` and `optimizer` just like standard PyTorch objects.

In [7]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(dense_recipe_path)
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))

Kick off the transfer learning loop. Our run reached ~92.5% validation accuracy after 10 epochs.

In [8]:
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")

# clean up
manager.finalize(model)

Running Training Epoch 1/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 1/10
Training Loss: 3.6993999630212784
Top 1 Acc: 0.2823529411764706

Running Validation Epoch 1/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 1/10
Val Loss: 2.32899521663785
Top 1 Acc: 0.6235294117647059

Running Training Epoch 2/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 2/10
Training Loss: 1.3755246959626675
Top 1 Acc: 0.8823529411764706

Running Validation Epoch 2/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 2/10
Val Loss: 1.0808461774140596
Top 1 Acc: 0.85

Running Training Epoch 3/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 3/10
Training Loss: 0.4065576898865402
Top 1 Acc: 0.9882352941176471

Running Validation Epoch 3/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 3/10
Val Loss: 0.7663819249719381
Top 1 Acc: 0.8823529411764706

Running Training Epoch 4/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 4/10
Training Loss: 0.13117457507178187
Top 1 Acc: 1.0

Running Validation Epoch 4/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 4/10
Val Loss: 0.5690533611923456
Top 1 Acc: 0.9127450980392157

Running Training Epoch 5/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 5/10
Training Loss: 0.05944129719864577
Top 1 Acc: 1.0

Running Validation Epoch 5/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 5/10
Val Loss: 0.4922205451875925
Top 1 Acc: 0.9147058823529411

Running Training Epoch 6/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 6/10
Training Loss: 0.04032939835451543
Top 1 Acc: 1.0

Running Validation Epoch 6/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 6/10
Val Loss: 0.47672522626817226
Top 1 Acc: 0.9205882352941176

Running Training Epoch 7/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 7/10
Training Loss: 0.03349028015509248
Top 1 Acc: 1.0

Running Validation Epoch 7/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 7/10
Val Loss: 0.4664669381454587
Top 1 Acc: 0.9176470588235294

Running Training Epoch 8/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 8/10
Training Loss: 0.03036679228534922
Top 1 Acc: 1.0

Running Validation Epoch 8/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 8/10
Val Loss: 0.45540827978402376
Top 1 Acc: 0.9264705882352942

Running Training Epoch 9/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 9/10
Training Loss: 0.028963472577743232
Top 1 Acc: 1.0

Running Validation Epoch 9/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 9/10
Val Loss: 0.4513796488754451
Top 1 Acc: 0.9235294117647059

Running Training Epoch 10/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 10/10
Training Loss: 0.027131778770126402
Top 1 Acc: 1.0

Running Validation Epoch 10/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 10/10
Val Loss: 0.45484112948179245
Top 1 Acc: 0.9245098039215687



Export the model in case we want to reload in the future, so we do not have to rerun.

In [None]:
save_dir = "dense_model"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="densenet-121-dense-flowers.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="dense-model.onnx", convert_qat=True)

In [10]:
torch.cuda.empty_cache()

## Step 4: Prune The Model

With a model trained on Flowers, we are now ready to apply the GMP algorithm to prune the model. The GMP algorithm is an interative pruning algorithm. At the end of each epoch, we identify the lowest magnitude weights (those closest to 0) and remove them from the network starting from an initial level of sparsity until a final level of sparsity. The remaining nonzero weights are then fine-tuned onto training dataset.

In [45]:
# first, load the trained model from Part 3
checkpoint = torch.load("./dense_model/training/densenet-121-dense-flowers.pth")
model = torchvision.models.densenet121()
model.classifier = torch.nn.Linear(model.classifier.in_features, NUM_LABELS)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)

# setup loss function and optimizer, LR will be overriden by sparseml
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

Next, we need to create a SparseML recipe which includes the GMP algorithm. The `!GlobalMagnitudePruningModifier` modifier instructs SparseML to apply the GMP algorithm at a global level (pruning the lowest magnitude weights across all layers).

Firstly, we need to decide identify which parameters of the model to apply the GMP algorithm to. We can use the `get_prunable_layers` function to inspect:

In [46]:
# print parameters
for (name, layer) in get_prunable_layers(model):
    print(f"{name}")

features.conv0
features.denseblock1.denselayer1.conv1
features.denseblock1.denselayer1.conv2
features.denseblock1.denselayer2.conv1
features.denseblock1.denselayer2.conv2
features.denseblock1.denselayer3.conv1
features.denseblock1.denselayer3.conv2
features.denseblock1.denselayer4.conv1
features.denseblock1.denselayer4.conv2
features.denseblock1.denselayer5.conv1
features.denseblock1.denselayer5.conv2
features.denseblock1.denselayer6.conv1
features.denseblock1.denselayer6.conv2
features.transition1.conv
features.denseblock2.denselayer1.conv1
features.denseblock2.denselayer1.conv2
features.denseblock2.denselayer2.conv1
features.denseblock2.denselayer2.conv2
features.denseblock2.denselayer3.conv1
features.denseblock2.denselayer3.conv2
features.denseblock2.denselayer4.conv1
features.denseblock2.denselayer4.conv2
features.denseblock2.denselayer5.conv1
features.denseblock2.denselayer5.conv2
features.denseblock2.denselayer6.conv1
features.denseblock2.denselayer6.conv2
features.denseblock2.de

We will apply pruning to each of the `convs` and exclude the `classifier` (which is the final projection head). Fortunately, SparseML allows us to pass regexes to identify layers in the network, so we can use the following list to identify the relevant layers for pruning:

    - 'features.conv0.weight'
    - 're:features.denseblock1.*.conv1.weight'
    - 're:features.denseblock1.*.conv2.weight'
    - 're:features.transition1.conv.weight'
    - 're:features.denseblock2.*.conv1.weight'
    - 're:features.denseblock2.*.conv2.weight'
    - 're:features.transition2.conv.weight'
    - 're:features.denseblock3.*.conv1.weight'
    - 're:features.denseblock3.*.conv2.weight'
    - 're:features.transition3.conv.weight'
    - 're:features.denseblock4.*.conv1.weight'
    - 're:features.denseblock4.*.conv2.weight'

Here is what the recipe looks like:

```yaml
# Epoch and Learning-Rate variables
num_epochs: 13.0
pruning_epochs: 10.0
init_lr: 0.00025
final_lr: 0.0001
inter_func: cubic
mask_type: unstructured

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: eval(final_lr)
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(pruning_epochs)
    
  - !LearningRateFunctionModifier
    final_lr: eval(final_lr)
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: eval(pruning_epochs)
    end_epoch: eval(num_epochs)

# Pruning
pruning_modifiers:
  - !GlobalMagnitudePruningModifier
    init_sparsity: 0.05
    final_sparsity: 0.85
    start_epoch: 0.0
    end_epoch: eval(pruning_epochs)
    update_frequency: 0.5
    params: 
        - 'features.conv0.weight'
        - 're:features.denseblock1.*.conv1.weight'
        - 're:features.denseblock1.*.conv2.weight'
        - 're:features.transition1.conv.weight'
        - 're:features.denseblock2.*.conv1.weight'
        - 're:features.denseblock2.*.conv2.weight'
        - 're:features.transition2.conv.weight'
        - 're:features.denseblock3.*.conv1.weight'
        - 're:features.denseblock3.*.conv2.weight'
        - 're:features.transition3.conv.weight'
        - 're:features.denseblock4.*.conv1.weight'
        - 're:features.denseblock4.*.conv2.weight'
    leave_enabled: True
    inter_func: eval(inter_func)
    mask_type: eval(mask_type)
```

This recipe specifies that we will run the GMP algorithm for the first 10 epochs. We start at an `init_sparsity` level of 5% and gradually increase sparsity to a `final_sparsity` level of 85% following a `cubic` curve. The pruning is applied in an `unstructured` manner, meaning that any weight can be pruned.

Over the final 3 epochs, we will fine-tune the 85% pruned model further. Since we set `leave_enabled=True` the sparsity level will be maintained as the fine-tuning occurs.



In [47]:
pruning_recipe_path = "./recipes/densenet-flowers-pruning-recipe.yaml"

In [None]:
!cat ./recipes/densenet-flowers-pruning-recipe.yaml

In [49]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(pruning_recipe_path)
logger = TensorBoardLogger(log_path="./tensorboard_outputs")
optimizer = manager.modify(model, optimizer, loggers=[logger], steps_per_epoch=len(train_loader))

Next, kick off the GMP training loop. 

As you can see, we use the wrapped `optimizer` and `model` in the same way as above. SparseML parsed the recipe and updated the `optimizer` with the logic of GMP algorithm from the recipe. This allows you to just the `optimizer` and `model` as usual, with all of the pruning-related logic specified by the declarative recipe interface.

Our 85% sparsified model reaches ~91.5% validation accuracy (vs ~92.5% for the dense model).

In [50]:
# run GMP algorithm
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")
    
    logger.log_scalar("Metrics/Loss (Train)", train_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Train)", train_acc, epoch)
    logger.log_scalar("Metrics/Loss (Validation)", val_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Validation)", val_acc, epoch)

manager.finalize(model)

Running Training Epoch 1/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 1/13
Training Loss: 0.043860748992301524
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 1/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 1/13
Val Loss: 0.49546412169001997
Top 1 Acc: 0.8950980392156863

Running Training Epoch 2/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 2/13
Training Loss: 0.031713149510324
Top 1 Acc: 0.9980392156862745

Running Validation Epoch 2/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 2/13
Val Loss: 0.380035838810727
Top 1 Acc: 0.9117647058823529

Running Training Epoch 3/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 3/13
Training Loss: 0.016048584191594273
Top 1 Acc: 1.0

Running Validation Epoch 3/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 3/13
Val Loss: 0.3518236926756799
Top 1 Acc: 0.9156862745098039

Running Training Epoch 4/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 4/13
Training Loss: 0.013376500370213762
Top 1 Acc: 1.0

Running Validation Epoch 4/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 4/13
Val Loss: 0.37410336383618414
Top 1 Acc: 0.9098039215686274

Running Training Epoch 5/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 5/13
Training Loss: 0.01488300270284526
Top 1 Acc: 1.0

Running Validation Epoch 5/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 5/13
Val Loss: 0.3517874537501484
Top 1 Acc: 0.9156862745098039

Running Training Epoch 6/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 6/13
Training Loss: 0.020687898708274588
Top 1 Acc: 1.0

Running Validation Epoch 6/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 6/13
Val Loss: 0.38692406262271106
Top 1 Acc: 0.9137254901960784

Running Training Epoch 7/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 7/13
Training Loss: 0.029207701765699312
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 7/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 7/13
Val Loss: 0.4347131696995348
Top 1 Acc: 0.8990196078431373

Running Training Epoch 8/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 8/13
Training Loss: 0.05598480207845569
Top 1 Acc: 0.9950980392156863

Running Validation Epoch 8/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 8/13
Val Loss: 0.5028206340502948
Top 1 Acc: 0.8813725490196078

Running Training Epoch 9/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 9/13
Training Loss: 0.030482763570034876
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 9/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 9/13
Val Loss: 0.4592747155111283
Top 1 Acc: 0.8941176470588236

Running Training Epoch 10/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 10/13
Training Loss: 0.017036149831255898
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 10/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 10/13
Val Loss: 0.40747271745931357
Top 1 Acc: 0.9019607843137255

Running Training Epoch 11/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 11/13
Training Loss: 0.008050348522374406
Top 1 Acc: 1.0

Running Validation Epoch 11/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 11/13
Val Loss: 0.36801655124872923
Top 1 Acc: 0.9068627450980392

Running Training Epoch 12/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 12/13
Training Loss: 0.005260431091301143
Top 1 Acc: 1.0

Running Validation Epoch 12/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 12/13
Val Loss: 0.3571971161291003
Top 1 Acc: 0.9117647058823529

Running Training Epoch 13/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 13/13
Training Loss: 0.004477048831176944
Top 1 Acc: 1.0

Running Validation Epoch 13/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 13/13
Val Loss: 0.3466116476338357
Top 1 Acc: 0.9137254901960784



Here is a sample of the TensorBoard output, showing the validation accuracy, a particular layer's sparsity level, and the learning rate over time.

![tensorboard output](./images/densenet-flowers-tensorboard-output.png)

We can print layer-by-layer sparsity as well.

In [55]:
for (name, layer) in get_prunable_layers(model):
    print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}")

features.conv0.weight: 0.4184
features.denseblock1.denselayer1.conv1.weight: 0.6761
features.denseblock1.denselayer1.conv2.weight: 0.7671
features.denseblock1.denselayer2.conv1.weight: 0.6769
features.denseblock1.denselayer2.conv2.weight: 0.8130
features.denseblock1.denselayer3.conv1.weight: 0.6814
features.denseblock1.denselayer3.conv2.weight: 0.7528
features.denseblock1.denselayer4.conv1.weight: 0.7571
features.denseblock1.denselayer4.conv2.weight: 0.7544
features.denseblock1.denselayer5.conv1.weight: 0.8184
features.denseblock1.denselayer5.conv2.weight: 0.8309
features.denseblock1.denselayer6.conv1.weight: 0.7914
features.denseblock1.denselayer6.conv2.weight: 0.7735
features.transition1.conv.weight: 0.6437
features.denseblock2.denselayer1.conv1.weight: 0.9088
features.denseblock2.denselayer1.conv2.weight: 0.8509
features.denseblock2.denselayer2.conv1.weight: 0.8362
features.denseblock2.denselayer2.conv2.weight: 0.7972
features.denseblock2.denselayer3.conv1.weight: 0.8374
features.de

Finally, export your model to ONNX.

In [53]:
save_dir = "experiment-0"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="densenet-121-sparse-flowers.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="sparse-model.onnx", convert_qat=True)

## Wrap Up

The resulting model is is 85% sparse and achieves validation accuracy of ~91.5% (vs the unoptimized dense model at ~92.5%) without much hyperparameter search.

Key hyperparameter experiments you may want to run include:
- Learning rate
- Learning rate schedule
- Sparsity level
- Number of pruning epochs


DeepSparse supports speedup from pruning and quantization. To reach maximum performance, check out our examples of quantizing a model!