Skip to content
PyTorch library to facilitate development and standardized evaluation of neural network pruning methods.
Python Jupyter Notebook
Branch: master
Clone or download

Latest commit

Fetching latest commit…
Cannot retrieve the latest commit at this time.


Type Name Latest commit message Commit time
Failed to load latest commit information.
analysis add missing results CSV file Mar 26, 2020
datasets Move around datasets Mar 9, 2020
experiment Major refactor of experiment class Mar 9, 2020
jupyter Add installation instructions Mar 9, 2020
models Fix mark classifier. Closes #5 Apr 7, 2020
plot Some extensions and fixes to plotting Mar 9, 2020
pruning Fixes to pruning classes after refactor Mar 9, 2020
scripts Refactor script Mar 9, 2020
strategies Fix argument order to map Mar 9, 2020
util Fix plain imports for utils Feb 28, 2020
.gitignore Add resume options to cli script Sep 7, 2019
LICENSE Initial commit Aug 14, 2019 Update Apr 5, 2020
requirements.txt Add installation instructions Mar 9, 2020


Open source PyTorch library to facilitate development and standardized evaluation of neural network pruning methods.


This repo contains the analysis and benchmarks results from the paper What is the State of Neural Network Pruning?.


First, install the dependencies, this repo depends on

  • PyTorch
  • Torchvision
  • NumPy
  • Pandas
  • Matplotlib

To install the dependencies

# Create a python virtualenv or conda env as necessary

# With conda
conda install numpy matplotlib pandas
conda install pytorch torchvision -c pytorch

# With pip
pip install numpy matplotlib pandas pytorch torchvision

then, to install the module itself you just need to clone the repo and add the parent path it to your PYTHONPATH. For example:

git clone shrinkbench

# Bash
echo "export PYTHONPATH=\"$PWD:\$PYTHONPATH\"" >> ~/.bashrc

echo "export PYTHONPATH=\"$PWD:\$PYTHONPATH\"" >> ~/.zshrc


ShrinkBench not only faciliates evaluation of pruning methods, but also their development. Here's the code for a simple implementation of Global Magnitude Pruning and Layerwise Magnitude Pruning. As you can see, it is quite succint; you are just tasked with implementing model_masks a function that returns the masks for the model's weight tensors. If you want to prune your model layerwise, then you just need to implement layer_masks. For more examples, see the source code for the provided baselines.

class GlobalMagWeight(VisionPruning):

    def model_masks(self):
        importances = map_importances(np.abs, self.params())
        flat_importances = flatten_importances(importances)
        threshold = fraction_threshold(flat_importances, self.fraction)
        masks = importance_masks(importances, threshold)
        return masks

class LayerMagWeight(LayerPruning, VisionPruning):

    def layer_masks(self, module):
        params = self.module_params(module)
        importances = {param: np.abs(value) for param, value in params.items()}
        masks = {param: fraction_mask(importances[param], self.fraction)
                 for param, value in params.items() if value is not None}
        return masks


See here for a notebook showing how to run pruning experiments and plot their results


The modules are organized as follows:

submodule Description
analysis/ Aggregated survey results over 80 pruning papers
datasets/ Standardized dataloaders for supported datasets
experiment/ Main experiment class with the data loading, pruning, finetuning & evaluation
metrics/ Utils for measuring accuracy, model size, flops & memory footprint
models/ Custom architectures not included in torchvision
plot/ Utils for plotting across the logged dimensions
pruning/ General pruning and masking API.
scripts/ Executable scripts for running experiments (see experiment/)
strategies/ Baselines pruning methods, mainly magnitude pruning based
You can’t perform that action at this time.