<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/pruning4(L1_Norm_Filter).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

Cloning into 'cnn_pruning_cifar10'...
remote: Enumerating objects: 90, done.[K
remote: Counting objects: 100% (90/90), done.[K
remote: Compressing objects: 100% (86/86), done.[K
remote: Total 90 (delta 33), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (90/90), 46.67 KiB | 1.20 MiB/s, done.
Resolving deltas: 100% (33/33), done.
/content/cnn_pruning_cifar10


In [2]:
!pip install -q tensorflow-model-optimization

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━[0m [32m174.1/242.5 kB[0m [31m5.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [8]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_model_optimization as tfmot
from models.resnet56_baseline import build_resnet56

In [4]:
def get_prunable_layers(model):
    return [layer for layer in model.layers if isinstance(layer, tf.keras.layers.Conv2D)]
def get_l1_norms(layer):
    weights = layer.get_weights()[0]  # shape: (k, k, in_channels, out_channels)
    l1_norms = np.sum(np.abs(weights), axis=(0, 1, 2))  # per filter
    return l1_norms


In [5]:
def apply_l1_filter_pruning(model, sparsity=0.3):
    prunable_layers = get_prunable_layers(model)
    for layer in prunable_layers:
        weights, bias = layer.get_weights() if len(layer.get_weights()) == 2 else (layer.get_weights()[0], None)
        l1_norms = np.sum(np.abs(weights), axis=(0, 1, 2))
        num_filters = weights.shape[-1]
        k = int(sparsity * num_filters)
        prune_indices = np.argsort(l1_norms)[:k]

        # Zero-out selected filters
        for i in prune_indices:
            weights[..., i] = 0
            if bias is not None:
                bias[i] = 0

        if bias is not None:
            layer.set_weights([weights, bias])
        else:
            layer.set_weights([weights])


In [6]:
def train_l1_pruned_model(build_model_fn, x_train, y_train, x_val, y_val,
                          sparsity=0.3, epochs=50, batch_size=128, save_path='l1prune_model.h5'):
    model = build_model_fn()
    apply_l1_filter_pruning(model, sparsity)

    early_stop = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(save_path, save_best_only=True)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    history = model.fit(x_train, y_train,
                        validation_data=(x_val, y_val),
                        epochs=epochs,
                        batch_size=batch_size,
                        callbacks=[early_stop, checkpoint],
                        verbose=2)

    return model, history


In [9]:
# Load and normalize CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train structured pruned model
model, history = train_l1_pruned_model(
    build_model_fn=build_resnet56,
    x_train=x_train,
    y_train=y_train,
    x_val=x_test,
    y_val=y_test,
    sparsity=0.3,  # Prune 30% filters
    save_path='l1prune_model.h5'
)


Epoch 1/50


KeyboardInterrupt: 

In [None]:
# Load and normalize CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train structured pruned model
model, history = train_l1_pruned_model(
    build_model_fn=build_resnet56,
    x_train=x_train,
    y_train=y_train,
    x_val=x_test,
    y_val=y_test,
    sparsity=0.5,  # Prune 50% filters
    save_path='l1prune_model.h5'
)


In [None]:
# Load and normalize CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train structured pruned model
model, history = train_l1_pruned_model(
    build_model_fn=build_resnet56,
    x_train=x_train,
    y_train=y_train,
    x_val=x_test,
    y_val=y_test,
    sparsity=0.7,  # Prune 70% filters
    save_path='l1prune_model.h5'
)
