In [11]:
import sparseml
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import TensorBoardLogger, ModuleExporter, get_prunable_layers, tensor_sparsity
from sparseml.pytorch.utils.helpers import get_optim_learning_rate

import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

import torchvision
from torchvision import transforms

from tqdm.auto import tqdm
import math

## **Setup Dataset**

Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of between 40 and 258 images.

The images have large scale, pose and light variations. In addition, there are categories that have large variations within the category, and several very similar categories.

In [2]:
NUM_LABELS = 102
BATCH_SIZE = 32

# imagenet transformers
imagenet_transform = transforms.Compose([
   transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None),
   transforms.CenterCrop(size=(224, 224)),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# datasets
train_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="train",
    transform=imagenet_transform,
    download=True
)

val_dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="val",
    transform=imagenet_transform,
    download=True
)

# dataloaders
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16)
val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16)

## Setup Training Loop

We will use this training loop below. This is standard PyTorch functionality.

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

def run_model_one_epoch(model, data_loader, criterion, device, train=False, optimizer=None):
    if train:
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    total_correct = 0
    total_predictions = 0

    # loop through batches
    for step, (inputs, labels) in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        if train:
            optimizer.zero_grad()

        # compute loss, run backpropogation
        outputs = model(inputs)  # model returns logits
        loss = criterion(outputs, labels)
        if train:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()

        # run evaluation
        predictions = outputs.argmax(dim=1)
        total_correct += torch.sum(predictions == labels).item()
        total_predictions += inputs.size(0)

    # return loss and evaluation metric
    loss = running_loss / (step + 1.0)
    accuracy = total_correct / total_predictions
    return loss, accuracy

cuda


## **Part 1: Train DenseNet121 Model as Usual**

We download pretrained DenseNet121 from torchvision, setting it to use 102 classes.

In [16]:
model = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.DEFAULT)
model.classifier = torch.nn.Linear(model.classifier.in_features, NUM_LABELS)
model.to(device)
print(model)

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [17]:
# setup loss function and optimizer, LR will be overriden by sparseml
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

In [18]:
!cat ./dense_model/dense-recipe.yaml


# Epoch and Learning-Rate variables
num_epochs: 10.0
init_lr: 0.0005

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: 0.0
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(num_epochs)


Update the Optimizer and Model with the logic from the recipe.

In [19]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml("./dense_model/dense-recipe.yaml")
logger = TensorBoardLogger(log_path="./dense_model/training/tensorboard_outputs")
optimizer = manager.modify(model, optimizer, loggers=[logger], steps_per_epoch=len(train_loader))

Kick off the transfer learning loop. We fine-tune onto the Flowers dataset, reaching 91.3% validation accuracy after 10 epochs.

In [20]:
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")
    
    logger.log_scalar("Metrics/Loss (Train)", train_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Train)", train_acc, epoch)
    logger.log_scalar("Metrics/Loss (Validation)", val_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Validation)", val_acc, epoch)

manager.finalize(model)

Running Training Epoch 1/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 1/10
Training Loss: 3.618388310074806
Top 1 Acc: 0.3

Running Validation Epoch 1/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 1/10
Val Loss: 2.134434787556529
Top 1 Acc: 0.6274509803921569

Running Training Epoch 2/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 2/10
Training Loss: 1.3103822488337755
Top 1 Acc: 0.9019607843137255

Running Validation Epoch 2/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 2/10
Val Loss: 1.1688878536224365
Top 1 Acc: 0.8598039215686275

Running Training Epoch 3/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 3/10
Training Loss: 0.3958540456369519
Top 1 Acc: 0.9931372549019608

Running Validation Epoch 3/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 3/10
Val Loss: 0.7494302792474627
Top 1 Acc: 0.8813725490196078

Running Training Epoch 4/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 4/10
Training Loss: 0.12264440068975091
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 4/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 4/10
Val Loss: 0.6009533200412989
Top 1 Acc: 0.9068627450980392

Running Training Epoch 5/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 5/10
Training Loss: 0.06395778199657798
Top 1 Acc: 1.0

Running Validation Epoch 5/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 5/10
Val Loss: 0.5246470430865884
Top 1 Acc: 0.9166666666666666

Running Training Epoch 6/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 6/10
Training Loss: 0.041063663142267615
Top 1 Acc: 1.0

Running Validation Epoch 6/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 6/10
Val Loss: 0.5043302299454808
Top 1 Acc: 0.9156862745098039

Running Training Epoch 7/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 7/10
Training Loss: 0.03354773804312572
Top 1 Acc: 1.0

Running Validation Epoch 7/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 7/10
Val Loss: 0.490260950056836
Top 1 Acc: 0.9166666666666666

