<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/global_magnitude_pruning_resnet20.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

!pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
from models.resnet20 import build_resnet20
from tensorflow.keras.datasets import cifar10

fatal: destination path 'cnn_pruning_cifar10' already exists and is not an empty directory.
/content/cnn_pruning_cifar10


In [2]:
def get_global_magnitude_scores(model):
    scores = []
    for weight in model.trainable_weights:
        if 'kernel' in weight.name:
            scores.append(tf.reshape(tf.abs(weight), [-1]))
    return tf.concat(scores, axis=0)

def apply_global_magnitude_pruning(model, sparsity):
    all_scores = get_global_magnitude_scores(model)
    k = int((1 - sparsity) * all_scores.shape[0])
    threshold = tf.sort(all_scores)[k]

    for weight in model.trainable_weights:
        if 'kernel' in weight.name:
            mask = tf.cast(tf.abs(weight) >= threshold, tf.float32)
            weight.assign(weight * mask)

In [3]:
def train_fast_pruned_resnet20(sparsity, epochs=10, batch_size=64):
    # Load and preprocess CIFAR-10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    # Build and prune model
    model = build_resnet20()
    apply_global_magnitude_pruning(model, sparsity)

    # Compile and train
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    history = model.fit(x_train, y_train,
                        validation_data=(x_test, y_test),
                        epochs=epochs,
                        batch_size=batch_size,
                        verbose=2)
    return model, history

In [5]:
# Example: Run with 30% pruning and 30 epochs
model, history = train_fast_pruned_resnet20(sparsity=0.3, epochs=50)

Epoch 1/50
782/782 - 62s - loss: 1.3560 - accuracy: 0.5081 - val_loss: 1.4153 - val_accuracy: 0.5122 - 62s/epoch - 80ms/step
Epoch 2/50
782/782 - 54s - loss: 0.9303 - accuracy: 0.6700 - val_loss: 0.9656 - val_accuracy: 0.6648 - 54s/epoch - 68ms/step
Epoch 3/50
782/782 - 53s - loss: 0.7513 - accuracy: 0.7375 - val_loss: 1.0316 - val_accuracy: 0.6681 - 53s/epoch - 68ms/step
Epoch 4/50
782/782 - 53s - loss: 0.6364 - accuracy: 0.7794 - val_loss: 0.8405 - val_accuracy: 0.7163 - 53s/epoch - 68ms/step
Epoch 5/50
782/782 - 54s - loss: 0.5495 - accuracy: 0.8075 - val_loss: 0.8557 - val_accuracy: 0.7132 - 54s/epoch - 69ms/step
Epoch 6/50
782/782 - 54s - loss: 0.4934 - accuracy: 0.8277 - val_loss: 0.7926 - val_accuracy: 0.7309 - 54s/epoch - 69ms/step
Epoch 7/50
782/782 - 53s - loss: 0.4379 - accuracy: 0.8485 - val_loss: 0.7421 - val_accuracy: 0.7559 - 53s/epoch - 68ms/step
Epoch 8/50
782/782 - 54s - loss: 0.3901 - accuracy: 0.8639 - val_loss: 0.8007 - val_accuracy: 0.7428 - 54s/epoch - 68ms/step


KeyboardInterrupt: 

In [6]:
# Example: Run with 50% pruning and 30 epochs
model, history = train_fast_pruned_resnet20(sparsity=0.5, epochs=30)

Epoch 1/30
782/782 - 60s - loss: 1.3320 - accuracy: 0.5135 - val_loss: 1.4734 - val_accuracy: 0.5053 - 60s/epoch - 77ms/step
Epoch 2/30
782/782 - 53s - loss: 0.9180 - accuracy: 0.6760 - val_loss: 1.0126 - val_accuracy: 0.6521 - 53s/epoch - 68ms/step
Epoch 3/30
782/782 - 54s - loss: 0.7444 - accuracy: 0.7399 - val_loss: 0.9970 - val_accuracy: 0.6713 - 54s/epoch - 68ms/step
Epoch 4/30
782/782 - 54s - loss: 0.6324 - accuracy: 0.7791 - val_loss: 0.7196 - val_accuracy: 0.7551 - 54s/epoch - 69ms/step
Epoch 5/30
782/782 - 54s - loss: 0.5494 - accuracy: 0.8083 - val_loss: 0.7149 - val_accuracy: 0.7462 - 54s/epoch - 69ms/step
Epoch 6/30
782/782 - 54s - loss: 0.4862 - accuracy: 0.8316 - val_loss: 0.8046 - val_accuracy: 0.7352 - 54s/epoch - 69ms/step
Epoch 7/30
782/782 - 54s - loss: 0.4264 - accuracy: 0.8513 - val_loss: 0.7188 - val_accuracy: 0.7606 - 54s/epoch - 68ms/step
Epoch 8/30
782/782 - 53s - loss: 0.3830 - accuracy: 0.8659 - val_loss: 0.6733 - val_accuracy: 0.7797 - 53s/epoch - 68ms/step


