In [38]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow import keras
import numpy as np


mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0


model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Conv2D(filters=24, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Conv2D(filters=48, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.src.callbacks.History at 0x7cb7ce5a3010>

In [47]:
import tensorflow as tf
import numpy as np

def prune_weights(model, pruning_threshold=0.1):
    pruned_model = tf.keras.models.clone_model(model)
    pruned_model.set_weights(model.get_weights())  

    for layer, pruned_layer in zip(model.layers, pruned_model.layers):
        if isinstance(layer, tf.keras.layers.Conv2D):
            weights = layer.get_weights()
            kernel_weights, bias = weights[0], weights[1]
            
            # IMPORTANCE SCORE USING L1 NORM ON INDIVIDUAL WEIGHTS
            importance_scores = np.abs(kernel_weights)
            threshold = np.percentile(importance_scores, pruning_threshold * 100)
            
            # IDENTIFY THE WEIGHTS TO BE PRUNED
            weights_to_prune = importance_scores < threshold
            
            # CHECK THE WEIGHTS GONNA BE PRUNED
            print("\n weights to be pruned:", weights_to_prune.sum(), len(weights_to_prune))
            
            # PRUNE THE WEIGHTS BY SETTING THEM TO ZERO
            kernel_weights[weights_to_prune] = 0.0
            
            # UPDATE THE WEIGHT
            pruned_layer.set_weights([kernel_weights, bias])

    return pruned_model


In [51]:
pruned_model = prune_filters(model, pruning_threshold=0.1)
pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
pruned_model.fit(train_images, train_labels, epochs=5, validation_split=0.2)


 filters to be pruned: [False False False False False False False  True False False  True False] 12

 filters to be pruned: [False False  True False False False False False False False False False
 False False False  True False False False False False  True False False] 24

 filters to be pruned: [False False False False False  True False False False False False False
 False False  True False False False False False False  True False False
 False False False  True False False False False False False  True False
 False False False False False False False False False False False False] 48
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x7cb7116dc790>

In [52]:
model.evaluate(test_images, test_labels)



[0.0739019513130188, 0.9750999808311462]

In [53]:
pruned_model.evaluate(test_images, test_labels)



[2.300748825073242, 0.5504999756813049]