<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/pruning1(lth).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys
import os
sys.path.append('../data')

In [2]:
import tensorflow as tf
tf.config.run_functions_eagerly(True)  # Optional: For debugging

# Confirm eager execution is on
print("Eager Execution:", tf.executing_eagerly())

Eager Execution: True


In [3]:
!pip install -q tensorflow-model-optimization

In [4]:
import os
import numpy as np
from models.resnet20 import build_resnet20
from data.preprocessing import load_cifar10_data
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, GlobalAveragePooling2D, Dense
from tensorflow.keras import regularizers

In [5]:
# Paths
CHECKPOINT_PATH = './checkpoints/lth_resnet56_cifar10.keras'
FINAL_MODEL_PATH = './models/lth_resnet56_cifar10_final.keras'
INITIAL_WEIGHTS_PATH = './models/initial_resnet56_weights.npy'

In [6]:
import os
import numpy as np

INITIAL_WEIGHTS_PATH = './models/initial_resnet56_weights.npz'
os.makedirs('./models', exist_ok=True)

def save_initial_weights(model):
    weights = model.get_weights()
    np.savez(INITIAL_WEIGHTS_PATH, *weights)
    print("Initial weights saved.")

def load_initial_weights(model):
    data = np.load(INITIAL_WEIGHTS_PATH)
    weights = [data[f'arr_{i}'] for i in range(len(data.files))]
    model.set_weights(weights)
    print("Initial weights loaded.")

In [7]:
def prune_model(model, pruning_fraction=0.2):
    weights = model.get_weights()
    pruned_weights = []

    for w in weights:
        if len(w.shape) > 1:  # Prune only weights, not biases or BN params
            k = int(np.prod(w.shape) * pruning_fraction)
            threshold = np.partition(np.abs(w.flatten()), k)[k]
            mask = np.abs(w) > threshold
            w = w * mask
        pruned_weights.append(w)

    model.set_weights(pruned_weights)
    print(f"Model pruned with {pruning_fraction * 100:.1f}% sparsity.")
    return model

In [9]:
def train_lth_model():
    train_gen, val_gen = load_cifar10_data(batch_size=64)
    model = build_resnet20(input_shape=(32, 32, 3), num_classes=10)

    if not os.path.exists(INITIAL_WEIGHTS_PATH):
        save_initial_weights(model)
    else:
        load_initial_weights(model)

    # Compile model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train initial model (or load pre-trained weights)
    print("Training initial/pruned model...")
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=CHECKPOINT_PATH,
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
    earlystop_cb = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy', patience=5, restore_best_weights=True)

    history = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    # Load best weights and prune
    model.load_weights(CHECKPOINT_PATH)
    load_initial_weights(model)  # Reset to initial weights


    #change for change in sparsity[0.4,0.6,0.8]
    model = prune_model(model, pruning_fraction=0.2)



    # Recompile and retrain the pruned model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    print("Re-training pruned model from initial weights...")
    history_pruned = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    model.save(FINAL_MODEL_PATH)
    print("LTH pruned model saved.")

    # Load initial weights if available
    if os.path.exists(INITIAL_WEIGHTS_PATH):
        load_initial_weights(model)
    else:
        save_initial_weights(model)

In [10]:
if __name__ == "__main__":
    train_lth_model()

TypeError: load_cifar10_data() got an unexpected keyword argument 'batch_size'

sparsity40

In [46]:
def train_lth_model_sparsity40():
    train_gen, val_gen = load_cifar10_data(batch_size=64)
    model = build_resnet56(input_shape=(32, 32, 3), num_classes=10)

    if not os.path.exists(INITIAL_WEIGHTS_PATH):
        save_initial_weights(model)
    else:
        load_initial_weights(model)

    # Compile model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train initial model (or load pre-trained weights)
    print("Training initial/pruned model...")
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=CHECKPOINT_PATH,
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
    earlystop_cb = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy', patience=5, restore_best_weights=True)

    history = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    # Load best weights and prune
    model.load_weights(CHECKPOINT_PATH)
    load_initial_weights(model)  # Reset to initial weights


    #change for change in sparsity[0.4,0.6,0.8]
    model = prune_model(model, pruning_fraction=0.4)



    # Recompile and retrain the pruned model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    print("Re-training pruned model from initial weights...")
    history_pruned = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    model.save(FINAL_MODEL_PATH)
    print("LTH pruned model saved.")

    # Load initial weights if available
    if os.path.exists(INITIAL_WEIGHTS_PATH):
        load_initial_weights(model)
    else:
        save_initial_weights(model)

In [None]:
if __name__ == "__main__":
    train_lth_model_sparsity40()

Initial weights saved.
Training initial/pruned model...
Epoch 1/30


  self._warn_if_super_not_called()


