<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/pruning4(L1_Norm_Filter).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

Cloning into 'cnn_pruning_cifar10'...
remote: Enumerating objects: 99, done.[K
remote: Counting objects: 100% (99/99), done.[K
remote: Compressing objects: 100% (95/95), done.[K
remote: Total 99 (delta 38), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (99/99), 59.73 KiB | 8.53 MiB/s, done.
Resolving deltas: 100% (38/38), done.
/content/cnn_pruning_cifar10


In [2]:
!pip install -q tensorflow-model-optimization

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_model_optimization as tfmot
from models.resnet56_baseline import build_resnet56

In [4]:
def get_prunable_layers(model):
    return [layer for layer in model.layers if isinstance(layer, tf.keras.layers.Conv2D)]
def get_l1_norms(layer):
    weights = layer.get_weights()[0]  # shape: (k, k, in_channels, out_channels)
    l1_norms = np.sum(np.abs(weights), axis=(0, 1, 2))  # per filter
    return l1_norms


In [5]:
def apply_l1_filter_pruning(model, sparsity=0.3):
    prunable_layers = get_prunable_layers(model)
    for layer in prunable_layers:
        weights, bias = layer.get_weights() if len(layer.get_weights()) == 2 else (layer.get_weights()[0], None)
        l1_norms = np.sum(np.abs(weights), axis=(0, 1, 2))
        num_filters = weights.shape[-1]
        k = int(sparsity * num_filters)
        prune_indices = np.argsort(l1_norms)[:k]

        # Zero-out selected filters
        for i in prune_indices:
            weights[..., i] = 0
            if bias is not None:
                bias[i] = 0

        if bias is not None:
            layer.set_weights([weights, bias])
        else:
            layer.set_weights([weights])


In [6]:
def train_l1_pruned_model(build_model_fn, x_train, y_train, x_val, y_val,
                          sparsity=0.3, epochs=50, batch_size=128, save_path='l1prune_model.h5'):
    model = build_model_fn()
    apply_l1_filter_pruning(model, sparsity)

    early_stop = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(save_path, save_best_only=True)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    history = model.fit(x_train, y_train,
                        validation_data=(x_val, y_val),
                        epochs=epochs,
                        batch_size=batch_size,
                        callbacks=[early_stop, checkpoint],
                        verbose=2)

    return model, history


In [7]:
# Load and normalize CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train structured pruned model
model, history = train_l1_pruned_model(
    build_model_fn=build_resnet56,
    x_train=x_train,
    y_train=y_train,
    x_val=x_test,
    y_val=y_test,
    sparsity=0.3,  # Prune 30% filters
    save_path='l1prune_model.h5'
)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 0us/step
Epoch 1/50




391/391 - 101s - 258ms/step - accuracy: 0.4232 - loss: 1.6060 - val_accuracy: 0.4636 - val_loss: 1.5088
Epoch 2/50




391/391 - 47s - 120ms/step - accuracy: 0.5981 - loss: 1.1168 - val_accuracy: 0.5264 - val_loss: 1.3259
Epoch 3/50
391/391 - 39s - 101ms/step - accuracy: 0.6866 - loss: 0.8822 - val_accuracy: 0.5508 - val_loss: 1.4677
Epoch 4/50




391/391 - 29s - 75ms/step - accuracy: 0.7378 - loss: 0.7407 - val_accuracy: 0.6347 - val_loss: 1.1221
Epoch 5/50




391/391 - 41s - 105ms/step - accuracy: 0.7796 - loss: 0.6299 - val_accuracy: 0.7008 - val_loss: 0.9262
Epoch 6/50




391/391 - 41s - 106ms/step - accuracy: 0.8072 - loss: 0.5519 - val_accuracy: 0.7129 - val_loss: 0.8532
Epoch 7/50
391/391 - 39s - 99ms/step - accuracy: 0.8304 - loss: 0.4811 - val_accuracy: 0.6584 - val_loss: 1.1734
Epoch 8/50
391/391 - 29s - 74ms/step - accuracy: 0.8497 - loss: 0.4230 - val_accuracy: 0.7259 - val_loss: 0.9460
Epoch 9/50
391/391 - 39s - 101ms/step - accuracy: 0.8705 - loss: 0.3683 - val_accuracy: 0.6864 - val_loss: 1.0908
Epoch 10/50




391/391 - 44s - 112ms/step - accuracy: 0.8881 - loss: 0.3176 - val_accuracy: 0.7683 - val_loss: 0.7603
Epoch 11/50
391/391 - 29s - 74ms/step - accuracy: 0.9023 - loss: 0.2738 - val_accuracy: 0.7263 - val_loss: 0.9759
Epoch 12/50
391/391 - 28s - 70ms/step - accuracy: 0.9147 - loss: 0.2348 - val_accuracy: 0.6591 - val_loss: 1.4414
Epoch 13/50
391/391 - 28s - 71ms/step - accuracy: 0.9287 - loss: 0.1979 - val_accuracy: 0.6784 - val_loss: 1.6478
Epoch 14/50
391/391 - 41s - 105ms/step - accuracy: 0.9393 - loss: 0.1709 - val_accuracy: 0.7124 - val_loss: 1.1624
Epoch 15/50
391/391 - 28s - 71ms/step - accuracy: 0.9478 - loss: 0.1455 - val_accuracy: 0.7171 - val_loss: 1.4336


In [8]:
# Load and normalize CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train structured pruned model
model, history = train_l1_pruned_model(
    build_model_fn=build_resnet56,
    x_train=x_train,
    y_train=y_train,
    x_val=x_test,
    y_val=y_test,
    sparsity=0.5,  # Prune 50% filters
    save_path='l1prune_model.h5'
)


Epoch 1/50




391/391 - 96s - 245ms/step - accuracy: 0.4350 - loss: 1.5824 - val_accuracy: 0.3271 - val_loss: 2.0344
Epoch 2/50




391/391 - 29s - 75ms/step - accuracy: 0.6060 - loss: 1.1067 - val_accuracy: 0.5465 - val_loss: 1.3425
Epoch 3/50




391/391 - 39s - 101ms/step - accuracy: 0.6749 - loss: 0.9119 - val_accuracy: 0.6019 - val_loss: 1.2364
Epoch 4/50




391/391 - 42s - 107ms/step - accuracy: 0.7197 - loss: 0.7898 - val_accuracy: 0.6380 - val_loss: 1.0626
Epoch 5/50




391/391 - 28s - 71ms/step - accuracy: 0.7538 - loss: 0.6971 - val_accuracy: 0.6620 - val_loss: 1.0181
Epoch 6/50




391/391 - 42s - 108ms/step - accuracy: 0.7825 - loss: 0.6173 - val_accuracy: 0.7030 - val_loss: 0.8705
Epoch 7/50
391/391 - 39s - 100ms/step - accuracy: 0.8084 - loss: 0.5440 - val_accuracy: 0.6800 - val_loss: 1.0013
Epoch 8/50
391/391 - 41s - 104ms/step - accuracy: 0.8293 - loss: 0.4879 - val_accuracy: 0.6561 - val_loss: 1.0871
Epoch 9/50
391/391 - 28s - 71ms/step - accuracy: 0.8483 - loss: 0.4309 - val_accuracy: 0.6510 - val_loss: 1.2013
Epoch 10/50
391/391 - 28s - 72ms/step - accuracy: 0.8658 - loss: 0.3787 - val_accuracy: 0.7257 - val_loss: 0.8885
Epoch 11/50
391/391 - 40s - 103ms/step - accuracy: 0.8809 - loss: 0.3387 - val_accuracy: 0.7191 - val_loss: 0.9385


In [9]:
# Load and normalize CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train structured pruned model
model, history = train_l1_pruned_model(
    build_model_fn=build_resnet56,
    x_train=x_train,
    y_train=y_train,
    x_val=x_test,
    y_val=y_test,
    sparsity=0.7,  # Prune 70% filters
    save_path='l1prune_model.h5'
)


Epoch 1/50




391/391 - 97s - 248ms/step - accuracy: 0.4134 - loss: 1.6136 - val_accuracy: 0.2934 - val_loss: 1.9230
Epoch 2/50




391/391 - 29s - 75ms/step - accuracy: 0.5653 - loss: 1.2133 - val_accuracy: 0.5259 - val_loss: 1.3937
Epoch 3/50
391/391 - 40s - 103ms/step - accuracy: 0.6441 - loss: 1.0035 - val_accuracy: 0.5321 - val_loss: 1.4187
Epoch 4/50




391/391 - 42s - 106ms/step - accuracy: 0.6943 - loss: 0.8662 - val_accuracy: 0.6399 - val_loss: 1.0543
Epoch 5/50




391/391 - 41s - 104ms/step - accuracy: 0.7340 - loss: 0.7594 - val_accuracy: 0.6589 - val_loss: 0.9829
Epoch 6/50




391/391 - 40s - 101ms/step - accuracy: 0.7599 - loss: 0.6857 - val_accuracy: 0.7090 - val_loss: 0.8821
Epoch 7/50
391/391 - 40s - 103ms/step - accuracy: 0.7823 - loss: 0.6214 - val_accuracy: 0.6629 - val_loss: 1.0522
Epoch 8/50
391/391 - 41s - 104ms/step - accuracy: 0.8023 - loss: 0.5675 - val_accuracy: 0.6250 - val_loss: 1.2435
Epoch 9/50
391/391 - 28s - 73ms/step - accuracy: 0.8195 - loss: 0.5176 - val_accuracy: 0.6942 - val_loss: 0.9262
Epoch 10/50
391/391 - 40s - 101ms/step - accuracy: 0.8296 - loss: 0.4838 - val_accuracy: 0.6525 - val_loss: 1.2405
Epoch 11/50
391/391 - 29s - 74ms/step - accuracy: 0.8478 - loss: 0.4349 - val_accuracy: 0.6636 - val_loss: 1.1216
