# Sparsifying DenseNet121 from Scratch (Flower102)

In this example, we will demonstrate how to sparsify an image classification model from scratch using SparseML's PyTorch integration. We train and prune [DenseNet121](https://pytorch.org/vision/main/models/generated/torchvision.models.densenet121.html) on the downstream [Oxford Flower 102 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Flowers102.html#:~:text=Oxford%20102%20Flower%20is%20an,scale%2C%20pose%20and%20light%20variations) using the Global Magnitude Pruning algorithm. 

## Agenda

There are a few steps:

 1. Setup the dataset
 2. Setup the PyTorch training loop
 3. Train a dense version of DenseNet121
 4. Run the GMP pruning algorithm on the dense model
 
## Installation

Install SparseML with `pip`:

```
pip install sparseml[torchvision]
```

In [1]:
import torch
import sparseml
import torchvision
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import TensorBoardLogger, ModuleExporter, get_prunable_layers, tensor_sparsity

from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

from torchvision import transforms
from tqdm.auto import tqdm
import math

## **Step 1: Setup Dataset**

Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of between 40 and 258 images. The images have large scale, pose and light variations. In addition, there are categories that have large variations within the category, and several very similar categories.

We use the standard PyTorch `datasets` and `dataloaders` to manage the dataset.

In [2]:
NUM_LABELS = 102
BATCH_SIZE = 16

# imagenet transforms
imagenet_transform = transforms.Compose([
   transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None),
   transforms.CenterCrop(size=(224, 224)),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# datasets
train_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="train",
    transform=imagenet_transform,
    download=True
)

val_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="val",
    transform=imagenet_transform,
    download=True
)

# dataloaders
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16)
val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16)

## Step 2: Setup PyTorch Training Loop

We will use this training loop below. This is standard PyTorch functionality.

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

def run_model_one_epoch(model, data_loader, criterion, device, train=False, optimizer=None):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    total_correct = 0
    total_predictions = 0

    # loop through batches
    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        if train:
            optimizer.zero_grad()

        # compute loss, run backpropogation
        outputs = model(inputs)  # model returns logits
        loss = criterion(outputs, labels)
        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

        # run evaluation
        predictions = outputs.argmax(dim=1)
        total_correct += torch.sum(predictions == labels).item()
        total_predictions += inputs.size(0)

    # return loss and evaluation metric
    loss = running_loss / (step + 1.0)
    accuracy = total_correct / total_predictions
    return loss, accuracy

cuda


## **Step 3: Train DenseNet121 on Flowers102**

First, we will train a dense version of DenseNet121 on the Flowers dataset.

In [4]:
# download pre-trained model, setup classification head
model = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.DEFAULT)
model.classifier = torch.nn.Linear(model.classifier.in_features, NUM_LABELS)
model.to(device)

# setup loss function and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3) # lr will be override by sparseml

Next, we will use SparseML's recipes to set the hyperparameters of training loop. In this case, we will use the following recipe:

```yaml
# Epoch and Learning-Rate variables
num_epochs: 10.0
init_lr: 0.0005

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)
```

As you can see, the recipe includes an `!EpochRangeModifier` and a `!LearningRateFunctionModifier`. These modifiers simply set the number of epochs to train for and the learning rate schedule. As a result, the final model will be dense.

In [5]:
dense_recipe_path = "./recipes/densenet-flowers-dense-recipe.yaml"

In [None]:
!cat ./recipes/densenet-flowers-dense-recipe.yaml

Next, we use SparseML's `ScheduledModifierManager` to parse and apply the recipe. The `manager.modify` function modifies and wraps the `model` and `optimizer` with the instructions from the recipe. You can use the `model` and `optimizer` just like standard PyTorch objects.

In [6]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(dense_recipe_path)
optimizer = manager.modify(model, optimizer, steps_per_epoch=len(train_loader))

Kick off the transfer learning loop. Our run reached ~91.8% validation accuracy after 10 epochs.

In [7]:
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")

# clean up
manager.finalize(model)

Running Training Epoch 1/10


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 1/10
Training Loss: 3.705597158521414
Top 1 Acc: 0.25392156862745097

Running Validation Epoch 1/10


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 1/10
Val Loss: 2.1234971778467298
Top 1 Acc: 0.5637254901960784

Running Training Epoch 2/10


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 2/10
Training Loss: 1.32985198777169
Top 1 Acc: 0.8382352941176471

