<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/lth_pruning_20_40_60.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

Cloning into 'cnn_pruning_cifar10'...
remote: Enumerating objects: 50, done.[K
remote: Counting objects: 100% (50/50), done.[K
remote: Compressing objects: 100% (46/46), done.[K
remote: Total 50 (delta 14), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (50/50), 21.16 KiB | 1.17 MiB/s, done.
Resolving deltas: 100% (14/14), done.
/content/cnn_pruning_cifar10


In [2]:
import tensorflow as tf
tf.config.run_functions_eagerly(True)  # Optional: For debugging

# Confirm eager execution is on
print("Eager Execution:", tf.executing_eagerly())

Eager Execution: True


In [3]:
import os
import numpy as np
from models.resnet56_baseline import build_resnet56
from data.cifar10_loader import load_cifar10_data
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, GlobalAveragePooling2D, Dense
from tensorflow.keras import regularizers

In [4]:
# Paths
CHECKPOINT_PATH = './checkpoints/lth_resnet56_cifar10.keras'
FINAL_MODEL_PATH = './models/lth_resnet56_cifar10_final.keras'
INITIAL_WEIGHTS_PATH = './models/initial_resnet56_weights.npy'

In [5]:
import os
import numpy as np

INITIAL_WEIGHTS_PATH = './models/initial_resnet56_weights.npz'
os.makedirs('./models', exist_ok=True)

def save_initial_weights(model):
    weights = model.get_weights()
    np.savez(INITIAL_WEIGHTS_PATH, *weights)
    print("Initial weights saved.")

def load_initial_weights(model):
    data = np.load(INITIAL_WEIGHTS_PATH)
    weights = [data[f'arr_{i}'] for i in range(len(data.files))]
    model.set_weights(weights)
    print("Initial weights loaded.")

In [6]:
def prune_model(model, pruning_fraction=0.2):
    weights = model.get_weights()
    pruned_weights = []

    for w in weights:
        if len(w.shape) > 1:  # Prune only weights, not biases or BN params
            k = int(np.prod(w.shape) * pruning_fraction)
            threshold = np.partition(np.abs(w.flatten()), k)[k]
            mask = np.abs(w) > threshold
            w = w * mask
        pruned_weights.append(w)

    model.set_weights(pruned_weights)
    print(f"Model pruned with {pruning_fraction * 100:.1f}% sparsity.")
    return model

In [45]:
def train_lth_model():
    train_gen, val_gen = load_cifar10_data(batch_size=64)
    model = build_resnet56(input_shape=(32, 32, 3), num_classes=10)

    if not os.path.exists(INITIAL_WEIGHTS_PATH):
        save_initial_weights(model)
    else:
        load_initial_weights(model)

    # Compile model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train initial model (or load pre-trained weights)
    print("Training initial/pruned model...")
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=CHECKPOINT_PATH,
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
    earlystop_cb = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy', patience=5, restore_best_weights=True)

    history = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    # Load best weights and prune
    model.load_weights(CHECKPOINT_PATH)
    load_initial_weights(model)  # Reset to initial weights


    #change for change in sparsity[0.4,0.6,0.8]
    model = prune_model(model, pruning_fraction=0.2)



    # Recompile and retrain the pruned model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    print("Re-training pruned model from initial weights...")
    history_pruned = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    model.save(FINAL_MODEL_PATH)
    print("LTH pruned model saved.")

    # Load initial weights if available
    if os.path.exists(INITIAL_WEIGHTS_PATH):
        load_initial_weights(model)
    else:
        save_initial_weights(model)

In [None]:
if __name__ == "__main__":
    train_lth_model()

Initial weights saved.
Training initial/pruned model...
Epoch 1/30


  self._warn_if_super_not_called()


[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - accuracy: 0.3438 - loss: 1.9568
Epoch 1: val_accuracy improved from -inf to 0.45920, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 57ms/step - accuracy: 0.3439 - loss: 1.9563 - val_accuracy: 0.4592 - val_loss: 1.6808
Epoch 2/30
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.5653 - loss: 1.2008
Epoch 2: val_accuracy improved from 0.45920 to 0.59130, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 34ms/step - accuracy: 0.5654 - loss: 1.2008 - val_accuracy: 0.5913 - val_loss: 1.1975
Epoch 3/30
[1m781/782[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 32ms/step - accuracy: 0.6581 - loss: 0.9663
Epoch 3: val_accuracy improved from 0.59130 to 0.67750, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m78

sparsity40

In [46]:
def train_lth_model_sparsity40():
    train_gen, val_gen = load_cifar10_data(batch_size=64)
    model = build_resnet56(input_shape=(32, 32, 3), num_classes=10)

    if not os.path.exists(INITIAL_WEIGHTS_PATH):
        save_initial_weights(model)
    else:
        load_initial_weights(model)

    # Compile model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train initial model (or load pre-trained weights)
    print("Training initial/pruned model...")
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=CHECKPOINT_PATH,
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
    earlystop_cb = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy', patience=5, restore_best_weights=True)

    history = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    # Load best weights and prune
    model.load_weights(CHECKPOINT_PATH)
    load_initial_weights(model)  # Reset to initial weights


    #change for change in sparsity[0.4,0.6,0.8]
    model = prune_model(model, pruning_fraction=0.4)



    # Recompile and retrain the pruned model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    print("Re-training pruned model from initial weights...")
    history_pruned = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    model.save(FINAL_MODEL_PATH)
    print("LTH pruned model saved.")

    # Load initial weights if available
    if os.path.exists(INITIAL_WEIGHTS_PATH):
        load_initial_weights(model)
    else:
        save_initial_weights(model)

In [None]:
if __name__ == "__main__":
    train_lth_model_sparsity40()

Initial weights saved.
Training initial/pruned model...
Epoch 1/30


  self._warn_if_super_not_called()


[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.3684 - loss: 1.8503
Epoch 1: val_accuracy improved from -inf to 0.46230, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1288s[0m 2s/step - accuracy: 0.3685 - loss: 1.8499 - val_accuracy: 0.4623 - val_loss: 1.7135
Epoch 2/30
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.5832 - loss: 1.1690
Epoch 2: val_accuracy improved from 0.46230 to 0.55720, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1262s[0m 2s/step - accuracy: 0.5832 - loss: 1.1689 - val_accuracy: 0.5572 - val_loss: 1.5803
Epoch 3/30
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.6858 - loss: 0.9002
Epoch 3: val_accuracy did not improve from 0.55720
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1253s[0m 2s/st

sparsity60

In [9]:
def train_lth_model_sparsity60():
    train_gen, val_gen = load_cifar10_data(batch_size=64)
    model = build_resnet56(input_shape=(32, 32, 3), num_classes=10)

    if not os.path.exists(INITIAL_WEIGHTS_PATH):
        save_initial_weights(model)
    else:
        load_initial_weights(model)

    # Compile model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train initial model (or load pre-trained weights)
    print("Training initial/pruned model...")
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=CHECKPOINT_PATH,
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
    earlystop_cb = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy', patience=5, restore_best_weights=True)

    history = model.fit(
        train_gen,
        epochs=10,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    # Load best weights and prune
    model.load_weights(CHECKPOINT_PATH)
    load_initial_weights(model)  # Reset to initial weights


    #change for change in sparsity[0.4,0.6,0.8]
    model = prune_model(model, pruning_fraction=0.6)



    # Recompile and retrain the pruned model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    print("Re-training pruned model from initial weights...")
    history_pruned = model.fit(
        train_gen,
        epochs=10,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    model.save(FINAL_MODEL_PATH)
    print("LTH pruned model saved.")

    # Load initial weights if available
    if os.path.exists(INITIAL_WEIGHTS_PATH):
        load_initial_weights(model)
    else:
        save_initial_weights(model)

In [8]:
if __name__ == "__main__":
    train_lth_model_sparsity60()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step
Initial weights saved.
Training initial/pruned model...
Epoch 1/10


  self._warn_if_super_not_called()


[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.3555 - loss: 1.9610
Epoch 1: val_accuracy improved from -inf to 0.47080, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1283s[0m 2s/step - accuracy: 0.3557 - loss: 1.9604 - val_accuracy: 0.4708 - val_loss: 1.5816
Epoch 2/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.5904 - loss: 1.1380
Epoch 2: val_accuracy improved from 0.47080 to 0.65020, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1299s[0m 2s/step - accuracy: 0.5904 - loss: 1.1379 - val_accuracy: 0.6502 - val_loss: 1.0119
Epoch 3/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.6807 - loss: 0.9008
Epoch 3: val_accuracy did not improve from 0.65020
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1321s[0m 2s/st

KeyboardInterrupt: 

In [None]:
if __name__ == "__main__":
    train_lth_model_sparsity60()

Initial weights loaded.
Training initial/pruned model...
Epoch 1/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.3685 - loss: 1.9454
Epoch 1: val_accuracy improved from -inf to 0.44530, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1259s[0m 2s/step - accuracy: 0.3686 - loss: 1.9449 - val_accuracy: 0.4453 - val_loss: 1.5660
Epoch 2/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.5913 - loss: 1.1489
Epoch 2: val_accuracy improved from 0.44530 to 0.61630, saving model to ./checkpoints/lth_resnet56_cifar10.keras
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1284s[0m 2s/step - accuracy: 0.5913 - loss: 1.1488 - val_accuracy: 0.6163 - val_loss: 1.0683
Epoch 3/10
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.6819 - loss: 0.9078
Epoch 3: val_accuracy improved from 0.61630 to 0.67230, s

In [1]:
def train_lth_model_sparsity60():
    train_gen, val_gen = load_cifar10_data(batch_size=64)

    # Load model and initial weights
    model = build_resnet56(input_shape=(32, 32, 3), num_classes=10)
    load_initial_weights(model)  # assumes weights are saved already

    # Apply pruning with desired sparsity
    model = prune_model(model, pruning_fraction=0.6)  # Change to 0.4 / 0.8 as needed

    # Recompile and retrain the pruned model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    print("Training pruned model from initial weights...")
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=CHECKPOINT_PATH,
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
    earlystop_cb = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy', patience=5, restore_best_weights=True)

    history_pruned = model.fit(
        train_gen,
        epochs=10,
        validation_data=val_gen,
        callbacks=[checkpoint_cb, earlystop_cb]
    )

    model.save(FINAL_MODEL_PATH)
    print("LTH pruned model saved.")

In [None]:
if __name__ == "__main__":
    train_lth_model_sparsity80()