# Training with Weight Pruning

## Import the required modules

In [1]:
import numpy as np
import os
import pandas as pd
import random
import torch
import torchaudio
import torchaudio.transforms as T
import zipfile

from time import time
from torch import nn
from torch.nn.utils import prune
from torch.utils.data import Dataset

from msc_dataset_lab3 import MSCDataset

## Define the Hyperparameters

In [2]:
CFG = {
    'sampling_rate': 16000,
    'frame_length_in_s': 0.04,
    'frame_step_in_s': 0.02,
    'n_mels': 40,
    'f_min': 0,
    'f_max': 8000,
    'n_mfcc': 40,
    'seed': 0,
    'train_steps': 2000,
    'train_batch_size': 32,
    # Pruning hyperparameters:
    'start_pruning': 499,     # start pruning at this training iteration
    'end_pruning': 1499,      # stop pruning after this training iteration
    'prune_amount': 0.1,      # percentage of connections to prune
    'prune_every_steps': 100, # apply pruning every N steps
}

## Define the target classes

In [3]:
# TODO: Define the set of target classes
# CLASSES = [...]

## Set Deterministic Behaviour

In [4]:
torch.manual_seed(CFG['seed'])
np.random.seed(CFG['seed'])
random.seed(CFG['seed'])

## Create Datasets and Dataloaders for train/test

In [5]:
transform = T.MFCC(
    sample_rate=16000,
    n_mfcc=CFG['n_mfcc'],
    log_mels=True,
    melkwargs=dict(
        # Spectrogram parameters
        n_fft=int(CFG['frame_length_in_s'] * CFG['sampling_rate']),
        win_length=int(CFG['frame_length_in_s'] * CFG['sampling_rate']),
        hop_length=int(CFG['frame_step_in_s'] * CFG['sampling_rate']),
        center=False,
        # Mel Spectrogram paramaters
        f_min=CFG['f_min'],
        f_max=CFG['f_max'],
        n_mels=CFG['n_mels'],
    )
)

# TODO: instantiate train_ds and test_ds objects
# train_ds = ...
# test_ds = ...

sampler = torch.utils.data.RandomSampler(
    train_ds,
    replacement=True,
    num_samples=CFG['train_steps'] * CFG['train_batch_size'],
)
train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=CFG['train_batch_size'],
    sampler=sampler,
    num_workers=2,
)

test_loader = torch.utils.data.DataLoader(
    test_ds, batch_size=100, num_workers=2
)

## Create the Model

In [5]:
# TODO: Write the model code
# model = ...

In [None]:
print(model)

## Get parameters to prune

In [None]:
parameters_to_prune = [
        (module, 'weight')
        for module in model.modules()
        if isinstance(module, (nn.Conv2d, nn.Linear))
]
parameters_to_prune

## Define function to inspect model sparsity

In [8]:
def print_sparsity(model):
    total_zeros = 0
    total_params = 0
    print('\nSparsity Report:')
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            if hasattr(module, 'weight_mask'):
                # Before strip_pruning: use mask
                mask = module.weight_mask
                num_zeros = torch.sum(mask == 0).item()
                num_params = mask.numel()
            else:
                # After strip_pruning: use weight
                weight = module.weight
                num_zeros = torch.sum(weight == 0).item()
                num_params = weight.numel()

            layer_sparsity = 100. * num_zeros / num_params
            print(f'Layer {name:<30} | Sparsity: {layer_sparsity:.2f}%')
            total_zeros += num_zeros
            total_params += num_params

    global_sparsity = 100. * total_zeros / total_params
    print(f'Global Sparsity: {global_sparsity:.2f}%\n')

## Define pruning control variables

In [9]:
start_pruning = CFG['start_pruning']
end_pruning = CFG['end_pruning']
prune_amount = CFG['prune_amount']
prune_every_steps = CFG['prune_every_steps']

## Define the Training Loss and Optimizer

In [10]:
# TODO: instantiate the loss and optimizer objects
# loss_module = ...
# optimizer = ...

## Define the Training Loop with Iterative Weight Pruning

### Pruning Neural Network

The most critical step for neural network pruning is to find out the unimportant synapse connections, i.e., weights, and set the weights to exactly zero. This step is also called pruning weights.

PyTorch uses binary masks tensors to indicate which synapse connections, i.e., some weights, are not important and should be pruned. These binary masks are constant during neural network training or fine-tuning.

There are many ways to prune weights. Some straightforward methods use the magnitude of the weights to determine which weights are not important. For example, the `prune.L1Unstructured` method prunes parameters in a tensor by zeroing out the ones with the lowest L1-norm.

### Sparsity for **Iterative** Pruning

The `prune.global_unstructured` function uses an `amount` argument which could be either the percentage of connections to prune (if it is a float between 0 and 1), or the absolute number of connections to prune (if it is a non-negative integer). When it is the percentage, it is the the relative percentage to the number of unpruned parameters in the module. For example, in **iterative** pruning, we prune the weights of a certain layer by `amount=0.2` in the first **iteration** and further prune the same layer by `amount=0.2` in the second **iteration**. The amount of the valid parameters after the pruning will be $1 \times (1 - 0.2) \times (1 - 0.2)$, and the sparsity of the parameters, i.e., the prune rate, in this module will be $1 - 1 \times (1 - 0.2) \times (1 - 0.2)$.