Running Validation Epoch 2/10


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 2/10
Val Loss: 0.956481144297868
Top 1 Acc: 0.8303921568627451

Running Training Epoch 3/10


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 3/10
Training Loss: 0.4278648835606873
Top 1 Acc: 0.9803921568627451

Running Validation Epoch 3/10


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 3/10
Val Loss: 0.6528619016171433
Top 1 Acc: 0.8578431372549019

Running Training Epoch 4/10


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 4/10
Training Loss: 0.15628680447116494
Top 1 Acc: 0.9970588235294118

Running Validation Epoch 4/10


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 4/10
Val Loss: 0.5102186386066023
Top 1 Acc: 0.9049019607843137

Running Training Epoch 5/10


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 5/10
Training Loss: 0.060788551229052246
Top 1 Acc: 1.0

Running Validation Epoch 5/10


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 5/10
Val Loss: 0.42239774897461757
Top 1 Acc: 0.9107843137254902

Running Training Epoch 6/10


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 6/10
Training Loss: 0.0473741386376787
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 6/10


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 6/10
Val Loss: 0.40489213640103117
Top 1 Acc: 0.9098039215686274

Running Training Epoch 7/10


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 7/10
Training Loss: 0.03090776341559831
Top 1 Acc: 1.0

Running Validation Epoch 7/10


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 7/10
Val Loss: 0.39284458926704247
Top 1 Acc: 0.9166666666666666

Running Training Epoch 8/10


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 8/10
Training Loss: 0.027710435475455597
Top 1 Acc: 1.0

Running Validation Epoch 8/10


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 8/10
Val Loss: 0.3832304854586255
Top 1 Acc: 0.9176470588235294

Running Training Epoch 9/10


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 9/10
Training Loss: 0.026527045309194364
Top 1 Acc: 1.0

Running Validation Epoch 9/10


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 9/10
Val Loss: 0.36842548519780394
Top 1 Acc: 0.9186274509803921

Running Training Epoch 10/10


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 10/10
Training Loss: 0.025542080067680217
Top 1 Acc: 1.0

Running Validation Epoch 10/10


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 10/10
Val Loss: 0.3675568748876685
Top 1 Acc: 0.9186274509803921



Export the model in case we want to reload in the future, so we do not have to rerun.

In [None]:
save_dir = "densenet-models"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="dense-model.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="dense-model.onnx", convert_qat=True)

In [16]:
torch.cuda.empty_cache()

## Step 4: Prune The Model

With a model trained on Flowers, we are now ready to apply the GMP algorithm to prune the model. The GMP algorithm is an interative pruning algorithm. At the end of each epoch, we identify the lowest magnitude weights (those closest to 0) and remove them from the network starting from an initial level of sparsity until a final level of sparsity. The remaining nonzero weights are then fine-tuned onto training dataset.

In [17]:
# first, load the trained model from Part 3
checkpoint = torch.load("./densenet-models/training/dense-model.pth")
model = torchvision.models.densenet121()
model.classifier = torch.nn.Linear(model.classifier.in_features, NUM_LABELS)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)

# setup loss function and optimizer, LR will be overriden by sparseml
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

Next, we need to create a SparseML recipe which includes the GMP algorithm. The `!GlobalMagnitudePruningModifier` modifier instructs SparseML to apply the GMP algorithm at a global level (pruning the lowest magnitude weights across all layers).

Firstly, we need to decide identify which parameters of the model to apply the GMP algorithm to. We can use the `get_prunable_layers` function to inspect:

In [18]:
# print parameters
for (name, layer) in get_prunable_layers(model):
    print(f"{name}")

features.conv0
features.denseblock1.denselayer1.conv1
features.denseblock1.denselayer1.conv2
features.denseblock1.denselayer2.conv1
features.denseblock1.denselayer2.conv2
features.denseblock1.denselayer3.conv1
features.denseblock1.denselayer3.conv2
features.denseblock1.denselayer4.conv1
features.denseblock1.denselayer4.conv2
features.denseblock1.denselayer5.conv1
features.denseblock1.denselayer5.conv2
features.denseblock1.denselayer6.conv1
features.denseblock1.denselayer6.conv2
features.transition1.conv
features.denseblock2.denselayer1.conv1
features.denseblock2.denselayer1.conv2
features.denseblock2.denselayer2.conv1
features.denseblock2.denselayer2.conv2
features.denseblock2.denselayer3.conv1
features.denseblock2.denselayer3.conv2
features.denseblock2.denselayer4.conv1
features.denseblock2.denselayer4.conv2
features.denseblock2.denselayer5.conv1
features.denseblock2.denselayer5.conv2
features.denseblock2.denselayer6.conv1
features.denseblock2.denselayer6.conv2
features.denseblock2.de

