In [8]:
import tensorflow as tf
import os
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

def print_tflite_info(tflite_model_path):
    """
    Print all essential information about a TensorFlow Lite (.tflite) model for future use in Java/Android Studio.
    
    Args:
        tflite_model_path (str): Path to the .tflite model file.
    """
    # Load TFLite model
    try:
        interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
        interpreter.allocate_tensors()
        logger.info("TFLite model loaded successfully.")
    except Exception as e:
        logger.error(f"Failed to load TFLite model from {tflite_model_path}: {e}")
        return

    # Model file size
    try:
        model_size = os.path.getsize(tflite_model_path) / (1024 * 1024)  # Convert to MB
        logger.info(f"Model File Size: {model_size:.2f} MB")
    except Exception as e:
        logger.error(f"Failed to compute model size: {e}")

    # Input tensors
    logger.info("\nInput Tensors:")
    input_details = interpreter.get_input_details()
    for tensor in input_details:
        quant = tensor['quantization']
        quant_info = f"Quantization: scale={quant[0]}, zero_point={quant[1]}" if quant != (0.0, 0) else "No quantization"
        logger.info(f"Name: {tensor['name']}, Index: {tensor['index']}, Shape: {tensor['shape']}, Dtype: {tensor['dtype']}, {quant_info}")

    # Output tensors
    logger.info("\nOutput Tensors:")
    output_details = interpreter.get_output_details()
    for tensor in output_details:
        quant = tensor['quantization']
        quant_info = f"Quantization: scale={quant[0]}, zero_point={quant[1]}" if quant != (0.0, 0) else "No quantization"
        logger.info(f"Name: {tensor['name']}, Index: {tensor['index']}, Shape: {tensor['shape']}, Dtype: {tensor['dtype']}, {quant_info}")

    # All tensors
    logger.info("\nAll Tensors:")
    all_tensors = interpreter.get_tensor_details()
    for tensor in all_tensors:
        quant = tensor['quantization']
        quant_info = f"Quantization: scale={quant[0]}, zero_point={quant[1]}" if quant != (0.0, 0) else "No quantization"
        logger.info(f"Name: {tensor['name']}, Index: {tensor['index']}, Shape: {tensor['shape']}, Dtype: {tensor['dtype']}, {quant_info}")

    # Operations
    logger.info("\nOperations:")
    try:
        ops_details = interpreter._get_ops_details()
        for op in ops_details:
            logger.info(f"Type: {op['op_name']}, Inputs: {op['inputs']}, Outputs: {op['outputs']}")
    except Exception as e:
        logger.error(f"Failed to retrieve operations: {e}")

    # Metadata
    logger.info("\nModel Metadata:")
    try:
        from tflite_support import metadata as _metadata
        displayer = _metadata.MetadataDisplayer.with_model_file(tflite_model_path)
        metadata_json = displayer.get_metadata_json()
        if metadata_json:
            logger.info(f"Metadata: {metadata_json}")
        else:
            logger.info("No metadata found in the model.")
    except ImportError:
        logger.warning("tflite_support module not found. Metadata extraction skipped.")
    except Exception as e:
        logger.error(f"Failed to extract metadata: {e}")

    # Signatures
    logger.info("\nSignatures:")
    signatures = interpreter.get_signature_list()
    if signatures:
        for sig_name, sig_details in signatures.items():
            logger.info(f"Signature Name: {sig_name}, Details: {sig_details}")
    else:
        logger.info("No signatures found in the model.")

# Example usage
if __name__ == "__main__":
    tflite_model_path = "student_model_31.tflite"  # Replace with your .tflite file path
    print_tflite_info(tflite_model_path)

2025-04-30 16:50:45,316 - INFO - TFLite model loaded successfully.
2025-04-30 16:50:45,318 - INFO - Model File Size: 0.09 MB
2025-04-30 16:50:45,319 - INFO - 
Input Tensors:
2025-04-30 16:50:45,320 - INFO - Name: inputs, Index: 0, Shape: [ 1 64  3], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:45,320 - INFO - 
Output Tensors:
2025-04-30 16:50:45,322 - INFO - Name: Identity, Index: 171, Shape: [1 1], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:45,323 - INFO - 
All Tensors:
2025-04-30 16:50:45,326 - INFO - Name: inputs, Index: 0, Shape: [ 1 64  3], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:45,327 - INFO - Name: arith.constant, Index: 1, Shape: [32], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:45,328 - INFO - Name: arith.constant1, Index: 2, Shape: [32 64], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:45,329 - INFO - Name: arith.constant2, Index: 3, Shape: [64 32], Dtype: <class 'numpy

In [12]:

    tflite_model_path = "time_series_transformer.tflite"  # Replace with your .tflite file path
    print_tflite_info(tflite_model_path)

2025-04-30 16:50:58,591 - INFO - TFLite model loaded successfully.
2025-04-30 16:50:58,596 - INFO - Model File Size: 0.10 MB
2025-04-30 16:50:58,597 - INFO - 
Input Tensors:
2025-04-30 16:50:58,598 - INFO - Name: serving_default_args_0:0, Index: 0, Shape: [  1   3 128], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:58,600 - INFO - 
Output Tensors:
2025-04-30 16:50:58,601 - INFO - Name: StatefulPartitionedCall:0, Index: 197, Shape: [1 1], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:58,602 - INFO - 
All Tensors:
2025-04-30 16:50:58,607 - INFO - Name: serving_default_args_0:0, Index: 0, Shape: [  1   3 128], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:58,609 - INFO - Name: arith.constant, Index: 1, Shape: [ 1 32], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:58,610 - INFO - Name: arith.constant1, Index: 2, Shape: [1], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:58,611 - INFO - Name: arith

In [11]:


    print_tflite_info("watchTousifKd.tflite")

2025-04-30 16:50:49,865 - INFO - TFLite model loaded successfully.
2025-04-30 16:50:49,868 - INFO - Model File Size: 0.14 MB
2025-04-30 16:50:49,870 - INFO - 
Input Tensors:
2025-04-30 16:50:49,873 - INFO - Name: serving_default_args_0:0, Index: 0, Shape: [  1 128   4], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:49,874 - INFO - 
Output Tensors:
2025-04-30 16:50:49,876 - INFO - Name: StatefulPartitionedCall:0, Index: 272, Shape: [1 1], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:49,878 - INFO - 
All Tensors:
2025-04-30 16:50:49,881 - INFO - Name: serving_default_args_0:0, Index: 0, Shape: [  1 128   4], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:49,882 - INFO - Name: arith.constant, Index: 1, Shape: [1], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:49,883 - INFO - Name: arith.constant1, Index: 2, Shape: [32], Dtype: <class 'numpy.float32'>, No quantization
2025-04-30 16:50:49,884 - INFO - Name: arith.co

