In [None]:

from importlib import import_module
from pathlib import Path
import sys
sys.path.insert(0, "../")
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow_model_optimization.quantization.keras import quantize_model
import numpy as np

# reload modules
import importlib
import models.fc
import models.cnn
importlib.reload(models.fc)
importlib.reload(models.cnn)

import data
importlib.reload(data)

from data import read_data, read_labels, normalize_img
from models.fc import build_fc_model
from models.cnn import build_cnn_model

In [None]:
# Read MNIST database
(train_data, train_labels), (test_data, test_labels) = tf.keras.datasets.mnist.load_data()

In [None]:
# Peprocessing (Normalization)
print('Raw data pixel value range:', train_data.min(), 'to', train_data.max())
train_data, train_labels = normalize_img(train_data, train_labels)
test_data, test_labels = normalize_img(test_data, test_labels)

print('Normalized datatye: ', type(train_data))
print('Normalized data pixel value range:', train_data.numpy().min(), 'to', train_data.numpy().max())

In [None]:
# One hot
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10)
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10)

In [None]:
# Define model
model_type = 'cnn'  # Cambia a 'cnn' para usar la red convolucional

In [None]:
# Define the path where the model is saved
OUTPUT_PATH = Path(f'./../../../saved_model/mnist_{model_type}') 

# Load the model
model = load_model(OUTPUT_PATH / 'model.h5')

model.summary()

In [None]:
# Evaluate model
test_loss, test_acc = model.evaluate(test_data, test_labels, verbose=2)
print(f"Test accuracy: {test_acc:.4f}")

In [None]:
# Apply quantization-aware training wrapper
quant_aware_model = quantize_model(model)

# Compile and retrain (fine-tune)
quant_aware_model.compile(optimizer='adam',
                          loss='categorical_crossentropy',
                          metrics=['accuracy'])

In [None]:
# Fine-tune the model
quant_aware_model.fit(train_data, train_labels, epochs=3, validation_data=(test_data, test_labels))

In [None]:
# Convert to TensorFlow Lite INT8 after QAT
converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

In [None]:
# Save model
OUTPUT_PATH = Path(f'./../../../saved_model/mnist_{model_type}_int8_qat')

# Save the quantized model
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
with open(OUTPUT_PATH / "model_int8.tflite", "wb") as f:
    f.write(tflite_model)

In [None]:
# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path=str(OUTPUT_PATH /"model_int8.tflite"))
interpreter.allocate_tensors()

# Get input and output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [None]:
def tflite_predict(x):
    if x.ndim == 3:
        x = np.expand_dims(x, axis=-1)

    input_shape = input_details[0]['shape']
    input_dtype = input_details[0]['dtype']

    if input_dtype == np.int8:
        scale, zero_point = input_details[0]['quantization']
        x = x / scale + zero_point
        x = np.clip(x, -128, 127).astype(np.int8)

    interpreter.set_tensor(input_details[0]['index'], x)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])

    if output_details[0]['dtype'] == np.int8:
        scale, zero_point = output_details[0]['quantization']
        output_data = (output_data.astype(np.float32) - zero_point) * scale

    return output_data


# Convert test tensors to NumPy if needed
if isinstance(test_data, tf.Tensor):
    test_data = test_data.numpy()
if isinstance(test_labels, tf.Tensor):
    test_labels = test_labels.numpy()

# Evaluate accuracy
correct = 0
total = len(test_data)
for i in range(total):
    x = test_data[i:i+1].astype(np.float32)
    y_true = np.argmax(test_labels[i])

    y_pred = np.argmax(tflite_predict(x))
    correct += (y_true == y_pred)

accuracy = correct / total
print(f"Quantized model accuracy: {accuracy:.4f}")
print(f"FP32 pretrained model accuracy: {test_acc:.4f}")