In [1]:
import numpy as np
import tensorflow as tf
from models.resnet20 import build_resnet20
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical

# -----------------------------
# SNIP Pruning Function
# -----------------------------
def snip_prune(model, x_batch, y_batch, sparsity):
    with tf.GradientTape() as tape:
        preds = model(x_batch, training=True)
        loss = tf.keras.losses.categorical_crossentropy(y_batch, preds)
    grads = tape.gradient(loss, model.trainable_variables)

    snip_scores = [tf.abs(w * g) for w, g in zip(model.trainable_variables, grads) if g is not None and 'kernel' in w.name]
    all_scores = tf.concat([tf.reshape(score, [-1]) for score in snip_scores], axis=0)

    k = int((1 - sparsity) * tf.size(all_scores).numpy())
    threshold = tf.sort(all_scores)[k].numpy()

    masks = [(tf.abs(score) > threshold).numpy().astype(np.float32) for score in snip_scores]
    mask_idx = 0

    for i, var in enumerate(model.trainable_variables):
        if 'kernel' in var.name:
            var.assign(var * masks[mask_idx])
            mask_idx += 1

# -----------------------------
# Training Function
# -----------------------------
def run_snip_training(sparsity=0.5, batch_size=128, epochs=150):
    # Load and preprocess data
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train.astype('float32') / 255.0, x_test.astype('float32') / 255.0
    y_train, y_test = to_categorical(y_train, 10), to_categorical(y_test, 10)

    # Build and initialize model
    model = build_resnet20()
    model.build(input_shape=(None, 32, 32, 3))
    model.summary()

    # SNIP pruning on small batch
    x_sample = x_train[:batch_size]
    y_sample = y_train[:batch_size]
    snip_prune(model, x_sample, y_sample, sparsity)

    # Compile and train
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    model.fit(x_train, y_train,
              epochs=epochs,
              batch_size=batch_size,
              validation_data=(x_test, y_test),
              verbose=1)

    print("✅ SNIP pruning and training complete.")

In [3]:
run_snip_training(sparsity=0.3, batch_size=128, epochs=150)

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d_21 (Conv2D)             (None, 32, 32, 16)   448         ['input_2[0][0]']                
                                                                                                  
 batch_normalization_19 (BatchN  (None, 32, 32, 16)  64          ['conv2d_21[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_19 (Activation)     (None, 32, 32, 16)   0           ['batch_normalization_19[0]

In [4]:
run_snip_training(sparsity=0.5, batch_size=128, epochs=150)

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d_42 (Conv2D)             (None, 32, 32, 16)   448         ['input_3[0][0]']                
                                                                                                  
 batch_normalization_38 (BatchN  (None, 32, 32, 16)  64          ['conv2d_42[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_38 (Activation)     (None, 32, 32, 16)   0           ['batch_normalization_38[0]

In [5]:
run_snip_training(sparsity=0.7, batch_size=128, epochs=150)

Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_4 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d_63 (Conv2D)             (None, 32, 32, 16)   448         ['input_4[0][0]']                
                                                                                                  
 batch_normalization_57 (BatchN  (None, 32, 32, 16)  64          ['conv2d_63[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_57 (Activation)     (None, 32, 32, 16)   0           ['batch_normalization_57[0]

In [6]:
run_snip_training(sparsity=0.9, batch_size=128, epochs=150)

Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_5 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d_84 (Conv2D)             (None, 32, 32, 16)   448         ['input_5[0][0]']                
                                                                                                  
 batch_normalization_76 (BatchN  (None, 32, 32, 16)  64          ['conv2d_84[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_76 (Activation)     (None, 32, 32, 16)   0           ['batch_normalization_76[0]