In [15]:
import tensorflow as tf
import numpy as np
import os
import logging
import traceback
import shutil
from sklearn.metrics import mean_squared_error
from keras.saving import register_keras_serializable

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

# Define and register the custom TransModel class
@register_keras_serializable(package="CustomModels")
class TransModel(tf.keras.Model):
    def __init__(self, acc_frames=64, num_classes=1, num_heads=4, acc_coords=3, embed_dim=32, num_layers=2, dropout=0.5, activation='relu', **kwargs):
        super().__init__(**kwargs)
        self.acc_frames = acc_frames
        self.num_classes = num_classes
        self.num_heads = num_heads
        self.acc_coords = acc_coords
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.dropout_rate = dropout
        self.activation = activation
        
        # Define layers
        self.conv_layer = tf.keras.layers.Conv2D(filters=embed_dim, kernel_size=(8, 1), padding='same', name="conv_projection")
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layer_norm")
        self.attention_layers = [
            tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=dropout, name=f"mha_{i}")
            for i in range(num_layers)
        ]
        self.ffn_layers = [
            tf.keras.Sequential([
                tf.keras.layers.Dense(embed_dim * 2, activation=activation, name=f"ffn_dense1_{i}"),
                tf.keras.layers.Dropout(dropout),
                tf.keras.layers.Dense(embed_dim, name=f"ffn_dense2_{i}"),
                tf.keras.layers.Dropout(dropout)
            ], name=f"ffn_{i}")
            for i in range(num_layers)
        ]
        self.layer_norms = [
            [tf.keras.layers.LayerNormalization(epsilon=1e-6, name=f"ln{i}_{j}") for j in range(2)]
            for i in range(num_layers)
        ]
        self.final_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="final_norm")
        self.global_pool = tf.keras.layers.GlobalAveragePooling1D(name="global_pool")
        self.output_dense = tf.keras.layers.Dense(num_classes, name="output_dense")
        
        # Initialize with a dummy input to ensure the model is built
        dummy_input = tf.zeros((1, acc_frames, acc_coords), dtype=tf.float32)
        self(dummy_input, training=False)
        
    def call(self, inputs, training=False):
        x = inputs.get('accelerometer', inputs) if isinstance(inputs, dict) else inputs
        x = tf.expand_dims(x, axis=2)  # Add channel dimension for Conv2D
        x = self.conv_layer(x)
        x = tf.squeeze(x, axis=2)  # Remove channel dimension
        x = self.layer_norm(x)
        
        # Transformer layers
        for i in range(self.num_layers):
            attn = self.attention_layers[i](x, x, training=training)
            x = self.layer_norms[i][0](x + attn)  # Residual + normalization
            ffn = self.ffn_layers[i](x, training=training)
            x = self.layer_norms[i][1](x + ffn)  # Residual + normalization
            
        x = self.final_norm(x)
        x = self.global_pool(x)
        logits = self.output_dense(x)
        return tf.reshape(logits, [-1, self.num_classes])
    
    # Add get_config method for proper serialization
    def get_config(self):
        config = super().get_config()
        config.update({
            "acc_frames": self.acc_frames,
            "num_classes": self.num_classes,
            "num_heads": self.num_heads,
            "acc_coords": self.acc_coords,
            "embed_dim": self.embed_dim,
            "num_layers": self.num_layers,
            "dropout_rate": self.dropout_rate,
            "activation": self.activation
        })
        return config
        
    # Add TFLite export functionality
    def export_to_tflite(self, save_path, input_shape=None):
        """
        Export the model to TFLite format
        
        Args:
            save_path (str): Path to save the TFLite model
            input_shape (tuple, optional): Input shape for the model. Defaults to (1, acc_frames, acc_coords).
        
        Returns:
            bool: True if export was successful, False otherwise
        """
        if input_shape is None:
            input_shape = (1, self.acc_frames, self.acc_coords)
            
        try:
            os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
            save_path = save_path if save_path.endswith('.tflite') else f"{save_path}.tflite"
            logger.info(f"Exporting to TFLite: {save_path}, shape={input_shape}")

            # Define a wrapper model to ensure proper input signature
            class TFLiteModel(tf.keras.Model):
                def __init__(self, parent):
                    super().__init__()
                    self.parent = parent

                @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32, name='accelerometer')])
                def call(self, inputs):
                    return self.parent({'accelerometer': inputs}, training=False)

            tflite_model = TFLiteModel(self)
            
            # Create a temporary directory for the SavedModel
            temp_dir = os.path.join(os.path.dirname(save_path) or '.', "temp_savedmodel")
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
            os.makedirs(temp_dir, exist_ok=True)

            # Save the model with signatures
            tf.saved_model.save(
                tflite_model, 
                temp_dir, 
                signatures={'serving_default': tflite_model.call}
            )
            
            # Convert to TFLite
            converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir)
            converter.target_spec.supported_ops = [
                tf.lite.OpsSet.TFLITE_BUILTINS, 
                tf.lite.OpsSet.SELECT_TF_OPS
            ]
            converter.inference_input_type = tf.float32
            converter.inference_output_type = tf.float32
            tflite_content = converter.convert()

            # Save TFLite model
            with open(save_path, 'wb') as f:
                f.write(tflite_content)
                
            # Clean up temporary directory
            shutil.rmtree(temp_dir)
            logger.info(f"TFLite model successfully saved: {save_path}")
            return True
        except Exception as e:
            logger.error(f"TFLite export failed: {e}\n{traceback.format_exc()}")
            if 'temp_dir' in locals() and os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
            return False


def get_keras_layer_info(model):
    """Extract layer information from a Keras model"""
    try:
        layers_info = []
        for layer in model.layers:
            layer_info = {
                'name': layer.name,
                'type': layer.__class__.__name__,
                'trainable': layer.trainable,
                'params': layer.count_params()
            }
            
            # Try to extract shape information if available
            try:
                if hasattr(layer, 'output_shape'):
                    layer_info['output_shape'] = str(layer.output_shape)
                if hasattr(layer, 'input_shape'):
                    layer_info['input_shape'] = str(layer.input_shape)
            except:
                pass
                
            # If it's a Sequential model, get info about its layers
            if isinstance(layer, tf.keras.Sequential):
                sublayers = []
                for sublayer in layer.layers:
                    sublayer_info = {
                        'name': sublayer.name,
                        'type': sublayer.__class__.__name__,
                        'trainable': sublayer.trainable,
                        'params': sublayer.count_params()
                    }
                    sublayers.append(sublayer_info)
                layer_info['sublayers'] = sublayers
                
            layers_info.append(layer_info)
        return layers_info
    except Exception as e:
        logger.error(f"Failed to extract Keras layer info: {e}")
        return []


