<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/pruning3(MAG_50).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

fatal: destination path 'cnn_pruning_cifar10' already exists and is not an empty directory.
/content/cnn_pruning_cifar10


In [2]:
!pip install -q tensorflow-model-optimization

In [11]:
import tensorflow as tf
from models.resnet56_baseline import build_resnet56

In [4]:
def get_global_magnitude_scores(model):
    scores = []
    for weight in model.trainable_weights:
        if 'kernel' in weight.name:  # Ignore biases and BN params
            scores.append(tf.reshape(tf.abs(weight), [-1]))
    return tf.concat(scores, axis=0)


In [5]:
def apply_global_magnitude_pruning(model, sparsity):
    all_scores = get_global_magnitude_scores(model)
    k = int((1 - sparsity) * all_scores.shape[0])
    threshold = tf.sort(all_scores, direction='ASCENDING')[k]

    # Apply mask
    for weight in model.trainable_weights:
        if 'kernel' in weight.name:
            mask = tf.cast(tf.abs(weight) >= threshold, tf.float32)
            pruned = weight * mask
            weight.assign(pruned)


In [6]:
def train_magnitude_pruned_model(build_model_fn, x_train, y_train, x_val, y_val,
                                 sparsity=0.5, epochs=50, batch_size=128, save_path='magprune_model.h5'):
    model = build_model_fn()  # build fresh model
    apply_global_magnitude_pruning(model, sparsity)

    early_stop = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(save_path, save_best_only=True)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    history = model.fit(x_train, y_train,
                        validation_data=(x_val, y_val),
                        epochs=epochs,
                        batch_size=batch_size,
                        callbacks=[early_stop, checkpoint],
                        verbose=2)

    return model, history


In [7]:
!ls


 data	 'pruning1(lth).ipynb'	  README.md			  traning
 models  'pruning2(SNIP).ipynb'   ResNet56_baseline_model.ipynb


In [12]:
# Load and normalize CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train magnitude-pruned model
model, history = train_magnitude_pruned_model(
    build_model_fn=build_resnet56,  # your model builder
    x_train=x_train,
    y_train=y_train,
    x_val=x_test,
    y_val=y_test,
    sparsity=0.5,  # 50% pruning
    save_path='magprune_model.h5'
)


Epoch 1/50


KeyboardInterrupt: 