In [None]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import tensorflow as tf

import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.keras.compat import keras

In [3]:
x, y = make_classification(n_samples=1000, n_features=10, random_state=42, n_classes=2)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=42)

In [10]:
model = keras.Sequential([
    keras.layers.Dense(128, activation='relu', input_shape=(10,)),
    keras.layers.Dense(64, activation="relu"),
    keras.layers.Dense(1, activation="sigmoid")
])

**Model Pruning**

In [11]:
pruning_params = {
    'pruning_schedule' : tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.0,
        final_sparsity=0.5,
        begin_step=0,
        end_step=1000
    )
}

model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

In [14]:
# Train with pruning
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_dense   (None, 128)               2690      
 (PruneLowMagnitude)                                             
                                                                 
 prune_low_magnitude_dense_  (None, 64)                16450     
 1 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_dense_  (None, 1)                 131       
 2 (PruneLowMagnitude)                                           
                                                                 
Total params: 19271 (75.29 KB)
Trainable params: 9729 (38.00 KB)
Non-trainable params: 9542 (37.29 KB)
_________________________________________________________________


In [15]:
model.fit(x_train, y_train, epochs=10, batch_size=32, callbacks=[tfmot.sparsity.keras.UpdatePruningStep()])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tf_keras.src.callbacks.History at 0x186c20197c0>

In [16]:
# Strip pruning wrapper
model = tfmot.sparsity.keras.strip_pruning(model)

**Quantization**

- Post-training quantization

In [17]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_model = converter.convert()

INFO:tensorflow:Assets written to: C:\Users\User\AppData\Local\Temp\tmp_8vvlo6k\assets


INFO:tensorflow:Assets written to: C:\Users\User\AppData\Local\Temp\tmp_8vvlo6k\assets


- Quantization-aware training

In [18]:
qat_model = tfmot.quantization.keras.quantize_model(model)
qat_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

qat_model.fit(x_train, y_train, epochs=10, batch_size=32)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tf_keras.src.callbacks.History at 0x186c19b5fd0>