We will apply pruning to each of the `convs` and exclude the `classifier` (which is the final projection head). Fortunately, SparseML allows us to pass regexes to identify layers in the network.

Here is what the recipe looks like:

```yaml
# Epoch and Learning-Rate variables
num_epochs: 13.0
pruning_epochs: 10.0
init_lr: 0.0005
final_lr: 0.0001
inter_func: cubic
mask_type: unstructured

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: eval(final_lr)
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(pruning_epochs)
    
  - !LearningRateFunctionModifier
    final_lr: eval(final_lr)
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: eval(pruning_epochs)
    end_epoch: eval(num_epochs)

# Pruning
pruning_modifiers:
  - !GlobalMagnitudePruningModifier
    init_sparsity: 0.05
    final_sparsity: 0.9
    start_epoch: 0.0
    end_epoch: eval(pruning_epochs)
    update_frequency: 0.5
    params: 
        - 'features.conv0.weight'
        - 're:features.denseblock1.*.conv1.weight'
        - 're:features.denseblock1.*.conv2.weight'
        - 're:features.transition1.conv.weight'
        - 're:features.denseblock2.*.conv1.weight'
        - 're:features.denseblock2.*.conv2.weight'
        - 're:features.transition2.conv.weight'
        - 're:features.denseblock3.*.conv1.weight'
        - 're:features.denseblock3.*.conv2.weight'
        - 're:features.transition3.conv.weight'
        - 're:features.denseblock4.*.conv1.weight'
        - 're:features.denseblock4.*.conv2.weight'
    leave_enabled: True
    inter_func: eval(inter_func)
    mask_type: eval(mask_type)
```

This recipe specifies that we will run the GMP algorithm for the first 10 epochs. We start at an `init_sparsity` level of 5% and gradually increase sparsity to a `final_sparsity` level of 90% following a `cubic` curve across each of the `convs` in the network. The pruning is applied in an `unstructured` manner, meaning that any weight can be pruned in any pattern.

Over the final 3 epochs, we fine-tune the 90% pruned model further. Since we set `leave_enabled=True` the sparsity level will be maintained as the fine-tuning occurs.

In [19]:
pruning_recipe_path = "./recipes/densenet-flowers-pruning-recipe.yaml"

In [None]:
!cat ./recipes/densenet-flowers-pruning-recipe.yaml

In [22]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(pruning_recipe_path)
logger = TensorBoardLogger(log_path="./tensorboard_outputs/densenet/pruning-run")
optimizer = manager.modify(model, optimizer, loggers=[logger], steps_per_epoch=len(train_loader))

Next, kick off the GMP training loop. 

As you can see, we use the wrapped `optimizer` and `model` in the same way as above. SparseML parsed the recipe and updated the `optimizer` with the logic of GMP algorithm from the recipe. This allows you to use the `optimizer` and `model` as usual, with all of the pruning-related logic handled by SparseML.

Our 90% pruned model reaches ~91.8% validation accuracy (vs ~91.9% for the dense model).

In [23]:
# run GMP algorithm
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")
    
    logger.log_scalar("Metrics/Loss (Train)", train_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Train)", train_acc, epoch)
    logger.log_scalar("Metrics/Loss (Validation)", val_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Validation)", val_acc, epoch)

manager.finalize(model)

Running Training Epoch 1/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 1/13
Training Loss: 0.22910855180816725
Top 1 Acc: 0.9735294117647059

Running Validation Epoch 1/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 1/13
Val Loss: 0.9752587381517515
Top 1 Acc: 0.7333333333333333

Running Training Epoch 2/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 2/13
Training Loss: 0.1926842465181835
Top 1 Acc: 0.9676470588235294

Running Validation Epoch 2/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 2/13
Val Loss: 0.5537524847895838
Top 1 Acc: 0.85

Running Training Epoch 3/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 3/13
Training Loss: 0.0884120583650656
Top 1 Acc: 0.9882352941176471

Running Validation Epoch 3/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 3/13
Val Loss: 0.458682878066611
Top 1 Acc: 0.8764705882352941

Running Training Epoch 4/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 4/13
Training Loss: 0.07008300906454679
Top 1 Acc: 0.9941176470588236

