Converts the fine-tuned Keras model to TFLite format with various optimization strategies for mobile deployment.

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
import json
import time
from pathlib import Path

# Configuration
MODELS_DIR = '../models/models_approach2/yamnet_finetuned'
TFLITE_DIR = '../models/models_approach2/tflite'
FEATURES_DIR = '../data/approach2/features'
TEST_SAMPLES = 100  # Number of samples for testing

os.makedirs(TFLITE_DIR, exist_ok=True)


In [None]:
# 1. Load Fine-Tuned Model
print("\nLoading fine-tuned model...")

model_path = os.path.join(MODELS_DIR, 'yamnet_finetuned_final.keras')
if not os.path.exists(model_path):
    print(f"Model not found at {model_path}")
    print("Trying best_model.keras...")
    model_path = os.path.join(MODELS_DIR, 'best_model.keras')

model = keras.models.load_model(model_path)
print(f"Model loaded from {model_path}")

# Load model config
config_path = os.path.join(MODELS_DIR, 'model_config.json')
with open(config_path, 'r') as f:
    model_config = json.load(f)

categories = model_config['categories']
num_classes = model_config['num_classes']
TARGET_SR = model_config['sample_rate']

print(f"  Classes: {categories}")
print(f"  Number of classes: {num_classes}")


In [None]:
# 2. Prepare Representative Dataset for Quantization
print("\nPreparing representative dataset for quantization...")

# Load test data for representative samples
test_meta = pd.read_csv(os.path.join(FEATURES_DIR, 'test_metadata.csv'))
sample_paths = test_meta['frame_path'].sample(min(TEST_SAMPLES, len(test_meta)), 
                                               random_state=42).tolist()

def representative_dataset_gen():
    """Generator for representative dataset."""
    for path in sample_paths:
        try:
            audio = np.load(path)
            # Ensure correct shape and dtype
            audio = audio.astype(np.float32)
            audio = np.expand_dims(audio, axis=0)  # Add batch dimension
            yield [audio]
        except Exception as e:
            print(f"Warning: Could not load {path}: {e}")
            continue

print(f"Prepared {len(sample_paths)} representative samples")


In [None]:

# 3. Convert to TFLite - Float32 (No Quantization)
print("\nConverting to TFLite (Float32 - No Quantization)...")

converter_float32 = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_float32 = converter_float32.convert()

# Save
tflite_float32_path = os.path.join(TFLITE_DIR, 'yamnet_finetuned_float32.tflite')
with open(tflite_float32_path, 'wb') as f:
    f.write(tflite_float32)

float32_size = len(tflite_float32) / (1024 * 1024)
print(f"Float32 model saved to {tflite_float32_path}")
print(f"  Size: {float32_size:.2f} MB")


In [None]:
# 4. Convert to TFLite - Dynamic Range Quantization
print("\nConverting to TFLite (Dynamic Range Quantization)...")

converter_dynamic = tf.lite.TFLiteConverter.from_keras_model(model)
converter_dynamic.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_dynamic = converter_dynamic.convert()

# Save
tflite_dynamic_path = os.path.join(TFLITE_DIR, 'yamnet_finetuned_dynamic.tflite')
with open(tflite_dynamic_path, 'wb') as f:
    f.write(tflite_dynamic)

dynamic_size = len(tflite_dynamic) / (1024 * 1024)
compression_ratio = float32_size / dynamic_size
print(f"Dynamic quantized model saved to {tflite_dynamic_path}")
print(f"  Size: {dynamic_size:.2f} MB")
print(f"  Compression: {compression_ratio:.2f}x")


In [None]:
# 5. Convert to TFLite - Float16 Quantization
print("\nConverting to TFLite (Float16 Quantization)...")

converter_float16 = tf.lite.TFLiteConverter.from_keras_model(model)
converter_float16.optimizations = [tf.lite.Optimize.DEFAULT]
converter_float16.target_spec.supported_types = [tf.float16]

tflite_float16 = converter_float16.convert()

# Save
tflite_float16_path = os.path.join(TFLITE_DIR, 'yamnet_finetuned_float16.tflite')
with open(tflite_float16_path, 'wb') as f:
    f.write(tflite_float16)

float16_size = len(tflite_float16) / (1024 * 1024)
compression_ratio_16 = float32_size / float16_size
print(f"Float16 quantized model saved to {tflite_float16_path}")
print(f"  Size: {float16_size:.2f} MB")
print(f"  Compression: {compression_ratio_16:.2f}x")



In [None]:

# 6. Convert to TFLite - Full Integer Quantization (INT8)

print("\nConverting to TFLite (Full Integer Quantization - INT8)...")

converter_int8 = tf.lite.TFLiteConverter.from_keras_model(model)
converter_int8.optimizations = [tf.lite.Optimize.DEFAULT]
converter_int8.representative_dataset = representative_dataset_gen

