# Model compression

There exists a few methods to compress the model without a significant loss in performance:

- Pruning: Removing weights with a low weight
- Quantization: Quantization reduces the precision of weights and activation
- Knowledge distillation: Train a smaller (student) model to mimin the behviour of a larger (teacher) model. 

For this task, pruning will be conducted.


In [6]:
import torch
from torchvision import models
import torch.nn.utils.prune as prune


data_dir = "../data/test"
labels = ['fresh','blackspot','canker','grenning']
model_file_input = '../models/cross_validation_final.pth'
pruning_amount = 0.2
model_file_output = '../models/cross_validation_final_pruned.pth'


## Load model

In [7]:
model = models.mobilenet_v2()
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(labels))
model.load_state_dict(torch.load(model_file_input, weights_only=True))
model = model.eval()


## Apply pruning

Apply pruning with a predefined pruning_amount (e.g. 20%). This means that the 20% lowest weight get set to 0. 

In [8]:
# Define a function to apply pruning to all Conv2d layers
def apply_pruning(model, amount=pruning_amount):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=amount)
apply_pruning(model)

## Evaluate pruning

Count the number of weights that are set to zero. This should match the number that was specified before. 

In [9]:
# Function to count the pruned weights
def count_pruned_weights(model):
    total_weights = 0
    pruned_weights = 0
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            # The mask is stored in 'weight_mask' after pruning
            if hasattr(module, 'weight_mask'):
                weight_mask = module.weight_mask
                total_weights += weight_mask.numel()
                pruned_weights += (weight_mask == 0).sum().item()
    return pruned_weights, total_weights

pruned_weights, total_weights = count_pruned_weights(model)
pruned_percentage = 100 * pruned_weights / total_weights

print(f"Total weights: {total_weights}")
print(f"Pruned weights: {pruned_weights}")
print(f"Pruned percentage: {pruned_percentage:.2f}%")


Total weights: 2189760
Pruned weights: 437952
Pruned percentage: 20.00%


## Remove weights and store model to file

Remove weights from the model.

In [10]:
def remove_pruning(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            if hasattr(module, 'weight_mask'):
                prune.remove(module, 'weight')

remove_pruning(model)

torch.save(model.state_dict(), model_file_output)



## Conclusion

After this notebook, the resulting model was evaluated with the model_validation notebook. A pruninig amount of 20% did not lead to too much performance loss while reducing the amount of weights considerably. Depending on that application, the model should run on, further methods would be required.