<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/Effective_Automatic_Channel_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

fatal: destination path 'cnn_pruning_cifar10' already exists and is not an empty directory.
/content/cnn_pruning_cifar10


In [2]:
!pip install -q tensorflow-model-optimization

In [3]:
import tensorflow as tf
import numpy as np
from sklearn.cluster import KMeans
from models.resnet20 import build_resnet20  # Your custom ResNet-20 model

In [4]:
# -------------------------------
# Step 1: Cluster filters using cosine similarity
# -------------------------------
def cluster_filters(layer_weights, num_clusters):
    filters = layer_weights.reshape(layer_weights.shape[-1], -1)
    filters_norm = np.linalg.norm(filters, axis=1, keepdims=True)
    filters_normalized = filters / (filters_norm + 1e-8)
    kmeans = KMeans(n_clusters=num_clusters, random_state=0, n_init='auto')
    return kmeans.labels_

In [10]:
from sklearn.cluster import KMeans
import numpy as np

def cluster_filters(layer_weights, num_clusters):
    # Flatten each filter (kernel) to 1D
    filters = layer_weights.reshape(-1, layer_weights.shape[-1])
    filters = filters.transpose()  # Shape: (num_filters, filter_size)

    # Normalize filters
    filters_norm = np.linalg.norm(filters, axis=1, keepdims=True)
    filters_normalized = filters / (filters_norm + 1e-8)

    # Fit KMeans and return labels
    kmeans = KMeans(n_clusters=num_clusters, random_state=0, n_init='auto')
    kmeans.fit(filters_normalized)
    return kmeans.labels_


In [11]:
def train_eacp_model(x_train, y_train, x_val, y_val,
                     sparsity=0.5, epochs=50, batch_size=128):
    # Build and compile model
    model = build_resnet20()
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Apply EACP pruning
    print(f"✂️  Applying EACP pruning with sparsity = {sparsity:.2f}...")
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Conv2D):
            weights = layer.get_weights()[0]
            num_filters = weights.shape[-1]
            num_clusters = max(1, int(num_filters * (1 - sparsity)))

            if num_clusters < num_filters:
                cluster_labels = cluster_filters(weights, num_clusters)
                prune_filters(layer, cluster_labels)

    # Fine-tune pruned model
    print("🔁 Fine-tuning pruned model...")
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_val, y_val))

    return model


In [None]:
# Load CIFAR-10 (unchanged)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train with EACP (ResNet-20)
model = train_eacp_model(x_train, y_train, x_test, y_test, sparsity=0.3, epochs=30, batch_size=64)

✂️  Applying EACP pruning with sparsity = 0.30...
🔁 Fine-tuning pruned model...
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30

In [None]:
# Load CIFAR-10 (unchanged)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train with EACP (ResNet-20)
model = train_eacp_model(x_train, y_train, x_test, y_test, sparsity=0.5, epochs=30, batch_size=64)

In [None]:
# Load CIFAR-10 (unchanged)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train with EACP (ResNet-20)
model = train_eacp_model(x_train, y_train, x_test, y_test, sparsity=0.7, epochs=30, batch_size=64)

In [None]:
# 1. Force compatible versions (reset everything to Colab defaults)
!pip install -U --force-reinstall numpy==1.23.5
!pip install -U --force-reinstall tensorflow==2.14.0
!pip install -U tensorflow-model-optimization

# 2. Restart runtime automatically after install
import os
os.kill(os.getpid(), 9)

Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.23.5 which is incompatible.
chex 0.1.89 requires numpy>=1.24.1, but you have numpy 1.23.5 which is incompatible.
xarray 2025.1.2 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.
bigframes 1.42.0 require

Collecting tensorflow==2.14.0
  Downloading tensorflow-2.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting absl-py>=1.0.0 (from tensorflow==2.14.0)
  Downloading absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting astunparse>=1.6.0 (from tensorflow==2.14.0)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=23.5.26 (from tensorflow==2.14.0)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow==2.14.0)
  Downloading gast-0.6.0-py3-none-any.whl.metadata (1.3 kB)
Collecting google-pasta>=0.1.1 (from tensorflow==2.14.0)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting h5py>=2.9.0 (from tensorflow==2.14.0)
  Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Collecting libclang>=13.0.0 (from tensorflow==2.14.0)
  Downloading libcla

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Collecting absl-py~=1.2 (from tensorflow-model-optimization)
  Downloading absl_py-1.4.0-py3-none-any.whl.metadata (2.3 kB)
Collecting numpy~=1.23 (from tensorflow-model-optimization)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m258.1 kB/s[0m eta [36m0:00:00[0m
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading absl_py-1.4.0-py3-none-any.whl (126 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.5/126.5 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_