def get_tflite_op_info(interpreter):
    """Extract operation information from a TFLite interpreter"""
    try:
        ops_info = []
        
        # Get information about ops from the interpreter
        with open(interpreter._model_path, 'rb') as f:
            model_data = f.read()
        
        # Get input and output details
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        
        # Get operator codes and subgraphs from the model
        try:
            # Get basic tensor details
            tensor_details = interpreter._get_tensor_details()
            for i, tensor in enumerate(tensor_details):
                is_input = any(tensor['index'] == det['index'] for det in input_details)
                is_output = any(tensor['index'] == det['index'] for det in output_details)
                
                op_info = {
                    'tensor_index': tensor['index'],
                    'name': tensor['name'],
                    'shape': str(tensor['shape']),
                    'dtype': str(tensor['dtype']),
                    'is_input': is_input,
                    'is_output': is_output
                }
                ops_info.append(op_info)
        except Exception as e:
            logger.warning(f"Couldn't extract detailed TFLite tensor info: {e}")
            
            # Fallback to basic input/output info
            for i, details in enumerate(input_details):
                ops_info.append({
                    'tensor_index': details['index'],
                    'name': details.get('name', f'input_{i}'),
                    'shape': str(details['shape']),
                    'dtype': str(details['dtype']),
                    'is_input': True,
                    'is_output': False
                })
                
            for i, details in enumerate(output_details):
                ops_info.append({
                    'tensor_index': details['index'],
                    'name': details.get('name', f'output_{i}'),
                    'shape': str(details['shape']),
                    'dtype': str(details['dtype']),
                    'is_input': False,
                    'is_output': True
                })
                
        return ops_info
    except Exception as e:
        logger.error(f"Failed to extract TFLite op info: {e}")
        return []


