<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/Effective_Automatic_Channel_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys
import os
sys.path.append('../models')

In [2]:
!pip install -q tensorflow-model-optimization

In [3]:
import tensorflow as tf
import numpy as np
from sklearn.cluster import KMeans
from models.resnet20 import build_resnet20  # Your custom ResNet-20 model

In [4]:
# -------------------------------
# Step 1: Cluster filters using cosine similarity
# -------------------------------
def cluster_filters(layer_weights, num_clusters):
    filters = layer_weights.reshape(layer_weights.shape[-1], -1)
    filters_norm = np.linalg.norm(filters, axis=1, keepdims=True)
    filters_normalized = filters / (filters_norm + 1e-8)

    kmeans = KMeans(n_clusters=num_clusters, random_state=0, n_init='auto')
    kmeans.fit(filters_normalized)  # ✅ Fit first!
    return kmeans.labels_           # ✅ Then get labels


In [5]:
# -------------------------------
# Step 2: Apply pruning by retaining one filter per cluster
# -------------------------------
def prune_filters(layer, cluster_labels):
    weights, biases = layer.get_weights()
    num_filters = weights.shape[-1]
    keep_mask = np.zeros(num_filters, dtype=bool)

    for cluster_id in np.unique(cluster_labels):
        indices = np.where(cluster_labels == cluster_id)[0]
        keep_mask[indices[0]] = True  # Keep one representative

    weights[..., ~keep_mask] = 0
    if biases is not None:
        biases[~keep_mask] = 0

    layer.set_weights([weights, biases])


In [9]:
# -------------------------------
# Step 3: EACP pruning and training pipeline (no pre-training)
# -------------------------------
def train_eacp_model(x_train, y_train, x_val, y_val,
                     sparsity=0.5, epochs=50, batch_size=128):
    model = build_resnet20()

    print(f"Applying EACP pruning with sparsity = {sparsity:.2f}...")
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Conv2D):
            weights = layer.get_weights()[0]
            num_filters = weights.shape[-1]
            num_clusters = max(1, int(num_filters * (1 - sparsity)))

            if num_clusters < num_filters:
                cluster_labels = cluster_filters(weights, num_clusters)
                prune_filters(layer, cluster_labels)

    print("Fine-tuning pruned model...")
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    model.fit(x_train, y_train,
              epochs=epochs,
              batch_size=batch_size,
              validation_data=(x_val, y_val),
              callbacks=[tf.keras.callbacks.EarlyStopping(patience=30, restore_best_weights=True)],
              verbose=2)

    return model

In [7]:
# Load CIFAR-10 (unchanged)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train with EACP (ResNet-20)
model = train_eacp_model(x_train, y_train, x_test, y_test, sparsity=0.3, epochs=30, batch_size=64)

Applying EACP pruning with sparsity = 0.30...




Fine-tuning pruned model...
Epoch 1/30
782/782 - 50s - loss: 1.3551 - accuracy: 0.5105 - val_loss: 1.7057 - val_accuracy: 0.4802 - 50s/epoch - 64ms/step
Epoch 2/30
782/782 - 33s - loss: 0.9461 - accuracy: 0.6636 - val_loss: 1.0769 - val_accuracy: 0.6281 - 33s/epoch - 42ms/step
Epoch 3/30
782/782 - 33s - loss: 0.7704 - accuracy: 0.7328 - val_loss: 0.8460 - val_accuracy: 0.7119 - 33s/epoch - 43ms/step
Epoch 4/30
782/782 - 34s - loss: 0.6664 - accuracy: 0.7694 - val_loss: 0.9632 - val_accuracy: 0.6820 - 34s/epoch - 43ms/step
Epoch 5/30
782/782 - 33s - loss: 0.5922 - accuracy: 0.7933 - val_loss: 1.1416 - val_accuracy: 0.6543 - 33s/epoch - 42ms/step
Epoch 6/30
782/782 - 33s - loss: 0.5338 - accuracy: 0.8163 - val_loss: 1.0394 - val_accuracy: 0.6757 - 33s/epoch - 42ms/step
Epoch 7/30
782/782 - 33s - loss: 0.4838 - accuracy: 0.8322 - val_loss: 0.7286 - val_accuracy: 0.7562 - 33s/epoch - 42ms/step
Epoch 8/30
782/782 - 36s - loss: 0.4414 - accuracy: 0.8459 - val_loss: 0.8016 - val_accuracy: 0.7

In [10]:
0# Load CIFAR-10 (unchanged)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train with EACP (ResNet-20)
model = train_eacp_model(x_train, y_train, x_test, y_test, sparsity=0.5, epochs=30, batch_size=64)

Applying EACP pruning with sparsity = 0.50...




