<a href="https://colab.research.google.com/github/bhartiansh/cnn_pruning_cifar10/blob/main/pruning5(Random_Pruning_Unstructured).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!git clone https://github.com/bhartiansh/cnn_pruning_cifar10.git
%cd cnn_pruning_cifar10

fatal: destination path 'cnn_pruning_cifar10' already exists and is not an empty directory.
/content/cnn_pruning_cifar10


In [4]:
!pip install -q tensorflow-model-optimization

In [5]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_model_optimization as tfmot
from models.resnet56_baseline import build_resnet56

In [6]:
def random_unstructured_pruning(model, sparsity=0.3):
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Conv2D):
            weights = layer.get_weights()
            if not weights:
                continue  # Skip layers without weights

            kernel = weights[0]  # shape: (k, k, in_channels, out_channels)
            flat_kernel = kernel.flatten()
            n_total = flat_kernel.size
            n_prune = int(sparsity * n_total)

            prune_indices = np.random.choice(n_total, n_prune, replace=False)
            flat_kernel[prune_indices] = 0
            pruned_kernel = flat_kernel.reshape(kernel.shape)

            # Set pruned weights back
            if len(weights) == 2:
                layer.set_weights([pruned_kernel, weights[1]])
            else:
                layer.set_weights([pruned_kernel])


In [7]:
def train_random_pruned_model(build_model_fn, x_train, y_train, x_val, y_val,
                              sparsity=0.3, epochs=50, batch_size=128, save_path='randomprune_model.h5'):
    model = build_model_fn()
    random_unstructured_pruning(model, sparsity)

    early_stop = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(save_path, save_best_only=True)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    history = model.fit(x_train, y_train,
                        validation_data=(x_val, y_val),
                        epochs=epochs,
                        batch_size=batch_size,
                        callbacks=[early_stop, checkpoint],
                        verbose=2)

    return model, history


In [8]:
# Load and normalize CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train random-pruned model
model, history = train_random_pruned_model(
    build_model_fn=build_resnet56,
    x_train=x_train,
    y_train=y_train,
    x_val=x_test,
    y_val=y_test,
    sparsity=0.3,  # 30% of weights randomly zeroed
    save_path='randomprune_model.h5'
)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 0us/step
Epoch 1/50


KeyboardInterrupt: 

In [None]:
import pickle
with open('randomprune_training_log.pkl', 'wb') as f:
    pickle.dump(history.history, f)


In [None]:
# Load and normalize CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train random-pruned model
model, history = train_random_pruned_model(
    build_model_fn=build_resnet56,
    x_train=x_train,
    y_train=y_train,
    x_val=x_test,
    y_val=y_test,
    sparsity=0.5,  # 50% of weights randomly zeroed
    save_path='randomprune_model.h5'
)


In [None]:
# Load and normalize CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Train random-pruned model
model, history = train_random_pruned_model(
    build_model_fn=build_resnet56,
    x_train=x_train,
    y_train=y_train,
    x_val=x_test,
    y_val=y_test,
    sparsity=0.7,  # 70% of weights randomly zeroed
    save_path='randomprune_model.h5'
)