def compare_keras_tflite(keras_model_path, tflite_model_path, weights_path=None, sample_input=None, model_architecture=None):
    """
    Compare a Keras TransModel with its TFLite version for robustness and compatibility.
    
    Args:
        keras_model_path (str): Path to the .keras model file.
        tflite_model_path (str): Path to the .tflite model file (or where it will be saved if converted).
        weights_path (str, optional): Path to the .weights.h5 file if .keras is unavailable.
        sample_input (np.ndarray, optional): Sample input for inference comparison.
        model_architecture (callable, optional): Function to define model architecture if using .weights.h5.
    """
    keras_model = None
    
    # Load or create Keras model
    try:
        if os.path.exists(keras_model_path):
            logger.info(f"Loading Keras model from: {keras_model_path}")
            keras_model = tf.keras.models.load_model(
                keras_model_path, 
                custom_objects={'TransModel': TransModel}
            )
            logger.info("Keras model loaded from .keras file.")
        elif weights_path and model_architecture and os.path.exists(weights_path):
            logger.info(f"Creating model and loading weights from: {weights_path}")
            keras_model = model_architecture()
            keras_model.load_weights(weights_path)
            logger.info("Keras model loaded from .weights.h5 with provided architecture.")
        else:
            logger.info("Creating new model instance")
            keras_model = model_architecture() if model_architecture else TransModel()
            logger.info("New TransModel instance created.")
    except Exception as e:
        logger.error(f"Failed to load/create Keras model: {e}")
        return
    
    # Ensure model is built
    try:
        dummy_input = np.zeros((1, 64, 3), dtype=np.float32)
        _ = keras_model(dummy_input)
        logger.info("Keras model built successfully with dummy input.")
    except Exception as e:
        logger.error(f"Failed to build model: {e}")
    
    # Log Keras model details
    logger.info("\n=== Keras Model Details ===")
    
    # Get shapes by inference instead of direct attribute access
    try:
        dummy_input = np.zeros((1, 64, 3), dtype=np.float32)
        test_output = keras_model.predict(dummy_input, verbose=0)
        keras_input_shape = [None, 64, 3]  # Input shape is known from our dummy input pattern
        keras_output_shape = list(test_output.shape)
        logger.info(f"Keras Input Shape (inferred): {keras_input_shape}")
        logger.info(f"Keras Output Shape (inferred): {keras_output_shape}")
    except Exception as e:
        logger.error(f"Failed to infer shapes: {e}")
    
    # Extract keras layer information
    keras_layers = get_keras_layer_info(keras_model)
    
    # Convert to TFLite if necessary
    if not os.path.exists(tflite_model_path):
        try:
            logger.info(f"Converting model to TFLite: {tflite_model_path}")
            # Try to use the built-in export method if it's a TransModel instance
            if isinstance(keras_model, TransModel):
                logger.info("Using TransModel's built-in TFLite export method")
                success = keras_model.export_to_tflite(tflite_model_path)
                if not success:
                    raise ValueError("Export failed using TransModel's export_to_tflite method")
            else:
                # Fallback to generic conversion
                logger.info("Using generic TFLite conversion")
                @tf.function(input_signature=[tf.TensorSpec(shape=(None, 64, 3), dtype=tf.float32)])
                def model_func(inputs):
                    return keras_model(inputs, training=False)
                concrete_func = model_func.get_concrete_function()
                converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
                tflite_model = converter.convert()
                with open(tflite_model_path, 'wb') as f:
                    f.write(tflite_model)
            logger.info(f"Converted Keras model to TFLite and saved to {tflite_model_path}.")
        except Exception as e:
            logger.error(f"Failed to convert Keras model to TFLite: {e}")
            return
    else:
        logger.info(f"Using existing TFLite model at {tflite_model_path}.")
    
    # Load TFLite model
    try:
        logger.info(f"Loading TFLite model from {tflite_model_path}")
        interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
        interpreter.allocate_tensors()
        logger.info("TFLite model loaded successfully.")
    except Exception as e:
        logger.error(f"Failed to load TFLite model: {e}")
        return
    
    # Get TFLite model details
    logger.info("\n=== TFLite Model Details ===")
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    tflite_input_shape = input_details[0]['shape']
    tflite_output_shape = output_details[0]['shape']
    logger.info(f"TFLite Input Shape: {tflite_input_shape}")
    logger.info(f"TFLite Output Shape: {tflite_output_shape}")
    tflite_input_dtype = input_details[0]['dtype']
    logger.info(f"TFLite Input Data Type: {tflite_input_dtype}")
    
    # Extract TFLite operation information
    tflite_ops = get_tflite_op_info(interpreter)
    
    # Compare and display layer information
    logger.info("\n=== Layer/Operator Comparison ===")
    logger.info(f"Keras model has {len(keras_layers)} main layers")
    logger.info(f"TFLite model has {len(tflite_ops)} tensors/operators")
    
    # Display Keras layers
    logger.info("\nKeras Layers:")
    for i, layer in enumerate(keras_layers):
        logger.info(f"  {i+1}. {layer['name']} ({layer['type']}) - Params: {layer['params']}")
        if 'input_shape' in layer:
            logger.info(f"     Input shape: {layer['input_shape']}")
        if 'output_shape' in layer:
            logger.info(f"     Output shape: {layer['output_shape']}")
        if 'sublayers' in layer:
            for j, sublayer in enumerate(layer['sublayers']):
                logger.info(f"       {j+1}. {sublayer['name']} ({sublayer['type']}) - Params: {sublayer['params']}")
    
    # Display TFLite operators
    logger.info("\nTFLite Tensors/Operators:")
    for i, op in enumerate(tflite_ops):
        role = "INPUT" if op['is_input'] else "OUTPUT" if op['is_output'] else "INTERMEDIATE"
        logger.info(f"  {i+1}. {op['name']} ({role}) - Shape: {op['shape']}, Type: {op['dtype']}")
    
    # Compare shapes
    logger.info("\n=== Shape Comparison ===")
    try:
        if 'keras_input_shape' in locals() and 'tflite_input_shape' in locals():
            # Compare all except batch dimension
            if keras_input_shape[1:] == list(tflite_input_shape)[1:]:
                logger.info("✓ Input shapes match (ignoring batch dimension).")
            else:
                logger.warning(f"✗ Input shapes differ: Keras {keras_input_shape} vs TFLite {tflite_input_shape}")
        else:
            logger.info(f"⚠ Full shape comparison unavailable - displaying separately:")
            logger.info(f"  Keras input (inferred): {locals().get('keras_input_shape', 'Unknown')}")
            logger.info(f"  TFLite input: {locals().get('tflite_input_shape', 'Unknown')}")
        
        if 'keras_output_shape' in locals() and 'tflite_output_shape' in locals():
            if keras_output_shape == list(tflite_output_shape):
                logger.info("✓ Output shapes match.")
            else:
                logger.warning(f"✗ Output shapes differ: Keras {keras_output_shape} vs TFLite {tflite_output_shape}")
        else:
            logger.info(f"⚠ Output shape comparison unavailable - displaying separately:")
            logger.info(f"  Keras output (inferred): {locals().get('keras_output_shape', 'Unknown')}")
            logger.info(f"  TFLite output: {locals().get('tflite_output_shape', 'Unknown')}")
    except Exception as e:
        logger.error(f"Error during shape comparison: {e}")
    
    # Generate sample input if not provided
    if sample_input is None:
        sample_input = np.random.randn(1, 64, 3).astype(np.float32)
        logger.info("Generated random input for inference comparison.")
    
    # Perform inference and compare outputs
    logger.info("\n=== Inference Comparison ===")
    try:
        logger.info("Running Keras model inference...")
        keras_output = keras_model.predict(sample_input, verbose=0)
        
        logger.info("Running TFLite model inference...")
        interpreter.set_tensor(input_details[0]['index'], sample_input)
        interpreter.invoke()
        tflite_output = interpreter.get_tensor(output_details[0]['index'])
        
        logger.info("Comparing outputs...")
        mse = mean_squared_error(keras_output.flatten(), tflite_output.flatten())
        logger.info(f"Mean Squared Error between outputs: {mse:.10f}")
        
        if mse < 1e-5:
            logger.info("✓ Outputs are approximately equal (MSE < 1e-5).")
        else:
            logger.warning(f"✗ Outputs differ significantly (MSE = {mse:.10f}).")
            
        # Display sample values from both outputs
        logger.info("\nSample output values:")
        keras_flat = keras_output.flatten()
        tflite_flat = tflite_output.flatten()
        num_samples = min(5, len(keras_flat))
        for i in range(num_samples):
            logger.info(f"  Value {i+1}: Keras={keras_flat[i]:.6f}, TFLite={tflite_flat[i]:.6f}, Diff={abs(keras_flat[i]-tflite_flat[i]):.6f}")
    except Exception as e:
        logger.error(f"Failed to compare outputs: {e}")
    
    # Compare model sizes
    logger.info("\n=== Model Size Comparison ===")
    try:
        keras_size = os.path.getsize(keras_model_path) / (1024 * 1024) if os.path.exists(keras_model_path) else 0
        weights_size = os.path.getsize(weights_path) / (1024 * 1024) if weights_path and os.path.exists(weights_path) else 0
        tflite_size = os.path.getsize(tflite_model_path) / (1024 * 1024)
        
        if keras_size:
            logger.info(f"Keras model size (.keras): {keras_size:.2f} MB")
        if weights_size:
            logger.info(f"Weights size (.weights.h5): {weights_size:.2f} MB")
        logger.info(f"TFLite model size: {tflite_size:.2f} MB")
        
        if keras_size and tflite_size:
            reduction = (1 - tflite_size/keras_size) * 100
            logger.info(f"Size reduction: {reduction:.2f}%")
    except Exception as e:
        logger.error(f"Failed to compute model sizes: {e}")
    
    logger.info("\n=== Compatibility Check ===")
    logger.info("✓ Keras model can be loaded and executed")
    logger.info("✓ TFLite model can be loaded and executed")
    logger.info(f"✓ Inference results {'match' if mse < 1e-5 else 'differ'} between implementations")
    
    logger.info("\n=== Comparison Complete ===")
    return keras_model, interpreter


# Example usage
if __name__ == "__main__":
    keras_model_path = "student_model_31.keras"
    tflite_model_path = "student_model_31.tflite"
    weights_path = "student_model_31.weights.h5"
    
    # Define a function to create a new instance of TransModel
    def create_trans_model():
        return TransModel()
    
    # Run comparison
    compare_keras_tflite(
        keras_model_path, 
        tflite_model_path, 
        weights_path=weights_path, 
        model_architecture=create_trans_model
    )

2025-04-30 16:59:00,451 - INFO - Loading Keras model from: student_model_31.keras
2025-04-30 16:59:01,013 - INFO - Keras model loaded from .keras file.
2025-04-30 16:59:01,055 - INFO - Keras model built successfully with dummy input.
2025-04-30 16:59:01,055 - INFO - 
=== Keras Model Details ===




