<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/SNIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

Cloning into 'cnn_pruning_cifar10'...
remote: Enumerating objects: 81, done.[K
remote: Counting objects: 100% (81/81), done.[K
remote: Compressing objects: 100% (77/77), done.[K
remote: Total 81 (delta 28), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (81/81), 39.53 KiB | 2.82 MiB/s, done.
Resolving deltas: 100% (28/28), done.
/content/cnn_pruning_cifar10


In [5]:
!pip install -q tensorflow-model-optimization

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m235.5/242.5 kB[0m [31m7.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [13]:
import tensorflow as tf
import numpy as np
from models.resnet56_baseline import build_resnet56

In [14]:
def compute_snip_scores(model, x_batch, y_batch, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy()):
    with tf.GradientTape() as tape:
        preds = model(x_batch, training=True)
        loss = loss_fn(y_batch, preds)

    grads = tape.gradient(loss, model.trainable_variables)
    snip_scores = [tf.abs(g * w) for g, w in zip(grads, model.trainable_variables) if g is not None]
    return snip_scores

In [15]:
def snip_prune_model(model, snip_scores, sparsity):
    all_scores = tf.concat([tf.reshape(score, [-1]) for score in snip_scores], axis=0)
    k = int((1 - sparsity) * tf.size(all_scores).numpy())
    threshold = tf.sort(all_scores, direction='DESCENDING')[k]

    masks = [tf.cast(score >= threshold, tf.float32) for score in snip_scores]
    pruned_weights = [w * m for w, m in zip(model.trainable_variables, masks)]

    for var, pruned in zip(model.trainable_variables, pruned_weights):
        var.assign(pruned)

    return masks


In [16]:
def train_pruned_model(model, x_train, y_train, x_val, y_val, epochs=50, batch_size=128, save_path="snip_model.h5"):
    early_stop = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(save_path, save_best_only=True)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    history = model.fit(x_train, y_train,
                        validation_data=(x_val, y_val),
                        epochs=epochs,
                        batch_size=batch_size,
                        callbacks=[early_stop, checkpoint],
                        verbose=2)

    return history


In [17]:
# 1. Load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 2. Build your ResNet-56
model = build_resnet56()  # Replace with your function

# 3. Get SNIP scores from a small batch
batch_x, batch_y = x_train[:512], y_train[:512]
snip_scores = compute_snip_scores(model, batch_x, batch_y)

# 4. Apply pruning with desired sparsity
snip_prune_model(model, snip_scores, sparsity=0.5)

# 5. Train pruned model
history = train_pruned_model(model, x_train, y_train, x_test, y_test)


Epoch 1/50


KeyboardInterrupt: 