In [None]:
# Example: Run with 70% pruning and 30 epochs
model, history = train_fast_pruned_resnet20(sparsity=0.7, epochs=30)

Epoch 1/30
782/782 - 60s - loss: 1.3656 - accuracy: 0.5028 - val_loss: 1.9691 - val_accuracy: 0.4219 - 60s/epoch - 77ms/step
Epoch 2/30
782/782 - 53s - loss: 0.9020 - accuracy: 0.6813 - val_loss: 1.1451 - val_accuracy: 0.6207 - 53s/epoch - 68ms/step
Epoch 3/30
782/782 - 54s - loss: 0.7191 - accuracy: 0.7481 - val_loss: 0.9738 - val_accuracy: 0.6760 - 54s/epoch - 69ms/step
Epoch 4/30
782/782 - 53s - loss: 0.6082 - accuracy: 0.7873 - val_loss: 0.8738 - val_accuracy: 0.7201 - 53s/epoch - 68ms/step
Epoch 5/30
782/782 - 54s - loss: 0.5392 - accuracy: 0.8113 - val_loss: 0.7949 - val_accuracy: 0.7276 - 54s/epoch - 69ms/step
Epoch 6/30
782/782 - 54s - loss: 0.4795 - accuracy: 0.8324 - val_loss: 0.7898 - val_accuracy: 0.7341 - 54s/epoch - 69ms/step
Epoch 7/30
782/782 - 54s - loss: 0.4287 - accuracy: 0.8522 - val_loss: 0.9360 - val_accuracy: 0.7132 - 54s/epoch - 69ms/step
Epoch 8/30
782/782 - 54s - loss: 0.3764 - accuracy: 0.8701 - val_loss: 0.7132 - val_accuracy: 0.7668 - 54s/epoch - 68ms/step


In [None]:
# 1. Force compatible versions (reset everything to Colab defaults)
!pip install -U --force-reinstall numpy==1.23.5
!pip install -U --force-reinstall tensorflow==2.14.0
!pip install -U tensorflow-model-optimization

# 2. Restart runtime automatically after install
import os
os.kill(os.getpid(), 9)

Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m103.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pymc 5.21.2 requires numpy>=1.25.0, but you have numpy 1.23.5 which is incompatible.
tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 1.23.5 which is incompatible.
blosc2 3.3.0 requires numpy>=1.26, but you have numpy 1.23.5 which is incompatible.
jax 0.5.2 requir

Collecting tensorflow==2.14.0
  Downloading tensorflow-2.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting absl-py>=1.0.0 (from tensorflow==2.14.0)
  Downloading absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting astunparse>=1.6.0 (from tensorflow==2.14.0)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=23.5.26 (from tensorflow==2.14.0)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow==2.14.0)
  Downloading gast-0.6.0-py3-none-any.whl.metadata (1.3 kB)
Collecting google-pasta>=0.1.1 (from tensorflow==2.14.0)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting h5py>=2.9.0 (from tensorflow==2.14.0)
  Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Collecting libclang>=13.0.0 (from tensorflow==2.14.0)
  Downloading libcla

Collecting absl-py~=1.2 (from tensorflow-model-optimization)
  Downloading absl_py-1.4.0-py3-none-any.whl.metadata (2.3 kB)
Collecting numpy~=1.23 (from tensorflow-model-optimization)
  Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Downloading absl_py-1.4.0-py3-none-any.whl (126 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.5/126.5 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hUsing cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
Installing collected packages: numpy, absl-py
  Attempting uninstall: numpy
    Found existing installation: numpy 2.2.4
    Uninstalling numpy-2.2.4:
      Successfully uninstalled numpy-2.2.4
  Attempting uninstall: absl-py
    Found existing installation: absl-py 2.2.2
    Uninstalling absl-py-2.2.2:
      Successfully uninstalled absl-py-2.2.2
[31mERROR: pip's dependency resolver does not currently take into account all the packag