In [1]:
from math import ceil
import tensorflow as tf
import tensorflow_model_optimization.sparsity.keras as sparsity

from utils.metrics import count_nonzero_params
from prune_model import prune_layer

In [2]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()

y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)

In [3]:
batch_size = 64

model = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(32, 32, 3)),
    tf.keras.layers.DepthwiseConv2D(kernel_size=4),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10, activation="softmax")
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=[
        tf.keras.metrics.TopKCategoricalAccuracy(k=1, name="top1"),
        tf.keras.metrics.TopKCategoricalAccuracy(k=3, name="top3")
    ]
)

model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    batch_size=batch_size,
    epochs=5
)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fc0084a3130>

In [4]:
initial_sparsity = 0
final_sparsity = 0.9
start_pruning = 0
pruning_steps = 5
frequency = 2

epochs = start_pruning + pruning_steps * frequency
steps_per_epoch = ceil(len(X_train) // batch_size + 1)

pruning_schedule = sparsity.PolynomialDecay(
    initial_sparsity=initial_sparsity,
    final_sparsity=final_sparsity,
    begin_step=start_pruning*steps_per_epoch,
    end_step=epochs*steps_per_epoch,
    frequency=frequency*steps_per_epoch
)

pruning_model = tf.keras.models.clone_model(
    model=model, 
    clone_function=lambda layer: prune_layer(layer, pruning_schedule)
)



In [5]:
for l1, l2 in zip(pruning_model.layers, model.layers):
    print(type(l1).__name__, "---", type(l2).__name__)

PruneLowMagnitude --- DepthwiseConv2D
PruneLowMagnitude --- Flatten
PruneLowMagnitude --- Dense


In [6]:
pruning_model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=[
        tf.keras.metrics.TopKCategoricalAccuracy(k=1, name="top1"),
        tf.keras.metrics.TopKCategoricalAccuracy(k=3, name="top3")
    ]
)

pruning_model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    batch_size=batch_size,
    epochs=epochs,
    callbacks=[
        sparsity.UpdatePruningStep()
    ]
)

pruning_model = sparsity.strip_pruning(pruning_model)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [7]:
for layer in pruning_model.layers:
    sparsity = (1 - count_nonzero_params(layer) / layer.count_params()) if layer.count_params() > 0 else 0.0
    print("[{: >5.2f}%] {}".format(sparsity * 100, type(layer).__name__))

[ 0.00%] DepthwiseConv2D
[ 0.00%] Flatten
[89.24%] Dense