[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.3684 - loss: 1.8503
Epoch 1: val_accuracy improved from -inf to 0.46230, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1288s[0m 2s/step - accuracy: 0.3685 - loss: 1.8499 - val_accuracy: 0.4623 - val_loss: 1.7135
Epoch 2/30
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.5832 - loss: 1.1690
Epoch 2: val_accuracy improved from 0.46230 to 0.55720, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1262s[0m 2s/step - accuracy: 0.5832 - loss: 1.1689 - val_accuracy: 0.5572 - val_loss: 1.5803
Epoch 3/30
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.6858 - loss: 0.9002
Epoch 3: val_accuracy did not improve from 0.55720
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1253s[0m 2s/st

sparsity60

In [7]:
def train_lth_model_sparsity60():
    train_gen, val_gen = load_cifar10_data(batch_size=64)
    model = build_resnet56(input_shape=(32, 32, 3), num_classes=10)

    if not os.path.exists(INITIAL_WEIGHTS_PATH):
        save_initial_weights(model)
    else:
        load_initial_weights(model)

    # Compile model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train initial model (or load pre-trained weights)
    print("Training initial/pruned model...")
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=CHECKPOINT_PATH,
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
    earlystop_cb = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy', patience=5, restore_best_weights=True)

    history = model.fit(
        train_gen,
        epochs=10,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    # Load best weights and prune
    model.load_weights(CHECKPOINT_PATH)
    load_initial_weights(model)  # Reset to initial weights


    #change for change in sparsity[0.4,0.6,0.8]
    model = prune_model(model, pruning_fraction=0.6)



    # Recompile and retrain the pruned model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    print("Re-training pruned model from initial weights...")
    history_pruned = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    model.save(FINAL_MODEL_PATH)
    print("LTH pruned model saved.")

    # Load initial weights if available
    if os.path.exists(INITIAL_WEIGHTS_PATH):
        load_initial_weights(model)
    else:
        save_initial_weights(model)

In [None]:
if __name__ == "__main__":
    train_lth_model_sparsity60()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step
Initial weights saved.
Training initial/pruned model...
Epoch 1/10


  self._warn_if_super_not_called()


[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.3555 - loss: 1.9610
Epoch 1: val_accuracy improved from -inf to 0.47080, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1283s[0m 2s/step - accuracy: 0.3557 - loss: 1.9604 - val_accuracy: 0.4708 - val_loss: 1.5816
Epoch 2/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.5904 - loss: 1.1380
Epoch 2: val_accuracy improved from 0.47080 to 0.65020, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1299s[0m 2s/step - accuracy: 0.5904 - loss: 1.1379 - val_accuracy: 0.6502 - val_loss: 1.0119
Epoch 3/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.6807 - loss: 0.9008
Epoch 3: val_accuracy did not improve from 0.65020
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1321s[0m 2s/st

In [None]:
def train_lth_model_sparsity80():
    train_gen, val_gen = load_cifar10_data(batch_size=64)
    model = build_resnet56(input_shape=(32, 32, 3), num_classes=10)

    if not os.path.exists(INITIAL_WEIGHTS_PATH):
        save_initial_weights(model)
    else:
        load_initial_weights(model)

    # Compile model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train initial model (or load pre-trained weights)
    print("Training initial/pruned model...")
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=CHECKPOINT_PATH,
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
    earlystop_cb = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy', patience=5, restore_best_weights=True)

    history = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    # Load best weights and prune
    model.load_weights(CHECKPOINT_PATH)
    load_initial_weights(model)  # Reset to initial weights


    #change for change in sparsity[0.4,0.6,0.8]
    model = prune_model(model, pruning_fraction=0.8)



    # Recompile and retrain the pruned model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    print("Re-training pruned model from initial weights...")
    history_pruned = model.fit(
        train_gen,
        epochs=10,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    model.save(FINAL_MODEL_PATH)
    print("LTH pruned model saved.")

    # Load initial weights if available
    if os.path.exists(INITIAL_WEIGHTS_PATH):
        load_initial_weights(model)
    else:
        save_initial_weights(model)

In [None]:
if __name__ == "__main__":
    train_lth_model_sparsity80()

sparsity80

In [23]:
from tensorflow import keras

In [29]:
import tensorflow_model_optimization as tfmot

def build_prunable_resnet56():
    inputs = keras.Input(shape=(32, 32, 3))

    # Wrap convolutional and dense layers with pruning
    x = tfmot.sparsity.keras.prune_low_magnitude(
        layers.Conv2D(16, (3, 3), padding='same', activation='relu')
    )(inputs)

    # Add the rest of ResNet blocks with pruning as needed
    # Dummy blocks here — replace with real ResNet-56 blocks
    for _ in range(3):
        x = tfmot.sparsity.keras.prune_low_magnitude(
            layers.Conv2D(16, (3, 3), padding='same', activation='relu')
        )(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = tfmot.sparsity.keras.prune_low_magnitude(
        layers.Dense(10, activation='softmax')
    )(x)

    model = keras.Model(inputs, outputs)
    return model

In [30]:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import os
import gc
import numpy as np

# Define your sparsity levels
sparsity_levels = [0.4, 0.6, 0.8]

# Store results
val_accuracies = {}
test_accuracies = {}

# Create directory to save models
os.makedirs("lth_models", exist_ok=True)

# Clear memory
gc.collect()
tf.keras.backend.clear_session()

# Build and store weights before starting pruning loop
base_model = build_resnet56()
initial_weights = base_model.get_weights()

for sparsity in sparsity_levels:
    print(f"\n🔧 Starting LTH iteration with {int(sparsity * 100)}% sparsity...\n")

    # Reload initial model and weights
    model = build_resnet56()
    model.set_weights(initial_weights)

    # ✅ Define optimizer before compiling
    optimizer = tf.keras.optimizers.Adam()

    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

    # Apply pruning
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(sparsity, begin_step=0)
    }
    pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

    pruned_model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

    # Callbacks
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
    ]

    # Train
    history = pruned_model.fit(
        x_train, y_train,
        batch_size=128,
        epochs=10,  # Use a lower epoch count to conserve Colab units
        validation_split=0.1,
        callbacks=callbacks,
        verbose=1
    )

    # Evaluate
    test_loss, test_acc = pruned_model.evaluate(x_test, y_test, verbose=0)
    print(f"🔍 Test Accuracy at {int(sparsity*100)}% sparsity: {test_acc:.4f}")


🔧 Starting LTH iteration with 40% sparsity...



ValueError: `prune_low_magnitude` can only prune an object of the following types: keras.models.Sequential, keras functional model, keras.layers.Layer, list of keras.layers.Layer. You passed an object of type: Functional.

In [12]:
!pip install -q tensorflow-model-optimization

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m84.7 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0m

In [25]:
print(type(model))

<class 'keras.src.models.functional.Functional'>


In [26]:
print(isinstance(model, tf.keras.Model))  # Should be True

True


In [36]:
import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow_model_optimization as tfmot

# Function to build ResNet56 model
def build_resnet56():
    inputs = tf.keras.Input(shape=(32, 32, 3))

    # Example layers, you can add more blocks as needed
    x = layers.Conv2D(16, (3, 3), padding='same', activation='relu')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(10, activation='softmax')(x)

    model = models.Model(inputs, outputs)
    return model

# Function to build pruned ResNet56 model with pruning applied
def build_prunable_resnet56():
    inputs = tf.keras.Input(shape=(32, 32, 3))

    # Wrap the layers you want to prune with prune_low_magnitude
    x = tfmot.sparsity.keras.prune_low_magnitude(
        layers.Conv2D(16, (3, 3), padding='same', activation='relu'),
        pruning_schedule=tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0)
    )(inputs)
    x = layers.MaxPooling2D((2, 2))(x)
    x = tfmot.sparsity.keras.prune_low_magnitude(
        layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
        pruning_schedule=tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0)
    )(x)
    x = layers.GlobalAveragePooling2D()(x)
    outputs = tfmot.sparsity.keras.prune_low_magnitude(
        layers.Dense(10, activation='softmax'),
        pruning_schedule=tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0)
    )(x)

    model = models.Model(inputs, outputs)
    return model

# ... (rest of your code) ...
# Function to apply pruning for a given sparsity level
def apply_pruning(sparsity, x_train, y_train, x_test, y_test, initial_weights):
    print(f"🔧 Starting LTH iteration with {int(sparsity * 100)}% sparsity...")

    # Pruning scope is not needed when pruning the whole model
    model = build_prunable_resnet56()  # Create pruned model
    model.set_weights(initial_weights)  # Reset weights to initial

    # Compile the model
    optimizer = tf.keras.optimizers.Adam()
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

    # Callbacks for pruning updates
    callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]

    # Train the model with pruning
    model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test), callbacks=callbacks)

    return model

# ... (rest of your code remains the same) ...

# Load CIFAR-10 data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# Initialize the base model to get the initial weights
base_model = build_resnet56()
initial_weights = base_model.get_weights()

# Set the sparsity levels for pruning
sparsity_levels = [0.4, 0.5, 0.6, 0.7, 0.8]  # Example sparsity levels

# Loop over sparsity levels
for sparsity in sparsity_levels:
    model = apply_pruning(sparsity, x_train, y_train, x_test, y_test, initial_weights)
    # Save the model after each pruning iteration if needed
    # model.save(f"pruned_model_sparsity_{int(sparsity * 100)}.h5")

🔧 Starting LTH iteration with 40% sparsity...


ValueError: `prune_low_magnitude` can only prune an object of the following types: keras.models.Sequential, keras functional model, keras.layers.Layer, list of keras.layers.Layer. You passed an object of type: Conv2D.