<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/pruning2(SNIP).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

fatal: destination path 'cnn_pruning_cifar10' already exists and is not an empty directory.
/content/cnn_pruning_cifar10


In [4]:
!pip install -q tensorflow-model-optimization

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m83.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [5]:
import tensorflow as tf
import numpy as np
from models.resnet56_baseline import build_resnet56

In [6]:
def compute_snip_scores(model, x_batch, y_batch, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy()):
    with tf.GradientTape() as tape:
        preds = model(x_batch, training=True)
        loss = loss_fn(y_batch, preds)

    grads = tape.gradient(loss, model.trainable_variables)
    snip_scores = [tf.abs(g * w) for g, w in zip(grads, model.trainable_variables) if g is not None]
    return snip_scores

In [7]:
def snip_prune_model(model, snip_scores, sparsity):
    all_scores = tf.concat([tf.reshape(score, [-1]) for score in snip_scores], axis=0)
    k = int((1 - sparsity) * tf.size(all_scores).numpy())
    threshold = tf.sort(all_scores, direction='DESCENDING')[k]

    masks = [tf.cast(score >= threshold, tf.float32) for score in snip_scores]
    pruned_weights = [w * m for w, m in zip(model.trainable_variables, masks)]

    for var, pruned in zip(model.trainable_variables, pruned_weights):
        var.assign(pruned)

    return masks


In [8]:
def train_pruned_model(model, x_train, y_train, x_val, y_val, epochs=50, batch_size=128, save_path="snip_model.h5"):
    early_stop = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(save_path, save_best_only=True)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    history = model.fit(x_train, y_train,
                        validation_data=(x_val, y_val),
                        epochs=epochs,
                        batch_size=batch_size,
                        callbacks=[early_stop, checkpoint],
                        verbose=2)

    return history


In [9]:
# 1. Load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 2. Build your ResNet-56
model = build_resnet56()  # Replace with your function

# 3. Get SNIP scores from a small batch
batch_x, batch_y = x_train[:512], y_train[:512]
snip_scores = compute_snip_scores(model, batch_x, batch_y)

# 4. Apply pruning with desired sparsity
snip_prune_model(model, snip_scores, sparsity=0.5)

# 5. Train pruned model
history = train_pruned_model(model, x_train, y_train, x_test, y_test)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step
Epoch 1/50




391/391 - 101s - 259ms/step - accuracy: 0.4687 - loss: 1.4877 - val_accuracy: 0.4468 - val_loss: 1.7489
Epoch 2/50




391/391 - 47s - 120ms/step - accuracy: 0.6298 - loss: 1.0386 - val_accuracy: 0.5198 - val_loss: 1.4935
Epoch 3/50




391/391 - 40s - 103ms/step - accuracy: 0.7029 - loss: 0.8397 - val_accuracy: 0.6123 - val_loss: 1.1114
Epoch 4/50




391/391 - 41s - 105ms/step - accuracy: 0.7521 - loss: 0.7007 - val_accuracy: 0.6833 - val_loss: 0.9335
Epoch 5/50
391/391 - 41s - 105ms/step - accuracy: 0.7889 - loss: 0.5994 - val_accuracy: 0.6213 - val_loss: 1.2115
Epoch 6/50
391/391 - 41s - 105ms/step - accuracy: 0.8203 - loss: 0.5113 - val_accuracy: 0.6118 - val_loss: 1.2621
Epoch 7/50
391/391 - 41s - 104ms/step - accuracy: 0.8462 - loss: 0.4381 - val_accuracy: 0.6774 - val_loss: 1.1848
Epoch 8/50
391/391 - 40s - 103ms/step - accuracy: 0.8703 - loss: 0.3687 - val_accuracy: 0.7000 - val_loss: 0.9911
Epoch 9/50
391/391 - 29s - 74ms/step - accuracy: 0.8896 - loss: 0.3087 - val_accuracy: 0.6997 - val_loss: 1.1888


In [10]:
# 1. Load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 2. Build your ResNet-56
model = build_resnet56()  # Replace with your function

# 3. Get SNIP scores from a small batch
batch_x, batch_y = x_train[:512], y_train[:512]
snip_scores = compute_snip_scores(model, batch_x, batch_y)