2025-04-30 16:59:01,503 - INFO - Keras Input Shape (inferred): [None, 64, 3]
2025-04-30 16:59:01,504 - INFO - Keras Output Shape (inferred): [1, 1]
2025-04-30 16:59:01,505 - INFO - Using existing TFLite model at student_model_31.tflite.
2025-04-30 16:59:01,506 - INFO - Loading TFLite model from student_model_31.tflite
2025-04-30 16:59:01,508 - INFO - TFLite model loaded successfully.
2025-04-30 16:59:01,509 - INFO - 
=== TFLite Model Details ===
2025-04-30 16:59:01,510 - INFO - TFLite Input Shape: [ 1 64  3]
2025-04-30 16:59:01,512 - INFO - TFLite Output Shape: [1 1]
2025-04-30 16:59:01,513 - INFO - TFLite Input Data Type: <class 'numpy.float32'>
2025-04-30 16:59:01,515 - ERROR - Failed to extract TFLite op info: 'Interpreter' object has no attribute '_model_path'
2025-04-30 16:59:01,516 - INFO - 
=== Layer/Operator Comparison ===
2025-04-30 16:59:01,516 - INFO - Keras model has 13 main layers
2025-04-30 16:59:01,517 - INFO - TFLite model has 0 tensors/operators
2025-04-30 16:59:01,518

In [17]:
import tensorflow as tf
import numpy as np
import os
import logging
import traceback
import shutil
from sklearn.metrics import mean_squared_error
from keras.saving import register_keras_serializable

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

# Define and register the custom TransModel class
@register_keras_serializable(package="CustomModels")
class TransModel(tf.keras.Model):
    def __init__(self, acc_frames=64, num_classes=1, num_heads=4, acc_coords=3, embed_dim=32, num_layers=2, dropout=0.5, activation='relu', **kwargs):
        super().__init__(**kwargs)
        self.acc_frames = acc_frames
        self.num_classes = num_classes
        self.num_heads = num_heads
        self.acc_coords = acc_coords
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.dropout_rate = dropout
        self.activation = activation
        
        # Define layers
        self.conv_layer = tf.keras.layers.Conv2D(filters=embed_dim, kernel_size=(8, 1), padding='same', name="conv_projection")
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layer_norm")
        self.attention_layers = [
            tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=dropout, name=f"mha_{i}")
            for i in range(num_layers)
        ]
        self.ffn_layers = [
            tf.keras.Sequential([
                tf.keras.layers.Dense(embed_dim * 2, activation=activation, name=f"ffn_dense1_{i}"),
                tf.keras.layers.Dropout(dropout),
                tf.keras.layers.Dense(embed_dim, name=f"ffn_dense2_{i}"),
                tf.keras.layers.Dropout(dropout)
            ], name=f"ffn_{i}")
            for i in range(num_layers)
        ]
        self.layer_norms = [
            [tf.keras.layers.LayerNormalization(epsilon=1e-6, name=f"ln{i}_{j}") for j in range(2)]
            for i in range(num_layers)
        ]
        self.final_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="final_norm")
        self.global_pool = tf.keras.layers.GlobalAveragePooling1D(name="global_pool")
        self.output_dense = tf.keras.layers.Dense(num_classes, name="output_dense")
        
        # Initialize with a dummy input to ensure the model is built
        dummy_input = tf.zeros((1, acc_frames, acc_coords), dtype=tf.float32)
        self(dummy_input, training=False)
        
    def call(self, inputs, training=False):
        x = inputs.get('accelerometer', inputs) if isinstance(inputs, dict) else inputs
        x = tf.expand_dims(x, axis=2)  # Add channel dimension for Conv2D
        x = self.conv_layer(x)
        x = tf.squeeze(x, axis=2)  # Remove channel dimension
        x = self.layer_norm(x)
        
        # Transformer layers
        for i in range(self.num_layers):
            attn = self.attention_layers[i](x, x, training=training)
            x = self.layer_norms[i][0](x + attn)  # Residual + normalization
            ffn = self.ffn_layers[i](x, training=training)
            x = self.layer_norms[i][1](x + ffn)  # Residual + normalization
            
        x = self.final_norm(x)
        x = self.global_pool(x)
        logits = self.output_dense(x)
        return tf.reshape(logits, [-1, self.num_classes])
    
    # Add get_config method for proper serialization
    def get_config(self):
        config = super().get_config()
        config.update({
            "acc_frames": self.acc_frames,
            "num_classes": self.num_classes,
            "num_heads": self.num_heads,
            "acc_coords": self.acc_coords,
            "embed_dim": self.embed_dim,
            "num_layers": self.num_layers,
            "dropout_rate": self.dropout_rate,
            "activation": self.activation
        })
        return config
        
    # Add TFLite export functionality
    def export_to_tflite(self, save_path, input_shape=None):
        """
        Export the model to TFLite format
        
        Args:
            save_path (str): Path to save the TFLite model
            input_shape (tuple, optional): Input shape for the model. Defaults to (1, acc_frames, acc_coords).
        
        Returns:
            bool: True if export was successful, False otherwise
        """
        if input_shape is None:
            input_shape = (1, self.acc_frames, self.acc_coords)
            
        try:
            os.makedirs(os.path.dirname(save_path) or '.', exist_ok=True)
            save_path = save_path if save_path.endswith('.tflite') else f"{save_path}.tflite"
            logger.info(f"Exporting to TFLite: {save_path}, shape={input_shape}")

            # Define a wrapper model to ensure proper input signature
            class TFLiteModel(tf.keras.Model):
                def __init__(self, parent):
                    super().__init__()
                    self.parent = parent

                @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32, name='accelerometer')])
                def call(self, inputs):
                    return self.parent({'accelerometer': inputs}, training=False)

            tflite_model = TFLiteModel(self)
            
            # Create a temporary directory for the SavedModel
            temp_dir = os.path.join(os.path.dirname(save_path) or '.', "temp_savedmodel")
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
            os.makedirs(temp_dir, exist_ok=True)

            # Save the model with signatures
            tf.saved_model.save(
                tflite_model, 
                temp_dir, 
                signatures={'serving_default': tflite_model.call}
            )
            
            # Convert to TFLite
            converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir)
            converter.target_spec.supported_ops = [
                tf.lite.OpsSet.TFLITE_BUILTINS, 
                tf.lite.OpsSet.SELECT_TF_OPS
            ]
            converter.inference_input_type = tf.float32
            converter.inference_output_type = tf.float32
            tflite_content = converter.convert()

            # Save TFLite model
            with open(save_path, 'wb') as f:
                f.write(tflite_content)
                
            # Clean up temporary directory
            shutil.rmtree(temp_dir)
            logger.info(f"TFLite model successfully saved: {save_path}")
            return True
        except Exception as e:
            logger.error(f"TFLite export failed: {e}\n{traceback.format_exc()}")
            if 'temp_dir' in locals() and os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
            return False