# For full integer quantization
converter_int8.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter_int8.inference_input_type = tf.int8  # or tf.uint8
converter_int8.inference_output_type = tf.int8  # or tf.uint8

try:
    tflite_int8 = converter_int8.convert()
    
    # Save
    tflite_int8_path = os.path.join(TFLITE_DIR, 'yamnet_finetuned_int8.tflite')
    with open(tflite_int8_path, 'wb') as f:
        f.write(tflite_int8)
    
    int8_size = len(tflite_int8) / (1024 * 1024)
    compression_ratio_int8 = float32_size / int8_size
    print(f"INT8 quantized model saved to {tflite_int8_path}")
    print(f"  Size: {int8_size:.2f} MB")
    print(f"  Compression: {compression_ratio_int8:.2f}x")
    
    int8_available = True
except Exception as e:
    print(f"INT8 quantization failed: {str(e)}")
    print("  This is common with complex models. Using Float16 as most compressed version.")
    int8_available = False
    tflite_int8_path = None


In [None]:
# 7. Test TFLite Models - Accuracy and Latency
print("\nTesting TFLite models (accuracy and latency)...")

def test_tflite_model(tflite_path, test_samples=50):
    """Test a TFLite model."""
    # Load interpreter
    interpreter = tf.lite.Interpreter(model_path=tflite_path)
    interpreter.allocate_tensors()
    
    # Get input and output details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    print(f"\n  Input details: {input_details[0]['shape']}, {input_details[0]['dtype']}")
    print(f"  Output details: {output_details[0]['shape']}, {output_details[0]['dtype']}")
    
    # Test on samples
    correct = 0
    total = 0
    latencies = []
    
    test_paths = test_meta.sample(min(test_samples, len(test_meta)), 
                                  random_state=42)
    
    for _, row in test_paths.iterrows():
        try:
            # Load audio
            audio = np.load(row['frame_path']).astype(np.float32)
            audio = np.expand_dims(audio, axis=0)
            
            # Convert input if needed
            if input_details[0]['dtype'] == np.int8:
                # Quantize input
                input_scale, input_zero_point = input_details[0]['quantization']
                audio = audio / input_scale + input_zero_point
                audio = audio.astype(np.int8)
            
            # Set input
            interpreter.set_tensor(input_details[0]['index'], audio)
            
            # Run inference
            start_time = time.time()
            interpreter.invoke()
            latency = (time.time() - start_time) * 1000  # ms
            latencies.append(latency)
            
            # Get output
            output = interpreter.get_tensor(output_details[0]['index'])
            
            # Dequantize output if needed
            if output_details[0]['dtype'] == np.int8:
                output_scale, output_zero_point = output_details[0]['quantization']
                output = (output.astype(np.float32) - output_zero_point) * output_scale
            
            pred = np.argmax(output[0])
            
            if pred == row['label']:
                correct += 1
            total += 1
            
        except Exception as e:
            print(f"  Warning: Error processing sample: {e}")
            continue
    
    accuracy = correct / total if total > 0 else 0
    avg_latency = np.mean(latencies) if latencies else 0
    std_latency = np.std(latencies) if latencies else 0
    
    return {
        'accuracy': accuracy,
        'avg_latency_ms': avg_latency,
        'std_latency_ms': std_latency,
        'samples_tested': total
    }

# Test all models
models_to_test = [
    ('Float32', tflite_float32_path, float32_size),
    ('Dynamic Quant', tflite_dynamic_path, dynamic_size),
    ('Float16', tflite_float16_path, float16_size),
]

if int8_available:
    models_to_test.append(('INT8', tflite_int8_path, int8_size))

results = []

for model_name, model_path, model_size in models_to_test:
    print(f"\nTesting {model_name} model...")
    test_results = test_tflite_model(model_path, test_samples=TEST_SAMPLES)
    
    results.append({
        'Model': model_name,
        'Size (MB)': model_size,
        'Accuracy': test_results['accuracy'],
        'Avg Latency (ms)': test_results['avg_latency_ms'],
        'Std Latency (ms)': test_results['std_latency_ms'],
        'Samples': test_results['samples_tested']
    })
    
    print(f"  Accuracy: {test_results['accuracy']:.4f}")
    print(f"  Avg Latency: {test_results['avg_latency_ms']:.2f} ms")
    print(f"  Std Latency: {test_results['std_latency_ms']:.2f} ms")


In [None]:
# 8. Compare Models
print("\n" + "="*70)
print("TFLITE MODEL COMPARISON")
print("="*70)

results_df = pd.DataFrame(results)
print("\n" + results_df.to_string(index=False))

# Save comparison
results_df.to_csv(os.path.join(TFLITE_DIR, 'tflite_comparison.csv'), index=False)

