<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/pruning4(One_cycle_Structured_Pruning_with_Stability_Driven_Structure_Search).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

Cloning into 'cnn_pruning_cifar10'...
remote: Enumerating objects: 114, done.[K
remote: Counting objects: 100% (114/114), done.[K
remote: Compressing objects: 100% (110/110), done.[K
remote: Total 114 (delta 47), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (114/114), 85.20 KiB | 660.00 KiB/s, done.
Resolving deltas: 100% (47/47), done.
/content/cnn_pruning_cifar10


In [9]:
!ls

 data			     'pruning4(L1_Norm_Filter).ipynb'
 lth_pruning_20_40_60.ipynb  'pruning5(Random_Pruning_Unstructured).ipynb'
 models			      README.md
'pruning1(lth).ipynb'	      ResNet56_baseline_model.ipynb
'pruning2(SNIP).ipynb'	      traning
'pruning3(MAG_50).ipynb'


In [2]:
!pip install -q tensorflow-model-optimization

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m235.5/242.5 kB[0m [31m7.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import tensorflow as tf
import numpy as np
from models.resnet56_baseline import build_resnet56

def get_conv_layers(model):
    return [layer for layer in model.layers if isinstance(layer, tf.keras.layers.Conv2D)]

def compute_group_saliency(model):
    saliency = {}
    for layer in get_conv_layers(model):
        weights = layer.get_weights()[0]  # shape: (k, k, in_channels, out_channels)
        l2_norms = np.linalg.norm(weights.reshape(-1, weights.shape[-1]), axis=0)
        saliency[layer.name] = l2_norms
    return saliency

def prune_filters(model, saliency, pruning_ratio):
    for layer in get_conv_layers(model):
        weights, bias = layer.get_weights()
        l2_norms = saliency[layer.name]
        num_filters = weights.shape[-1]
        num_prune = int(pruning_ratio * num_filters)
        prune_indices = np.argsort(l2_norms)[:num_prune]
        weights[..., prune_indices] = 0
        if bias is not None:
            bias[prune_indices] = 0
        layer.set_weights([weights, bias])

def train_one_cycle_pruned_model(x_train, y_train, x_val, y_val,
                                 pruning_ratio=0.3, epochs=50, batch_size=128):
    model = build_resnet56()
    optimizer = tf.keras.optimizers.Adam()
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
    train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
    val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

    # Initial training to compute saliency
    model.compile(optimizer=optimizer, loss=loss_fn, metrics=[train_acc_metric])
    model.fit(x_train, y_train, epochs=5, batch_size=batch_size, validation_data=(x_val, y_val))

    # Compute saliency and prune
    saliency = compute_group_saliency(model)
    prune_filters(model, saliency, pruning_ratio)

    # Fine-tune pruned model
    model.compile(optimizer=optimizer, loss=loss_fn, metrics=[train_acc_metric])
    model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_val, y_val))

    return model

In [4]:
# Load and preprocess CIFAR-10 data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train pruned model
model = train_one_cycle_pruned_model(x_train, y_train, x_test, y_test, pruning_ratio=0.3)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 0us/step
Epoch 1/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m99s[0m 120ms/step - loss: 2.0687 - sparse_categorical_accuracy: 0.3562 - val_loss: 1.4917 - val_sparse_categorical_accuracy: 0.4618
Epoch 2/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 73ms/step - loss: 1.1231 - sparse_categorical_accuracy: 0.5958 - val_loss: 1.4507 - val_sparse_categorical_accuracy: 0.5460
Epoch 3/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 72ms/step - loss: 0.8655 - sparse_categorical_accuracy: 0.6929 - val_loss: 0.9674 - val_sparse_categorical_accuracy: 0.6581
Epoch 4/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 72ms/step - loss: 0.6988 - sparse_categorical_accuracy: 0.7505 - val_loss: 1.1170 - val_sparse_categorical_accuracy: 0.6447
Epoch 5/5
[1m391/391[0m [32