In [None]:
pip install onnx onnxruntime tensorflow tensorflow-model-optimization


In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow_model_optimization.sparsity.keras import prune_low_magnitude
import onnx
import onnxruntime as ort


In [None]:
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Sequential
from tensorflow.keras.datasets import mnist

# Load dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # Normalize data

# Define a simple model
model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

# Compile and train the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=3, validation_data=(x_test, y_test))


In [None]:
# Convert the model to TensorFlow Lite format with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]  # Enable default optimizations
tflite_model = converter.convert()

# Save the quantized model
with open("model_quantized.tflite", "wb") as f:
    f.write(tflite_model)

print("Quantized model saved as model_quantized.tflite")


In [None]:
from tensorflow_model_optimization.sparsity.keras import strip_pruning

# Apply pruning to the model
pruning_params = {'pruning_schedule': tf.keras.experimental.PruningSchedule.PolynomialDecay(initial_sparsity=0.2, final_sparsity=0.8, begin_step=0, end_step=1000)}
pruned_model = prune_low_magnitude(model, **pruning_params)

# Compile and retrain the pruned model
pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
pruned_model.fit(x_train, y_train, epochs=2, validation_data=(x_test, y_test))

# Strip pruning wrappers for deployment
pruned_model = strip_pruning(pruned_model)

# Save the pruned model
pruned_model.save("model_pruned.h5")
print("Pruned model saved as model_pruned.h5")


In [None]:
import tf2onnx

# Convert the model to ONNX format
onnx_model = tf2onnx.convert.from_keras(model)

# Save the ONNX model
with open("model.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

print("Model converted to ONNX and saved as model.onnx")


In [None]:
# Evaluate the original model
original_loss, original_accuracy = model.evaluate(x_test, y_test)
print(f"Original Model Accuracy: {original_accuracy:.4f}")

# Load and evaluate the quantized model
interpreter = tf.lite.Interpreter(model_path="model_quantized.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Prepare the input data
interpreter.set_tensor(input_details[0]['index'], x_test[:1].astype('float32'))
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

print(f"Quantized Model Output: {output_data}")

# Evaluate the pruned model
pruned_loss, pruned_accuracy = pruned_model.evaluate(x_test, y_test)
print(f"Pruned Model Accuracy: {pruned_accuracy:.4f}")


In [None]:
# Load the ONNX model for inference
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# Perform inference
result = session.run([output_name], {input_name: x_test[:1].reshape(1, 28, 28).astype('float32')})
print(f"ONNX Model Prediction: {result}")