def get_keras_layer_info(model):
    """Extract layer information from a Keras model"""
    try:
        layers_info = []
        for layer in model.layers:
            layer_info = {
                'name': layer.name,
                'type': layer.__class__.__name__,
                'trainable': layer.trainable,
                'params': layer.count_params()
            }
            
            # Try to extract shape information if available
            try:
                if hasattr(layer, 'output_shape'):
                    layer_info['output_shape'] = str(layer.output_shape)
                if hasattr(layer, 'input_shape'):
                    layer_info['input_shape'] = str(layer.input_shape)
            except:
                pass
                
            # If it's a Sequential model, get info about its layers
            if isinstance(layer, tf.keras.Sequential):
                sublayers = []
                for sublayer in layer.layers:
                    sublayer_info = {
                        'name': sublayer.name,
                        'type': sublayer.__class__.__name__,
                        'trainable': sublayer.trainable,
                        'params': sublayer.count_params()
                    }
                    sublayers.append(sublayer_info)
                layer_info['sublayers'] = sublayers
                
            layers_info.append(layer_info)
        return layers_info
    except Exception as e:
        logger.error(f"Failed to extract Keras layer info: {e}")
        return []


def get_tflite_tensor_info(interpreter):
    """Extract tensor information from a TFLite interpreter"""
    try:
        # Get tensor details from the interpreter
        tensor_details = interpreter.get_tensor_details()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        
        # Organize tensors by their role
        tensors_info = []
        for tensor in tensor_details:
            tensor_idx = tensor['index']
            tensor_info = {
                'index': tensor_idx,
                'name': tensor['name'],
                'shape': str(tensor['shape']),
                'dtype': str(tensor['dtype']),
                'quantization': str(tensor.get('quantization', 'None')),
                'is_input': any(detail['index'] == tensor_idx for detail in input_details),
                'is_output': any(detail['index'] == tensor_idx for detail in output_details),
                'role': 'INTERMEDIATE'
            }
            
            if tensor_info['is_input']:
                tensor_info['role'] = 'INPUT'
            elif tensor_info['is_output']:
                tensor_info['role'] = 'OUTPUT'
                
            tensors_info.append(tensor_info)
            
        return tensors_info
    except Exception as e:
        logger.error(f"Failed to extract TFLite tensor info: {e}")
        return []