# Visualize comparison
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Size comparison
axes[0].bar(results_df['Model'], results_df['Size (MB)'], color='steelblue')
axes[0].set_title('Model Size Comparison', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Size (MB)', fontsize=12)
axes[0].set_xlabel('Model Type', fontsize=12)
axes[0].tick_params(axis='x', rotation=45)
axes[0].grid(axis='y', alpha=0.3)

# Add value labels
for i, v in enumerate(results_df['Size (MB)']):
    axes[0].text(i, v + 0.5, f'{v:.2f}', ha='center', fontsize=10)

# Accuracy comparison
axes[1].bar(results_df['Model'], results_df['Accuracy'], color='coral')
axes[1].set_title('Accuracy Comparison', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_xlabel('Model Type', fontsize=12)
axes[1].set_ylim([0, 1])
axes[1].tick_params(axis='x', rotation=45)
axes[1].grid(axis='y', alpha=0.3)

# Add value labels
for i, v in enumerate(results_df['Accuracy']):
    axes[1].text(i, v + 0.02, f'{v:.4f}', ha='center', fontsize=10)

# Latency comparison
axes[2].bar(results_df['Model'], results_df['Avg Latency (ms)'], color='lightgreen')
axes[2].set_title('Inference Latency Comparison', fontsize=14, fontweight='bold')
axes[2].set_ylabel('Average Latency (ms)', fontsize=12)
axes[2].set_xlabel('Model Type', fontsize=12)
axes[2].tick_params(axis='x', rotation=45)
axes[2].grid(axis='y', alpha=0.3)

# Add value labels
for i, v in enumerate(results_df['Avg Latency (ms)']):
    axes[2].text(i, v + 1, f'{v:.2f}', ha='center', fontsize=10)

plt.tight_layout()
plt.savefig(os.path.join(TFLITE_DIR, 'tflite_comparison.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✓ Comparison plot saved to {TFLITE_DIR}/tflite_comparison.png")


# ============================================================================
# 8. Compare Models
# ============================================================================
print("\n" + "="*70)
print("TFLITE MODEL COMPARISON")
print("="*70)

results_df = pd.DataFrame(results)
print("\n" + results_df.to_string(index=False))

# Save comparison
results_df.to_csv(os.path.join(TFLITE_DIR, 'tflite_comparison.csv'), index=False)

# Visualize comparison
import matplotlib.pyplot as plt
import seaborn as sns

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Size comparison
axes[0].bar(results_df['Model'], results_df['Size (MB)'], color='steelblue')
axes[0].set_title('Model Size Comparison', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Size (MB)', fontsize=12)
axes[0].set_xlabel('Model Type', fontsize=12)
axes[0].tick_params(axis='x', rotation=45)
axes[0].grid(axis='y', alpha=0.3)

# Add value labels
for i, v in enumerate(results_df['Size (MB)']):
    axes[0].text(i, v + 0.5, f'{v:.2f}', ha='center', fontsize=10)

# Accuracy comparison
axes[1].bar(results_df['Model'], results_df['Accuracy'], color='coral')
axes[1].set_title('Accuracy Comparison', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_xlabel('Model Type', fontsize=12)
axes[1].set_ylim([0, 1])
axes[1].tick_params(axis='x', rotation=45)
axes[1].grid(axis='y', alpha=0.3)

# Add value labels
for i, v in enumerate(results_df['Accuracy']):
    axes[1].text(i, v + 0.02, f'{v:.4f}', ha='center', fontsize=10)

# Latency comparison
axes[2].bar(results_df['Model'], results_df['Avg Latency (ms)'], color='lightgreen')
axes[2].set_title('Inference Latency Comparison', fontsize=14, fontweight='bold')
axes[2].set_ylabel('Average Latency (ms)', fontsize=12)
axes[2].set_xlabel('Model Type', fontsize=12)
axes[2].tick_params(axis='x', rotation=45)
axes[2].grid(axis='y', alpha=0.3)

# Add value labels
for i, v in enumerate(results_df['Avg Latency (ms)']):
    axes[2].text(i, v + 1, f'{v:.2f}', ha='center', fontsize=10)

plt.tight_layout()
plt.savefig(os.path.join(TFLITE_DIR, 'tflite_comparison.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✓ Comparison plot saved to {TFLITE_DIR}/tflite_comparison.png")



In [None]:
# Find best model based on accuracy-size tradeoff
results_df['Score'] = results_df['Accuracy'] / (results_df['Size (MB)'] ** 0.5)
best_idx = results_df['Score'].idxmax()
best_model = results_df.iloc[best_idx]

print(f"\nRecommended model: {best_model['Model']}")
print(f"  - Size: {best_model['Size (MB)']:.2f} MB")
print(f"  - Accuracy: {best_model['Accuracy']:.4f}")
print(f"  - Avg Latency: {best_model['Avg Latency (ms)']:.2f} ms")

