In [1]:
%pip install tensorflow-model-optimization

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Collecting absl-py~=1.2 (from tensorflow-model-optimization)
  Downloading absl_py-1.4.0-py3-none-any.whl.metadata (2.3 kB)
Collecting dm-tree~=0.1.1 (from tensorflow-model-optimization)
  Downloading dm_tree-0.1.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.4 kB)
Collecting numpy~=1.23 (from tensorflow-model-optimization)
  Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hDownloading absl_py-1.4.0-py3-none-any.whl (126 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.5/126.5 kB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

In [2]:
import tensorflow_model_optimization as tfmot
import tensorflow as tf
from tensorflow.keras.datasets import mnist

# Load dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Build a simple model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

# Apply pruning to the model
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.0, final_sparsity=0.5, begin_step=0, end_step=1000)
}
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

# Compile the pruned model
pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the pruned model to finalize pruning
callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
pruned_model.fit(x_train, y_train, epochs=2, validation_data=(x_test, y_test), callbacks=callbacks)

# Strip pruning wrappers to remove pruning-specific layers and metadata
pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

2025-09-14 11:59:56.964949: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-14 11:59:57.142611: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757843997.211062  615522 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757843997.231560  615522 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1757843997.359747  615522 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Epoch 1/5


I0000 00:00:1757844002.355097  615930 service.cc:152] XLA service 0x748610019080 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757844002.355120  615930 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 4060 Laptop GPU, Compute Capability 8.9
2025-09-14 12:00:02.384254: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1757844002.455750  615930 cuda_dnn.cc:529] Loaded cuDNN version 90701


[1m  85/1875[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3s[0m 2ms/step - accuracy: 0.4959 - loss: 1.5728    

I0000 00:00:1757844003.771619  615930 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m1863/1875[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 2ms/step - accuracy: 0.8560 - loss: 0.4898




[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 2ms/step - accuracy: 0.9129 - loss: 0.2974 - val_accuracy: 0.9564 - val_loss: 0.1467
Epoch 2/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - accuracy: 0.9573 - loss: 0.1437 - val_accuracy: 0.9661 - val_loss: 0.1062
Epoch 3/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 2ms/step - accuracy: 0.9667 - loss: 0.1078 - val_accuracy: 0.9738 - val_loss: 0.0847
Epoch 4/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 2ms/step - accuracy: 0.9736 - loss: 0.0856 - val_accuracy: 0.9753 - val_loss: 0.0827
Epoch 5/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - accuracy: 0.9754 - loss: 0.0758 - val_accuracy: 0.9774 - val_loss: 0.0722


ValueError: `prune_low_magnitude` can only prune an object of the following types: keras.models.Sequential, keras functional model, keras.layers.Layer, list of keras.layers.Layer. You passed an object of type: Sequential.

In [5]:
# Convert the pruned model to a TensorFlow Lite quantized model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_model = converter.convert()

INFO:tensorflow:Assets written to: /tmp/tmpy2jw9x1k/assets


INFO:tensorflow:Assets written to: /tmp/tmpy2jw9x1k/assets


Saved artifact at '/tmp/tmpy2jw9x1k'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28), dtype=tf.float32, name='keras_tensor_5')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  128124570010448: TensorSpec(shape=(), dtype=tf.resource, name=None)
  128124570012944: TensorSpec(shape=(), dtype=tf.resource, name=None)
  128124570013904: TensorSpec(shape=(), dtype=tf.resource, name=None)
  128124570013136: TensorSpec(shape=(), dtype=tf.resource, name=None)


W0000 00:00:1757844113.716766  615522 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1757844113.716783  615522 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2025-09-14 12:01:53.717087: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpy2jw9x1k
2025-09-14 12:01:53.717406: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-09-14 12:01:53.717412: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpy2jw9x1k
I0000 00:00:1757844113.720188  615522 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
2025-09-14 12:01:53.720645: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-09-14 12:01:53.738786: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpy2jw9x1k
2025-09-14 12:01:53.744291: I tensorflow/cc/saved_model/loader.cc:471] SavedModel 

Example (grid search in Scikit-Learn)

In [6]:
# Measure accuracy of the quantized model using the test set
interpreter = tf.lite.Interpreter(model_content=quantized_model)
interpreter.allocate_tensors()

input_index = interpreter.get_input_details()[0]['index']
output_index = interpreter.get_output_details()[0]['index']

# Evaluate accuracy
correct_predictions = 0
for i in range(len(x_test)):
    input_data = x_test[i:i+1].astype('float32')
    interpreter.set_tensor(input_index, input_data)
    interpreter.invoke()
    output = interpreter.get_tensor(output_index)
    predicted_label = output.argmax()
    if predicted_label == y_test[i]:
        correct_predictions += 1

accuracy = correct_predictions / len(x_test)
print(f'Quantized model accuracy: {accuracy * 100:.2f}%')

Quantized model accuracy: 97.69%


    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