Running Training Epoch 8/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 8/10
Training Loss: 0.03196021361509338
Top 1 Acc: 1.0

Running Validation Epoch 8/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 8/10
Val Loss: 0.4794748453423381
Top 1 Acc: 0.9156862745098039

Running Training Epoch 9/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 9/10
Training Loss: 0.027999908081255853
Top 1 Acc: 1.0

Running Validation Epoch 9/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 9/10
Val Loss: 0.4790585918817669
Top 1 Acc: 0.9137254901960784

Running Training Epoch 10/10


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 10/10
Training Loss: 0.027712288312613964
Top 1 Acc: 1.0

Running Validation Epoch 10/10


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 10/10
Val Loss: 0.4804669988807291
Top 1 Acc: 0.9156862745098039



Export the model in case we want to reload in the future, so we do not have to rerun.

In [21]:
save_dir = "dense_model"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="densenet-121-dense-flowers.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="dense-model.onnx", convert_qat=True)



## Part 2: Prune The Model

We load the model trained in the prior step as the starting point.

In [4]:
torch.cuda.empty_cache()

In [5]:
checkpoint = torch.load("./dense_model/training/densenet-121-dense-flowers.pth")
model = torchvision.models.densenet121()
model.classifier = torch.nn.Linear(model.classifier.in_features, NUM_LABELS)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)

# setup loss function and optimizer, LR will be overriden by sparseml
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

In [6]:
print(model)

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

We create a pruning recipe. We will start with the `GlobalMagnitudePruning` algorithm, pruning layers on average to 90%. We target all of the Conv layers (per the regexes in the `params` dictionary below) - which tie to the layer names above. We will run GMP for 10 epochs and then fine-tune for the final 3 epochs.

In [7]:
!cat recipe-0.yaml

# Epoch and Learning-Rate variables
num_epochs: 13.0
pruning_epochs: 10.0
init_lr: 0.0003
final_lr: 0.0001
inter_func: cubic
mask_type: unstructured

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: eval(num_epochs)

  - !LearningRateFunctionModifier
    final_lr: eval(final_lr)
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: 0.0
    end_epoch: eval(pruning_epochs)
    
  - !LearningRateFunctionModifier
    final_lr: eval(final_lr)
    init_lr: eval(init_lr)
    lr_func: cosine
    start_epoch: eval(pruning_epochs)
    end_epoch: eval(num_epochs)

# Pruning
pruning_modifiers:
  - !GlobalMagnitudePruningModifier
    init_sparsity: 0.05
    final_sparsity: 0.90
    start_epoch: 0.0
    end_epoch: eval(pruning_epochs)
    update_frequency: 0.5
    params: 
        - 'features.conv0.weight'
        - 're:features.denseblock1.*.conv1.weight'
        - 're:features.denseblock1.*.conv2.weight'
        - 're:

In [8]:
# create ScheduledModifierManager and Optimizer wrapper
manager = ScheduledModifierManager.from_yaml("./recipe-0.yaml")
logger = TensorBoardLogger(log_path="./tensorboard_outputs")
optimizer = manager.modify(model, optimizer, loggers=[logger], steps_per_epoch=len(train_loader))

In [9]:
epoch = 0
for epoch in range(manager.max_epochs):
    # run training loop
    epoch_name = f"{epoch + 1}/{manager.max_epochs}"
    
    print(f"Running Training Epoch {epoch_name}")
    train_loss, train_acc = run_model_one_epoch(model, train_loader, criterion, device, train=True, optimizer=optimizer)
    print(f"Training Epoch: {epoch_name}\nTraining Loss: {train_loss}\nTop 1 Acc: {train_acc}\n")

    # run validation loop
    print(f"Running Validation Epoch {epoch_name}")
    val_loss, val_acc = run_model_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Epoch: {epoch_name}\nVal Loss: {val_loss}\nTop 1 Acc: {val_acc}\n")
    
    logger.log_scalar("Metrics/Loss (Train)", train_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Train)", train_acc, epoch)
    logger.log_scalar("Metrics/Loss (Validation)", val_loss, epoch)
    logger.log_scalar("Metrics/Accuracy (Validation)", val_acc, epoch)

manager.finalize(model)

Running Training Epoch 1/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 1/13
Training Loss: 0.05704788229195401
Top 1 Acc: 0.9970588235294118

Running Validation Epoch 1/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 1/13
Val Loss: 0.5254904199391603
Top 1 Acc: 0.8970588235294118

Running Training Epoch 2/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 2/13
Training Loss: 0.045509724208386615
Top 1 Acc: 0.996078431372549

Running Validation Epoch 2/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 2/13
Val Loss: 0.49433585489168763
Top 1 Acc: 0.888235294117647

