In [None]:
import numpy as np
import tensorflow as tf
from models.resnet20 import build_resnet20
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import os

# --------------------------
# Global Magnitude Pruning
# --------------------------
def global_magnitude_pruning(model, sparsity):
    print(f"[INFO] Applying global magnitude pruning with sparsity: {sparsity}")
    all_weights = []
    for layer in model.layers:
        if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense)):
            weights = layer.get_weights()
            if weights:
                all_weights.append(weights[0].flatten())
    
    all_weights = np.concatenate(all_weights)
    threshold = np.percentile(np.abs(all_weights), sparsity * 100)

    for layer in model.layers:
        if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense)):
            weights = layer.get_weights()
            if weights:
                w, b = weights
                mask = np.abs(w) >= threshold
                layer.set_weights([w * mask, b])

# --------------------------
# Load CIFAR-10 Dataset
# --------------------------
def load_cifar10():
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)
    return (x_train, y_train), (x_test, y_test)

# --------------------------
# Training Function
# --------------------------
def train(model, x_train, y_train, x_test, y_test, epochs, batch_size, checkpoint_path):
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                           monitor='val_accuracy',
                                           save_best_only=True,
                                           verbose=1)
    ]

    history = model.fit(x_train, y_train,
                        epochs=epochs,
                        batch_size=batch_size,
                        validation_data=(x_test, y_test),
                        callbacks=callbacks)

    return history

# --------------------------
# Training Loop for LWC (Han et al.)
# --------------------------
def run_one_time_pruning(sparsity=0.5, epochs=100, batch_size=128):
    (x_train, y_train), (x_test, y_test) = load_cifar10()

    model = build_resnet20()
    model.build(input_shape=(None, 32, 32, 3))
    model.summary()

    global_magnitude_pruning(model, sparsity)

    checkpoint_path = f"results/one_time_resnet20_sparsity_{sparsity}.h5"
    train(model, x_train, y_train, x_test, y_test, epochs, batch_size, checkpoint_path)

    print("✅ One-time pruning and training complete.")


In [2]:
run_lwc_training(sparsity=0.3, epochs=150, batch_size=128, iterations=5)

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 32, 32, 16)   448         ['input_1[0][0]']                
                                                                                                  
 batch_normalization (BatchNorm  (None, 32, 32, 16)  64          ['conv2d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 activation (Activation)        (None, 32, 32, 16)   0           ['batch_normalization[0][0]']

KeyboardInterrupt: 