# 4. Apply pruning with desired sparsity
snip_prune_model(model, snip_scores, sparsity=0.3)

# 5. Train pruned model
history = train_pruned_model(model, x_train, y_train, x_test, y_test)

Epoch 1/50




391/391 - 108s - 275ms/step - accuracy: 0.4593 - loss: 1.5328 - val_accuracy: 0.4845 - val_loss: 1.4763
Epoch 2/50




391/391 - 47s - 120ms/step - accuracy: 0.6339 - loss: 1.0216 - val_accuracy: 0.5684 - val_loss: 1.2879
Epoch 3/50
391/391 - 30s - 78ms/step - accuracy: 0.7195 - loss: 0.7952 - val_accuracy: 0.5833 - val_loss: 1.4018
Epoch 4/50
391/391 - 39s - 101ms/step - accuracy: 0.7681 - loss: 0.6586 - val_accuracy: 0.5361 - val_loss: 1.6273
Epoch 5/50




391/391 - 42s - 108ms/step - accuracy: 0.8064 - loss: 0.5533 - val_accuracy: 0.7403 - val_loss: 0.7618
Epoch 6/50
391/391 - 40s - 102ms/step - accuracy: 0.8329 - loss: 0.4760 - val_accuracy: 0.7096 - val_loss: 0.9423
Epoch 7/50
391/391 - 42s - 108ms/step - accuracy: 0.8534 - loss: 0.4119 - val_accuracy: 0.6797 - val_loss: 1.1364
Epoch 8/50
391/391 - 41s - 104ms/step - accuracy: 0.8810 - loss: 0.3400 - val_accuracy: 0.7174 - val_loss: 0.9654
Epoch 9/50
391/391 - 40s - 102ms/step - accuracy: 0.8964 - loss: 0.2906 - val_accuracy: 0.6505 - val_loss: 1.4389
Epoch 10/50
391/391 - 41s - 105ms/step - accuracy: 0.9159 - loss: 0.2351 - val_accuracy: 0.6534 - val_loss: 1.3954


In [11]:
# 1. Load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 2. Build your ResNet-56
model = build_resnet56()  # Replace with your function

# 3. Get SNIP scores from a small batch
batch_x, batch_y = x_train[:512], y_train[:512]
snip_scores = compute_snip_scores(model, batch_x, batch_y)

# 4. Apply pruning with desired sparsity
snip_prune_model(model, snip_scores, sparsity=0.7)

# 5. Train pruned model
history = train_pruned_model(model, x_train, y_train, x_test, y_test)

Epoch 1/50




391/391 - 98s - 252ms/step - accuracy: 0.4757 - loss: 1.4870 - val_accuracy: 0.4880 - val_loss: 1.5281
Epoch 2/50




391/391 - 51s - 129ms/step - accuracy: 0.6381 - loss: 1.0220 - val_accuracy: 0.5555 - val_loss: 1.2856
Epoch 3/50




391/391 - 41s - 105ms/step - accuracy: 0.7093 - loss: 0.8243 - val_accuracy: 0.6108 - val_loss: 1.2367
Epoch 4/50




391/391 - 40s - 102ms/step - accuracy: 0.7589 - loss: 0.6884 - val_accuracy: 0.6962 - val_loss: 0.8799
Epoch 5/50
391/391 - 29s - 74ms/step - accuracy: 0.7943 - loss: 0.5877 - val_accuracy: 0.6089 - val_loss: 1.3514
Epoch 6/50
391/391 - 42s - 107ms/step - accuracy: 0.8225 - loss: 0.5066 - val_accuracy: 0.6816 - val_loss: 1.0406
Epoch 7/50
391/391 - 29s - 74ms/step - accuracy: 0.8485 - loss: 0.4316 - val_accuracy: 0.6668 - val_loss: 1.2577
Epoch 8/50
391/391 - 41s - 104ms/step - accuracy: 0.8689 - loss: 0.3699 - val_accuracy: 0.6390 - val_loss: 1.4140
Epoch 9/50
391/391 - 42s - 106ms/step - accuracy: 0.8901 - loss: 0.3121 - val_accuracy: 0.7385 - val_loss: 0.9246