Formally, the final prune rate could be calculated using the following equation. Suppose the relative prune rate for each **iteration** is $\gamma$, the final prune rate, after $n$ **iterations**, will be

$$1 - (1 - \gamma)^n$$

Similarly, it is also easy to derive the final prune rate for the scenario that $\gamma$ is different in each **iteration**.

### One-Time VS Multi-Time Iterative Pruning + Fine-Tuning

Unlike **one-time iterative pruning + fine-tuning** which achieves the desired prune rate by pruning and fine-tuning once, **multi-time iterative pruning + fine-tuning** achieves the desired prune rate by pruning and fine-tuning multiple-times. For example, to achieve the desired prune rate of 68.62%, we could run pruning and fine-tuning for 11 iterations, achieving prune rate of 10.00%, 19.00%, 27.10%, 34.39%, 40.95%, 46.86%, 52.17%, 56.95%, 61.26%, 65.13%, 68.62% in each iteration.

Usually multi-time iterative pruning + fine-tuning is better than one-time iterative pruning + fine-tuning. 

### Local Pruning VS Global Pruning
**Local pruning** is to prune the parameters module by module. The parameters from other modules do not affect the parameters being pruned. We could specify the prune rate for each layer in the network explicitly.

**Global pruning** groups many different modules and prune the parameters in these modules as if they were from one module. Therefore, the prune rate for each individual layer will be different.

Typically, global pruning performs much better than local pruning.

In [None]:
# Write the training loop code
for step, batch in enumerate(train_loader):
    # TODO:
    # ...
    # loss = ...
    # ...

    # Iterative global pruning
    if step >= start_pruning and step <= end_pruning:
        if ((step + 1) % prune_every_steps) == 0:
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=prune_amount
            )

    if ((step + 1) % 100) == 0 or step == 0:
        print(f'Step={step}; Training Loss={loss.item():.3f}')
        print_sparsity(model)

## Remove pruning re-parametrization

Once the fine-tuning is finished, the model weights can be finalized. We will combine the mask and weight together using the `prune.remove` method.

In [None]:
for module, param_name in parameters_to_prune:
    prune.remove(module, param_name)

print_sparsity(model)

## Evaluate the Model

In [13]:
# Write the evaluation loop code
# ...
# test_accuracy = ...

print(f'Test Accuracy: {test_accuracy:.2f}%')

Test Accuracy: 90.25%


## Save the Model

In [None]:
timestamp = int(time())

saved_model_dir = './saved_models/'
if not os.path.exists(saved_model_dir):
    os.makedirs(saved_model_dir)

print(f'Model Timestamp: {timestamp}')

torch.onnx.export(
    transform,  # model to export
    torch.randn(1, 1, 16000),  # inputs of the model,
    f'{saved_model_dir}/{timestamp}_frontend.onnx',  # filename of the ONNX model
    input_names=['input'], # input name in the ONNX model
    dynamo=True,
    optimize=True,
    report=False,
    external_data=False,
)
torch.onnx.export(
    model,  # model to export
    train_ds[0]['x'].unsqueeze(0),  # inputs of the model,
    f'{saved_model_dir}/{timestamp}_model.onnx',  # filename of the ONNX model
    input_names=['input'], # input name in the ONNX model
    dynamo=True,
    optimize=True,
    report=False,
    external_data=False,
)

## Zip the Model

Even after pruning is complete, the model's weights are still stored using dense matrix data structures. As a result, despite the sparsity introduced in the weight tensors, both the computational cost and storage requirements remain the same as those of the original dense model. To achieve latency improvements, specialized neural kernels for sparse matrix multiplication are required. For storage savings, lossless compression techniques like *zip* can be applied to encode the long sequences of zeros introduced by pruning into a more compact representation.

In [None]:
frontend_size = os.path.getsize(f'{saved_model_dir}/{timestamp}_frontend.onnx')
model_size = os.path.getsize(f'{saved_model_dir}/{timestamp}_model.onnx')
total_size = frontend_size + model_size

with zipfile.ZipFile(f'{saved_model_dir}/{timestamp}_model.onnx.zip', 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(f'{saved_model_dir}/{timestamp}_model.onnx')

zip_model_size = os.path.getsize(f'{saved_model_dir}/{timestamp}_model.onnx.zip')
zip_total_size = frontend_size + zip_model_size

print(f'Frontend Size: {frontend_size / 2**10:.1f}KB')
print(f'Model Size (ONNX): {model_size / 2**10:.1f}KB')
print(f'Total Size (ONNX): {total_size / 2**10:.1f}KB')
print()
print(f'Model Size (ZIP): {zip_model_size / 2**10:.1f}KB')
print(f'Total Size (ZIP): {zip_total_size / 2**10:.1f}KB')

## Save Hyperparameters & Results

In [16]:
output_dict = {
    'timestamp': timestamp,
    **CFG,
    'test_accuracy': test_accuracy
}

df = pd.DataFrame([output_dict])

output_path='./mwp_results.csv'
df.to_csv(output_path, mode='a', header=not os.path.exists(output_path), index=False)

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=3880e510-b64c-4bb5-b488-c2122d5d9e2d' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>