<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/SNIP_ResNet20.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

!pip install -q tensorflow-model-optimization

Cloning into 'cnn_pruning_cifar10'...
remote: Enumerating objects: 206, done.[K
remote: Counting objects: 100% (62/62), done.[K
remote: Compressing objects: 100% (60/60), done.[K
remote: Total 206 (delta 32), reused 2 (delta 2), pack-reused 144 (from 1)[K
Receiving objects: 100% (206/206), 124.13 KiB | 676.00 KiB/s, done.
Resolving deltas: 100% (95/95), done.
/content/cnn_pruning_cifar10


In [2]:
import tensorflow as tf
import numpy as np
from models.resnet20 import build_resnet20

def compute_snip_scores(model, x_sample, y_sample):
    with tf.GradientTape() as tape:
        logits = model(x_sample, training=True)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_sample, logits)
        loss = tf.reduce_mean(loss)

    grads = tape.gradient(loss, model.trainable_weights)
    scores = []

    for w, g in zip(model.trainable_weights, grads):
        if 'kernel' in w.name:
            scores.append(tf.reshape(tf.abs(w * g), [-1]))

    return tf.concat(scores, axis=0)

def apply_snip_mask(model, x_sample, y_sample, sparsity=0.5):
    scores = compute_snip_scores(model, x_sample, y_sample)
    k = int((1 - sparsity) * scores.shape[0])
    threshold = tf.sort(scores)[k]

    for w in model.trainable_weights:
        if 'kernel' in w.name:
            mask = tf.cast(tf.abs(w) >= threshold, tf.float32)
            w.assign(w * mask)

def quick_train(model, x_train, y_train, x_test, y_test, epochs=10, batch_size=64):
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    model.fit(x_train, y_train,
              validation_data=(x_test, y_test),
              batch_size=batch_size,
              epochs=epochs,
              verbose=2)
    return model

In [3]:
# === Run it ===
if __name__ == "__main__":
    # Load CIFAR-10
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    y_train, y_test = y_train.flatten(), y_test.flatten()

    # Build model
    model = build_resnet20()

    # Run dummy forward pass to initialize weights
    model(tf.convert_to_tensor(x_train[:1]), training=False)

    # SNIP pruning using small batch
    apply_snip_mask(model, x_train[:64], y_train[:64], sparsity=0.3)

    # Very quick fine-tuning
    quick_train(model, x_train, y_train, x_test, y_test, epochs=30, batch_size=64)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Epoch 1/30
782/782 - 61s - loss: 1.3401 - accuracy: 0.5109 - val_loss: 1.1363 - val_accuracy: 0.5894 - 61s/epoch - 78ms/step
Epoch 2/30
782/782 - 54s - loss: 0.9124 - accuracy: 0.6764 - val_loss: 1.2805 - val_accuracy: 0.5745 - 54s/epoch - 69ms/step
Epoch 3/30
782/782 - 57s - loss: 0.7387 - accuracy: 0.7416 - val_loss: 1.0174 - val_accuracy: 0.6621 - 57s/epoch - 73ms/step
Epoch 4/30
782/782 - 57s - loss: 0.6244 - accuracy: 0.7817 - val_loss: 0.7847 - val_accuracy: 0.7267 - 57s/epoch - 74ms/step
Epoch 5/30
782/782 - 55s - loss: 0.5460 - accuracy: 0.8086 - val_loss: 1.1357 - val_accuracy: 0.6390 - 55s/epoch - 70ms/step
Epoch 6/30
782/782 - 54s - loss: 0.4848 - accuracy: 0.8329 - val_loss: 0.7083 - val_accuracy: 0.7633 - 54s/epoch - 69ms/step
Epoch 7/30
782/782 - 54s - loss: 0.4294 - accuracy: 0.8519 - val_loss: 0.7287 - val_accuracy: 0.7584 - 54s/epoch - 69ms/step
Epoch 8/30
782/782 - 54s - loss: 0.3783 - accur

In [4]:
# === Run it ===
if __name__ == "__main__":
    # Load CIFAR-10
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    y_train, y_test = y_train.flatten(), y_test.flatten()

    # Build model
    model = build_resnet20()

    # Run dummy forward pass to initialize weights
    model(tf.convert_to_tensor(x_train[:1]), training=False)

    # SNIP pruning using small batch
    apply_snip_mask(model, x_train[:64], y_train[:64], sparsity=0.5)

    # Very quick fine-tuning
    quick_train(model, x_train, y_train, x_test, y_test, epochs=30, batch_size=64)

Epoch 1/30
782/782 - 61s - loss: 1.3626 - accuracy: 0.5041 - val_loss: 1.2454 - val_accuracy: 0.5578 - 61s/epoch - 79ms/step
Epoch 2/30
782/782 - 54s - loss: 0.9403 - accuracy: 0.6639 - val_loss: 1.0378 - val_accuracy: 0.6401 - 54s/epoch - 69ms/step
Epoch 3/30


KeyboardInterrupt: 

In [None]:
# === Run it ===
if __name__ == "__main__":
    # Load CIFAR-10
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    y_train, y_test = y_train.flatten(), y_test.flatten()

    # Build model
    model = build_resnet20()

    # Run dummy forward pass to initialize weights
    model(tf.convert_to_tensor(x_train[:1]), training=False)

    # SNIP pruning using small batch
    apply_snip_mask(model, x_train[:64], y_train[:64], sparsity=0.7)

    # Very quick fine-tuning
    quick_train(model, x_train, y_train, x_test, y_test, epochs=30, batch_size=64)

In [None]:
# 1. Force compatible versions (reset everything to Colab defaults)
!pip install -U --force-reinstall numpy==1.23.5
!pip install -U --force-reinstall tensorflow==2.14.0
!pip install -U tensorflow-model-optimization

# 2. Restart runtime automatically after install
import os
os.kill(os.getpid(), 9)

Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m98.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pymc 5.21.2 requires numpy>=1.25.0, but you have numpy 1.23.5 which is incompatible.
tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 1.23.5 which is incompatible.
blosc2 3.3.0 requires numpy>=1.26, but you have numpy 1.23.5 which is incompatible.
jax 0.5.2 requires n

Collecting tensorflow==2.14.0
  Downloading tensorflow-2.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting absl-py>=1.0.0 (from tensorflow==2.14.0)
  Downloading absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting astunparse>=1.6.0 (from tensorflow==2.14.0)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=23.5.26 (from tensorflow==2.14.0)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow==2.14.0)
  Downloading gast-0.6.0-py3-none-any.whl.metadata (1.3 kB)
Collecting google-pasta>=0.1.1 (from tensorflow==2.14.0)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting h5py>=2.9.0 (from tensorflow==2.14.0)
  Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Collecting libclang>=13.0.0 (from tensorflow==2.14.0)
  Downloading libcla

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Collecting absl-py~=1.2 (from tensorflow-model-optimization)
  Downloading absl_py-1.4.0-py3-none-any.whl.metadata (2.3 kB)
Collecting numpy~=1.23 (from tensorflow-model-optimization)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading absl_py-1.4.0-py3-none-any.whl (126 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.5/126.5 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_