Running Validation Epoch 4/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 4/13
Val Loss: 0.4294259559828788
Top 1 Acc: 0.8872549019607843

Running Training Epoch 5/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 5/13
Training Loss: 0.058840728845098056
Top 1 Acc: 0.9941176470588236

Running Validation Epoch 5/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 5/13
Val Loss: 0.4169776111084502
Top 1 Acc: 0.8970588235294118

Running Training Epoch 6/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 6/13
Training Loss: 0.06827072701707948
Top 1 Acc: 0.9931372549019608

Running Validation Epoch 6/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 6/13
Val Loss: 0.4382545033295173
Top 1 Acc: 0.8921568627450981

Running Training Epoch 7/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 7/13
Training Loss: 0.07268102490343153
Top 1 Acc: 0.9921568627450981

Running Validation Epoch 7/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 7/13
Val Loss: 0.4186342091416009
Top 1 Acc: 0.8941176470588236

Running Training Epoch 8/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 8/13
Training Loss: 0.062260200167656876
Top 1 Acc: 0.9950980392156863

Running Validation Epoch 8/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 8/13
Val Loss: 0.4408616948639974
Top 1 Acc: 0.8872549019607843

Running Training Epoch 9/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 9/13
Training Loss: 0.0411255351791624
Top 1 Acc: 1.0

Running Validation Epoch 9/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 9/13
Val Loss: 0.4180755140114343
Top 1 Acc: 0.9019607843137255

Running Training Epoch 10/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 10/13
Training Loss: 0.027908188028959557
Top 1 Acc: 1.0

Running Validation Epoch 10/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 10/13
Val Loss: 0.40239793025830295
Top 1 Acc: 0.8980392156862745

Running Training Epoch 11/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 11/13
Training Loss: 0.02035851358959917
Top 1 Acc: 1.0

Running Validation Epoch 11/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 11/13
Val Loss: 0.3589820931883878
Top 1 Acc: 0.9088235294117647

Running Training Epoch 12/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 12/13
Training Loss: 0.012050284800352529
Top 1 Acc: 1.0

Running Validation Epoch 12/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 12/13
Val Loss: 0.35020815586904064
Top 1 Acc: 0.9127450980392157

Running Training Epoch 13/13


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 13/13
Training Loss: 0.007597094325319631
Top 1 Acc: 1.0

Running Validation Epoch 13/13


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 13/13
Val Loss: 0.3375227072065172
Top 1 Acc: 0.9186274509803921



Here is a sample of the TensorBoard output, showing the validation accuracy, a particular layer's sparsity level, and the learning rate over time.

![tensorboard output](./images/densenet-flowers-tensorboard-output.png)

The resulting model is is 90% sparse and achieves validation accuracy of ~91.8% (vs the unoptimized dense model at ~92.5%) without much hyperparameter search. Key hyperparameter experiments you may want to run include:
- Learning rate
- Learning rate schedule
- Sparsity level
- Number of pruning epochs

In [26]:
print("Sparsity By Layer:")
for (name, layer) in get_prunable_layers(model):
    print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}")

Sparsity By Layer:
features.conv0.weight: 0.4737
features.denseblock1.denselayer1.conv1.weight: 0.7423
features.denseblock1.denselayer1.conv2.weight: 0.8321
features.denseblock1.denselayer2.conv1.weight: 0.7501
features.denseblock1.denselayer2.conv2.weight: 0.8680
features.denseblock1.denselayer3.conv1.weight: 0.7543
features.denseblock1.denselayer3.conv2.weight: 0.8259
features.denseblock1.denselayer4.conv1.weight: 0.8161
features.denseblock1.denselayer4.conv2.weight: 0.8302
features.denseblock1.denselayer5.conv1.weight: 0.8723
features.denseblock1.denselayer5.conv2.weight: 0.8858
features.denseblock1.denselayer6.conv1.weight: 0.8428
features.denseblock1.denselayer6.conv2.weight: 0.8443
features.transition1.conv.weight: 0.7129
features.denseblock2.denselayer1.conv1.weight: 0.9314
features.denseblock2.denselayer1.conv2.weight: 0.8886
features.denseblock2.denselayer2.conv1.weight: 0.8830
features.denseblock2.denselayer2.conv2.weight: 0.8611
features.denseblock2.denselayer3.conv1.weight:

Finally, export your model to ONNX.