Fine-tuning pruned model...
Epoch 1/30
782/782 - 38s - loss: 1.4288 - accuracy: 0.4759 - val_loss: 1.8455 - val_accuracy: 0.4112 - 38s/epoch - 49ms/step
Epoch 2/30
782/782 - 33s - loss: 1.0196 - accuracy: 0.6382 - val_loss: 1.4083 - val_accuracy: 0.5257 - 33s/epoch - 42ms/step
Epoch 3/30
782/782 - 33s - loss: 0.8501 - accuracy: 0.7002 - val_loss: 1.3553 - val_accuracy: 0.5501 - 33s/epoch - 42ms/step
Epoch 4/30
782/782 - 33s - loss: 0.7397 - accuracy: 0.7412 - val_loss: 1.1875 - val_accuracy: 0.6136 - 33s/epoch - 43ms/step
Epoch 5/30
782/782 - 33s - loss: 0.6643 - accuracy: 0.7657 - val_loss: 0.8108 - val_accuracy: 0.7243 - 33s/epoch - 42ms/step
Epoch 6/30
782/782 - 33s - loss: 0.6082 - accuracy: 0.7890 - val_loss: 0.8519 - val_accuracy: 0.7150 - 33s/epoch - 42ms/step
Epoch 7/30
782/782 - 33s - loss: 0.5575 - accuracy: 0.8051 - val_loss: 1.1154 - val_accuracy: 0.6568 - 33s/epoch - 43ms/step
Epoch 8/30
782/782 - 33s - loss: 0.5229 - accuracy: 0.8183 - val_loss: 0.8079 - val_accuracy: 0.7

In [11]:
# Load CIFAR-10 (unchanged)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train with EACP (ResNet-20)
model = train_eacp_model(x_train, y_train, x_test, y_test, sparsity=0.7, epochs=30, batch_size=64)

Applying EACP pruning with sparsity = 0.70...




Fine-tuning pruned model...
Epoch 1/30
782/782 - 33s - loss: 1.4924 - accuracy: 0.4537 - val_loss: 2.5461 - val_accuracy: 0.3102 - 33s/epoch - 43ms/step
Epoch 2/30
782/782 - 32s - loss: 1.1127 - accuracy: 0.6045 - val_loss: 1.2702 - val_accuracy: 0.5577 - 32s/epoch - 41ms/step
Epoch 3/30
782/782 - 30s - loss: 0.9664 - accuracy: 0.6592 - val_loss: 2.6670 - val_accuracy: 0.3517 - 30s/epoch - 39ms/step
Epoch 4/30
782/782 - 30s - loss: 0.8800 - accuracy: 0.6896 - val_loss: 1.0959 - val_accuracy: 0.6229 - 30s/epoch - 38ms/step
Epoch 5/30
782/782 - 30s - loss: 0.8142 - accuracy: 0.7139 - val_loss: 0.9568 - val_accuracy: 0.6703 - 30s/epoch - 38ms/step
Epoch 6/30
782/782 - 30s - loss: 0.7609 - accuracy: 0.7329 - val_loss: 1.0128 - val_accuracy: 0.6449 - 30s/epoch - 38ms/step
Epoch 7/30
782/782 - 30s - loss: 0.7168 - accuracy: 0.7490 - val_loss: 0.9138 - val_accuracy: 0.6811 - 30s/epoch - 39ms/step
Epoch 8/30
782/782 - 30s - loss: 0.6769 - accuracy: 0.7631 - val_loss: 1.0939 - val_accuracy: 0.6

In [None]:
# 1. Force compatible versions (reset everything to Colab defaults)
!pip install -U --force-reinstall numpy==1.23.5
!pip install -U --force-reinstall tensorflow==2.14.0
!pip install -U tensorflow-model-optimization

# 2. Restart runtime automatically after install
import os
os.kill(os.getpid(), 9)

Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m41.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.23.5 which is incompatible.
chex 0.1.89 requires numpy>=1.24.1, but you have numpy 1.23.5 which is incompatible.
xarray 2025.1.2 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.
bigframes 1.42.0 requ

Collecting tensorflow==2.14.0
  Downloading tensorflow-2.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting absl-py>=1.0.0 (from tensorflow==2.14.0)
  Downloading absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting astunparse>=1.6.0 (from tensorflow==2.14.0)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=23.5.26 (from tensorflow==2.14.0)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow==2.14.0)
  Downloading gast-0.6.0-py3-none-any.whl.metadata (1.3 kB)
Collecting google-pasta>=0.1.1 (from tensorflow==2.14.0)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting h5py>=2.9.0 (from tensorflow==2.14.0)
  Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Collecting libclang>=13.0.0 (from tensorflow==2.14.0)
  Downloading libcla

Collecting absl-py~=1.2 (from tensorflow-model-optimization)
  Downloading absl_py-1.4.0-py3-none-any.whl.metadata (2.3 kB)
Collecting numpy~=1.23 (from tensorflow-model-optimization)
  Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Downloading absl_py-1.4.0-py3-none-any.whl (126 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.5/126.5 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hUsing cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
Installing collected packages: numpy, absl-py
  Attempting uninstall: numpy
    Found existing installation: numpy 2.2.4
    Uninstalling numpy-2.2.4:
      Successfully uninstalled numpy-2.2.4
  Attempting uninstall: absl-py
    Found existing installation: absl-py 2.2.2
    Uninstalling absl-py-2.2.2:
      Successfully uninstalled absl-py-2.2.2
[31mERROR: pip's dependency resolver does not currently take into account all the packag