Running Training Epoch 3/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 3/13
Training Loss: 0.05249260214623064
Top 1 Acc: 0.9941176470588236

Running Validation Epoch 3/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 3/13
Val Loss: 0.4971598202828318
Top 1 Acc: 0.8921568627450981

Running Training Epoch 4/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 4/13
Training Loss: 0.04789798497222364
Top 1 Acc: 0.9970588235294118

Running Validation Epoch 4/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 4/13
Val Loss: 0.405253738630563
Top 1 Acc: 0.9098039215686274

Running Training Epoch 5/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 5/13
Training Loss: 0.03558145748684183
Top 1 Acc: 1.0

Running Validation Epoch 5/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 5/13
Val Loss: 0.40998171758838
Top 1 Acc: 0.9098039215686274

Running Training Epoch 6/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 6/13
Training Loss: 0.043919957941398025
Top 1 Acc: 1.0

Running Validation Epoch 6/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 6/13
Val Loss: 0.44851411995477974
Top 1 Acc: 0.9009803921568628

Running Training Epoch 7/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 7/13
Training Loss: 0.06865048292092979
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 7/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 7/13
Val Loss: 0.500796954613179
Top 1 Acc: 0.8813725490196078

Running Training Epoch 8/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 8/13
Training Loss: 0.08089102717349306
Top 1 Acc: 0.9950980392156863

Running Validation Epoch 8/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 8/13
Val Loss: 0.4999986342154443
Top 1 Acc: 0.8911764705882353

Running Training Epoch 9/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 9/13
Training Loss: 0.05980767833534628
Top 1 Acc: 0.9990196078431373

Running Validation Epoch 9/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 9/13
Val Loss: 0.4825127385556698
Top 1 Acc: 0.8960784313725491

Running Training Epoch 10/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 10/13
Training Loss: 0.029828229744452983
Top 1 Acc: 1.0

Running Validation Epoch 10/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 10/13
Val Loss: 0.45201362296938896
Top 1 Acc: 0.9009803921568628

Running Training Epoch 11/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 11/13
Training Loss: 0.017061863880371675
Top 1 Acc: 1.0

Running Validation Epoch 11/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 11/13
Val Loss: 0.40311497752554715
Top 1 Acc: 0.9058823529411765

Running Training Epoch 12/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 12/13
Training Loss: 0.01124472642550245
Top 1 Acc: 1.0

Running Validation Epoch 12/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 12/13
Val Loss: 0.38544022431597114
Top 1 Acc: 0.907843137254902

Running Training Epoch 13/13


  0%|          | 0/32 [00:00<?, ?it/s]

Training Epoch: 13/13
Training Loss: 0.009475693295826204
Top 1 Acc: 1.0

Running Validation Epoch 13/13


  0%|          | 0/32 [00:00<?, ?it/s]

Validation Epoch: 13/13
Val Loss: 0.38191458326764405
Top 1 Acc: 0.907843137254902



In [12]:
print(f"Checking Sparsity Level:")
for (name, layer) in get_prunable_layers(model):
    print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}")

Checking Sparsity Level:
features.conv0.weight: 0.4670
features.denseblock1.denselayer1.conv1.weight: 0.7366
features.denseblock1.denselayer1.conv2.weight: 0.8294
features.denseblock1.denselayer2.conv1.weight: 0.7461
features.denseblock1.denselayer2.conv2.weight: 0.8658
features.denseblock1.denselayer3.conv1.weight: 0.7488
features.denseblock1.denselayer3.conv2.weight: 0.8242
features.denseblock1.denselayer4.conv1.weight: 0.8116
features.denseblock1.denselayer4.conv2.weight: 0.8263
features.denseblock1.denselayer5.conv1.weight: 0.8702
features.denseblock1.denselayer5.conv2.weight: 0.8846
features.denseblock1.denselayer6.conv1.weight: 0.8377
features.denseblock1.denselayer6.conv2.weight: 0.8414
features.transition1.conv.weight: 0.7090
features.denseblock2.denselayer1.conv1.weight: 0.9321
features.denseblock2.denselayer1.conv2.weight: 0.8885
features.denseblock2.denselayer2.conv1.weight: 0.8816
features.denseblock2.denselayer2.conv2.weight: 0.8591
features.denseblock2.denselayer3.conv1.w

In [13]:
save_dir = "experiment-0"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="densenet-121-sparse-flowers-experiment-0.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="sparse-gmp-model.onnx", convert_qat=True)



The resulting model is is 90% sparse and achieves validation accuracy of 90.3% (vs the unoptimized dense model at 91.3%) without much hyperparameter search,

Key hyperparameter experiments you may want to run include:
- Learning rate
- Learning rate schedule
- Sparsity level
- Skipping layers
- Number of pruning epochs