In [None]:
save_dir = "densenet-models"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="densenet-pruned-fp32.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="pruned-fp32-model.onnx", convert_qat=True)

## Step 5: Quantize the Model

With a sparse model trained on Flowers, we are now ready to apply the QAT algorithm to quantize the model. Generally, we can run QAT for 5-6 epochs. Quantizing the model together with pruning allows us to gain a speedup when running with DeepSparse.

In [38]:
torch.cuda.empty_cache()

In [40]:
# first, load the trained model from Step 4
checkpoint = torch.load("./densenet-models/training/densenet-pruned-fp32.pth")
model = torchvision.models.densenet121()
model.classifier = torch.nn.Linear(model.classifier.in_features, NUM_LABELS)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)

# setup loss function and optimizer, LR will be overriden by sparseml
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

Next, we will create a Quantization recipe. We will us the `QuantizationModifier` (which instructs SparseML to run QAT) along with the `ConstantPruningModifier` (which instructs SparseML to maintain sparsity).

Here is what the recipe looks like:

```yaml
# Epoch and Learning-Rate variables
num_epochs: 6.0
init_lr: 0.0004

# quantization variables
quantization_epochs: 6.0

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

# Phase 1 Sparse Transfer Learning / Recovery
sparse_transfer_learning_modifiers:
  - !ConstantPruningModifier
    start_epoch: 0.0
    params: __ALL_PRUNABLE__

# Phase 2 Apply quantization
sparse_quantized_transfer_learning_modifiers:
  - !QuantizationModifier
    start_epoch: eval(num_epochs - quantization_epochs)
```

In [41]:
quantization_recipe_path = "./recipes/quantization-recipe.yaml"

In [42]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml(quantization_recipe_path)
logger = TensorBoardLogger(log_path="./tensorboard_outputs/densenet/quantization-run")
optimizer = manager.modify(model, optimizer, loggers=[logger], steps_per_epoch=len(train_loader))

In [43]:
# run QAT algorithm
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")
    
    logger.log_scalar("Metrics/Loss (Train)", train_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Train)", train_acc, epoch)
    logger.log_scalar("Metrics/Loss (Validation)", val_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Validation)", val_acc, epoch)

manager.finalize(model)

Running Training Epoch 1/5


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 1/5
Training Loss: 0.020167092072369996
Top 1 Acc: 1.0

Running Validation Epoch 1/5


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 1/5
Val Loss: 0.3898960210317455
Top 1 Acc: 0.8980392156862745

Running Training Epoch 2/5


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 2/5
Training Loss: 0.011033405537091312
Top 1 Acc: 1.0

Running Validation Epoch 2/5


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 2/5
Val Loss: 0.381910876960319
Top 1 Acc: 0.8970588235294118

Running Training Epoch 3/5


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 3/5
Training Loss: 0.006457825023971964
Top 1 Acc: 1.0

Running Validation Epoch 3/5


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 3/5
Val Loss: 0.33566539320781885
Top 1 Acc: 0.9137254901960784

Running Training Epoch 4/5


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 4/5
Training Loss: 0.003771477568989212
Top 1 Acc: 1.0

Running Validation Epoch 4/5


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 4/5
Val Loss: 0.34420355190604823
Top 1 Acc: 0.907843137254902

Running Training Epoch 5/5


  0%|          | 0/64 [00:00<?, ?it/s]

Training Epoch: 5/5
Training Loss: 0.002937028719316004
Top 1 Acc: 1.0

Running Validation Epoch 5/5


  0%|          | 0/64 [00:00<?, ?it/s]

Validation Epoch: 5/5
Val Loss: 0.3459806305090751
Top 1 Acc: 0.9088235294117647



The resulting model is is 90% sparse and INT8 quantized and achieves validation accuracy of ~90.9% (vs the unoptimized dense model at ~91.9%) without much hyperparameter search. Key hyperparameter experiments you may want to run for QAT are:
- Learning rate

Finally, export to ONNX:

In [44]:
save_dir = "densenet-models"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="densenet-pruned-int8.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="pruned-int8-model.onnx", convert_qat=True)

2023-03-14 23:27:05 sparseml.pytorch.sparsification.quantization.quantize_qat_export INFO     Converted 59 quantizable Conv ops with weight and bias to ConvInteger and Add
2023-03-14 23:27:05 sparseml.pytorch.sparsification.quantization.quantize_qat_export INFO     Converted 1 quantizable Gemm ops with weight and bias to MatMulInteger and Add