def compare_keras_tflite(keras_model_path, tflite_model_path, weights_path=None, sample_input=None, model_architecture=None):
    """
    Compare a Keras TransModel with its TFLite version for robustness and compatibility.
    
    Args:
        keras_model_path (str): Path to the .keras model file.
        tflite_model_path (str): Path to the .tflite model file (or where it will be saved if converted).
        weights_path (str, optional): Path to the .weights.h5 file if .keras is unavailable.
        sample_input (np.ndarray, optional): Sample input for inference comparison.
        model_architecture (callable, optional): Function to define model architecture if using .weights.h5.
    """
    keras_model = None
    
    # Load or create Keras model
    try:
        if os.path.exists(keras_model_path):
            logger.info(f"Loading Keras model from: {keras_model_path}")
            keras_model = tf.keras.models.load_model(
                keras_model_path, 
                custom_objects={'TransModel': TransModel}
            )
            logger.info("Keras model loaded from .keras file.")
        elif weights_path and model_architecture and os.path.exists(weights_path):
            logger.info(f"Creating model and loading weights from: {weights_path}")
            keras_model = model_architecture()
            keras_model.load_weights(weights_path)
            logger.info("Keras model loaded from .weights.h5 with provided architecture.")
        else:
            logger.info("Creating new model instance")
            keras_model = model_architecture() if model_architecture else TransModel()
            logger.info("New TransModel instance created.")
    except Exception as e:
        logger.error(f"Failed to load/create Keras model: {e}")
        return
    
    # Ensure model is built
    try:
        dummy_input = np.zeros((1, 64, 3), dtype=np.float32)
        _ = keras_model(dummy_input)
        logger.info("Keras model built successfully with dummy input.")
    except Exception as e:
        logger.error(f"Failed to build model: {e}")
    
    # Log Keras model details
    logger.info("\n=== Keras Model Details ===")
    
    # Variables to store shapes for comparison
    keras_input_shape_list = None
    keras_output_shape_list = None
    tflite_input_shape_list = None
    tflite_output_shape_list = None
    
    # Get shapes by inference instead of direct attribute access
    try:
        dummy_input = np.zeros((1, 64, 3), dtype=np.float32)
        test_output = keras_model.predict(dummy_input, verbose=0)
        keras_input_shape_list = [None, 64, 3]  # Input shape is known from our dummy input pattern
        keras_output_shape_list = list(test_output.shape)
        logger.info(f"Keras Input Shape (inferred): {keras_input_shape_list}")
        logger.info(f"Keras Output Shape (inferred): {keras_output_shape_list}")
    except Exception as e:
        logger.error(f"Failed to infer shapes: {e}")
    
    # Extract keras layer information
    keras_layers = get_keras_layer_info(keras_model)
    
    # Convert to TFLite if necessary
    if not os.path.exists(tflite_model_path):
        try:
            logger.info(f"Converting model to TFLite: {tflite_model_path}")
            # Try to use the built-in export method if it's a TransModel instance
            if isinstance(keras_model, TransModel):
                logger.info("Using TransModel's built-in TFLite export method")
                success = keras_model.export_to_tflite(tflite_model_path)
                if not success:
                    raise ValueError("Export failed using TransModel's export_to_tflite method")
            else:
                # Fallback to generic conversion
                logger.info("Using generic TFLite conversion")
                @tf.function(input_signature=[tf.TensorSpec(shape=(None, 64, 3), dtype=tf.float32)])
                def model_func(inputs):
                    return keras_model(inputs, training=False)
                concrete_func = model_func.get_concrete_function()
                converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
                tflite_model = converter.convert()
                with open(tflite_model_path, 'wb') as f:
                    f.write(tflite_model)
            logger.info(f"Converted Keras model to TFLite and saved to {tflite_model_path}.")
        except Exception as e:
            logger.error(f"Failed to convert Keras model to TFLite: {e}")
            return
    else:
        logger.info(f"Using existing TFLite model at {tflite_model_path}.")
    
    # Load TFLite model
    try:
        logger.info(f"Loading TFLite model from {tflite_model_path}")
        interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
        interpreter.allocate_tensors()
        logger.info("TFLite model loaded successfully.")
    except Exception as e:
        logger.error(f"Failed to load TFLite model: {e}")
        return
    
    # Get TFLite model details
    logger.info("\n=== TFLite Model Details ===")
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    tflite_input_shape = input_details[0]['shape']
    tflite_output_shape = output_details[0]['shape']
    tflite_input_shape_list = list(tflite_input_shape)
    tflite_output_shape_list = list(tflite_output_shape)
    logger.info(f"TFLite Input Shape: {tflite_input_shape_list}")
    logger.info(f"TFLite Output Shape: {tflite_output_shape_list}")
    logger.info(f"TFLite Input Data Type: {input_details[0]['dtype']}")
    
    # Extract TFLite tensor information
    tflite_tensors = get_tflite_tensor_info(interpreter)
    
    # Compare and display layer information
    logger.info("\n=== Layer/Operator Comparison ===")
    logger.info(f"Keras model has {len(keras_layers)} main layers")
    logger.info(f"TFLite model has {len(tflite_tensors)} tensors")
    
    # Display Keras layers
    logger.info("\nKeras Layers:")
    for i, layer in enumerate(keras_layers):
        logger.info(f"  {i+1}. {layer['name']} ({layer['type']}) - Params: {layer['params']}")
        if 'input_shape' in layer:
            logger.info(f"     Input shape: {layer['input_shape']}")
        if 'output_shape' in layer:
            logger.info(f"     Output shape: {layer['output_shape']}")
        if 'sublayers' in layer:
            for j, sublayer in enumerate(layer['sublayers']):
                logger.info(f"       {j+1}. {sublayer['name']} ({sublayer['type']}) - Params: {sublayer['params']}")
    
    # Display TFLite tensors
    logger.info("\nTFLite Tensors:")
    
    # First show inputs and outputs
    input_tensors = [t for t in tflite_tensors if t['is_input']]
    output_tensors = [t for t in tflite_tensors if t['is_output']]
    intermediate_tensors = [t for t in tflite_tensors if not t['is_input'] and not t['is_output']]
    
    # Show inputs
    if input_tensors:
        logger.info("\n  Input Tensors:")
        for i, tensor in enumerate(input_tensors):
            logger.info(f"    {i+1}. {tensor['name']} - Shape: {tensor['shape']}, Type: {tensor['dtype']}")
    
    # Show outputs
    if output_tensors:
        logger.info("\n  Output Tensors:")
        for i, tensor in enumerate(output_tensors):
            logger.info(f"    {i+1}. {tensor['name']} - Shape: {tensor['shape']}, Type: {tensor['dtype']}")
    
    # Show intermediate tensors (operators)
    if intermediate_tensors:
        logger.info("\n  Intermediate Tensors (operators):")
        # Group by tensor name to identify patterns
        tensor_groups = {}
        
        for tensor in intermediate_tensors:
            name_parts = tensor['name'].split('/')
            base_name = name_parts[0] if len(name_parts) > 0 else tensor['name']
            
            if base_name not in tensor_groups:
                tensor_groups[base_name] = []
            tensor_groups[base_name].append(tensor)
        
        # Display tensor groups
        for group_name, tensors in tensor_groups.items():
            logger.info(f"\n    {group_name}:")
            for i, tensor in enumerate(tensors[:5]):  # Show only first 5 tensors per group
                logger.info(f"      {i+1}. {tensor['name']} - Shape: {tensor['shape']}")
            if len(tensors) > 5:
                logger.info(f"      ... and {len(tensors) - 5} more")
    
    # Display TFLite operator summary
    try:
        # Try to get number of operations in the model
        op_codes = set()
        for i, tensor in enumerate(tflite_tensors):
            name_parts = tensor['name'].split('/')
            if len(name_parts) > 1:
                op_type = name_parts[1].split(':')[0] if ':' in name_parts[1] else name_parts[1]
                op_codes.add(op_type)
        
        if op_codes:
            logger.info("\n  TFLite Operation Types:")
            for i, op_type in enumerate(sorted(op_codes)):
                logger.info(f"    {i+1}. {op_type}")
    except Exception as e:
        logger.warning(f"Could not extract TFLite operation types: {e}")
    
    # Compare shapes
    logger.info("\n=== Shape Comparison ===")
    try:
        # Check if we have both input shapes for comparison
        if keras_input_shape_list is not None and tflite_input_shape_list is not None:
            # Compare all except batch dimension
            keras_rest = keras_input_shape_list[1:]
            tflite_rest = tflite_input_shape_list[1:]
            
            shapes_match = True
            for i in range(len(keras_rest)):
                if i >= len(tflite_rest) or keras_rest[i] != tflite_rest[i]:
                    shapes_match = False
                    break
            
            if shapes_match:
                logger.info("✓ Input shapes match (ignoring batch dimension).")
            else:
                logger.warning(f"✗ Input shapes differ: Keras {keras_input_shape_list} vs TFLite {tflite_input_shape_list}")
        else:
            logger.info(f"⚠ Full shape comparison unavailable - displaying separately:")
            logger.info(f"  Keras input (inferred): {keras_input_shape_list}")
            logger.info(f"  TFLite input: {tflite_input_shape_list}")
        
        # Check if we have both output shapes for comparison
        if keras_output_shape_list is not None and tflite_output_shape_list is not None:
            output_shapes_match = True
            for i in range(len(keras_output_shape_list)):
                if i >= len(tflite_output_shape_list) or keras_output_shape_list[i] != tflite_output_shape_list[i]:
                    output_shapes_match = False
                    break
                    
            if output_shapes_match:
                logger.info("✓ Output shapes match.")
            else:
                logger.warning(f"✗ Output shapes differ: Keras {keras_output_shape_list} vs TFLite {tflite_output_shape_list}")
        else:
            logger.info(f"⚠ Output shape comparison unavailable - displaying separately:")
            logger.info(f"  Keras output (inferred): {keras_output_shape_list}")
            logger.info(f"  TFLite output: {tflite_output_shape_list}")
    except Exception as e:
        logger.error(f"Error during shape comparison: {e}")
    
    # Generate sample input if not provided
    if sample_input is None:
        sample_input = np.random.randn(1, 64, 3).astype(np.float32)
        logger.info("Generated random input for inference comparison.")
    
    # Perform inference and compare outputs
    logger.info("\n=== Inference Comparison ===")
    try:
        logger.info("Running Keras model inference...")
        keras_output = keras_model.predict(sample_input, verbose=0)
        
        logger.info("Running TFLite model inference...")
        interpreter.set_tensor(input_details[0]['index'], sample_input)
        interpreter.invoke()
        tflite_output = interpreter.get_tensor(output_details[0]['index'])
        
        logger.info("Comparing outputs...")
        mse = mean_squared_error(keras_output.flatten(), tflite_output.flatten())
        logger.info(f"Mean Squared Error between outputs: {mse:.10f}")
        
        # Store MSE for later use
        outputs_match = mse < 1e-5
        
        if outputs_match:
            logger.info("✓ Outputs are approximately equal (MSE < 1e-5).")
        else:
            logger.warning(f"✗ Outputs differ significantly (MSE = {mse:.10f}).")
            
        # Display sample values from both outputs
        logger.info("\nSample output values:")
        keras_flat = keras_output.flatten()
        tflite_flat = tflite_output.flatten()
        num_samples = min(5, len(keras_flat))
        for i in range(num_samples):
            logger.info(f"  Value {i+1}: Keras={keras_flat[i]:.6f}, TFLite={tflite_flat[i]:.6f}, Diff={abs(keras_flat[i]-tflite_flat[i]):.6f}")
    except Exception as e:
        logger.error(f"Failed to compare outputs: {e}")
        outputs_match = False
    
    # Compare model sizes
    logger.info("\n=== Model Size Comparison ===")
    try:
        keras_size = os.path.getsize(keras_model_path) / (1024 * 1024) if os.path.exists(keras_model_path) else 0
        weights_size = os.path.getsize(weights_path) / (1024 * 1024) if weights_path and os.path.exists(weights_path) else 0
        tflite_size = os.path.getsize(tflite_model_path) / (1024 * 1024)
        
        if keras_size:
            logger.info(f"Keras model size (.keras): {keras_size:.2f} MB")
        if weights_size:
            logger.info(f"Weights size (.weights.h5): {weights_size:.2f} MB")
        logger.info(f"TFLite model size: {tflite_size:.2f} MB")
        
        if keras_size and tflite_size:
            reduction = (1 - tflite_size/keras_size) * 100
            logger.info(f"Size reduction: {reduction:.2f}%")
    except Exception as e:
        logger.error(f"Failed to compute model sizes: {e}")
    
    # Additional analysis
    logger.info("\n=== TFLite Model Analysis ===")
    try:
        # Count parameters
        keras_params = sum(layer['params'] for layer in keras_layers)
        logger.info(f"Total Keras parameters: {keras_params:,}")
        
        # Check for quantization
        quantized_tensors = [t for t in tflite_tensors if 'quantization' in t and t['quantization'] != 'None' and 'scale' in str(t['quantization'])]
        if quantized_tensors:
            logger.info(f"Model has {len(quantized_tensors)} quantized tensors.")
            quant_types = set(t['dtype'] for t in quantized_tensors)
            logger.info(f"Quantization data types: {', '.join(str(t) for t in quant_types)}")
        else:
            logger.info("Model does not use quantization.")
    except Exception as e:
        logger.error(f"Failed to perform additional analysis: {e}")
    
    # Provide compatibility summary
    logger.info("\n=== Compatibility Check ===")
    logger.info("✓ Keras model can be loaded and executed")
    logger.info("✓ TFLite model can be loaded and executed")
    
    # Safe checks for shape comparison results
    if 'outputs_match' in locals():
        match_text = "match (MSE < 1e-5)" if outputs_match else f"differ (MSE = {mse:.10f})"
        logger.info(f"✓ Inference results {match_text}")
    
    input_shapes_match = False
    if keras_input_shape_list is not None and tflite_input_shape_list is not None:
        if len(keras_input_shape_list) > 1 and len(tflite_input_shape_list) > 1:
            input_shapes_match = keras_input_shape_list[1:] == tflite_input_shape_list[1:]
            match_symbol = "✓" if input_shapes_match else "✗"
            match_text = "match" if input_shapes_match else "differ"
            logger.info(f"{match_symbol} Input shapes {match_text}")
    
    output_shapes_match = False
    if keras_output_shape_list is not None and tflite_output_shape_list is not None:
        output_shapes_match = keras_output_shape_list == tflite_output_shape_list
        match_symbol = "✓" if output_shapes_match else "✗"
        match_text = "match" if output_shapes_match else "differ"
        logger.info(f"{match_symbol} Output shapes {match_text}")
    
    logger.info("\n=== Comparison Complete ===")
    return keras_model, interpreter


# Example usage
if __name__ == "__main__":
    keras_model_path = "student_model_31.keras"
    tflite_model_path = "student_model_31.tflite"
    weights_path = "student_model_31.weights.h5"
    
    # Define a function to create a new instance of TransModel
    def create_trans_model():
        return TransModel()
    
    # Run comparison
    compare_keras_tflite(
        keras_model_path, 
        tflite_model_path, 
        weights_path=weights_path, 
        model_architecture=create_trans_model
    )

2025-04-30 17:09:56,633 - INFO - Loading Keras model from: student_model_31.keras
2025-04-30 17:09:57,236 - INFO - Keras model loaded from .keras file.
2025-04-30 17:09:57,312 - INFO - Keras model built successfully with dummy input.
2025-04-30 17:09:57,313 - INFO - 
=== Keras Model Details ===
2025-04-30 17:09:57,691 - INFO - Keras Input Shape (inferred): [None, 64, 3]
2025-04-30 17:09:57,691 - INFO - Keras Output Shape (inferred): [1, 1]
2025-04-30 17:09:57,693 - INFO - Using existing TFLite model at student_model_31.tflite.
2025-04-30 17:09:57,694 - INFO - Loading TFLite model from student_model_31.tflite
2025-04-30 17:09:57,696 - INFO - TFLite model loaded successfully.
2025-04-30 17:09:57,697 - INFO - 
=== TFLite Model Details ===
2025-04-30 17:09:57,698 - INFO - TFLite Input Shape: [1, 64, 3]
2025-04-30 17:09:57,699 - INFO - TFLite Output Shape: [1, 1]
2025-04-30 17:09:57,699 - INFO - TFLite Input Data Type: <class 'numpy.float32'>
2025-04-30 17:09:57,708 - INFO - 
=== Layer/Oper