In [1]:
config = {
    "use_aug" : False,
    "num_classes" : 264,
    "batch_size" : 64,
    "epochs" : 1,
    "PRECISION" : 16,
    "PATIENCE" : 8,    
    "seed" : 64,
    "model" : "tf_efficientnet_b1_ns",
    "pretrained" : True,            
    "weight_decay" : 1e-3,
    "use_mixup" : True,
    "mixup_alpha" : 0.6,  

    "train_images" : "/mnt/Stuff/phd_projects/esp32-projects/bird_call_id/splitted_dataset_color/train",
    "valid_images" : "/mnt/Stuff/phd_projects/esp32-projects/bird_call_id/splitted_dataset_color/test",
    
    "SR" : 32000,
    "DURATION" : 5,
    "MAX_READ_SAMPLES" : 5,
    "LR" : 1e-3
}

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_model_optimization as tfmot
import numpy as np
import os

# Configuration
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 1
FINE_TUNE_EPOCHS = 30
QAT_EPOCHS = 1
LEARNING_RATE = 0.001
DATA_DIR = '/mnt/Stuff/phd_projects/esp32-projects/bird_call_id/splitted_dataset_color'  # Update this path

# 1. DATA LOADING
def create_datasets(data_dir, img_size=IMG_SIZE, batch_size=BATCH_SIZE):
    """Load and prepare train/test datasets"""
    train_dir = os.path.join(data_dir, 'train')
    test_dir = os.path.join(data_dir, 'test')
    
    # Data augmentation for training
    train_datagen = keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        zoom_range=0.2,
        shear_range=0.2,
        fill_mode='nearest',
        validation_split=0.2  # 20% for validation
    )
    
    test_datagen = keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255
    )
    
    train_ds = train_datagen.flow_from_directory(
        train_dir,
        target_size=(img_size, img_size),
        batch_size=batch_size,
        class_mode='categorical',
        subset='training',
        shuffle=True
    )
    
    val_ds = train_datagen.flow_from_directory(
        train_dir,
        target_size=(img_size, img_size),
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation',
        shuffle=False
    )
    
    test_ds = test_datagen.flow_from_directory(
        test_dir,
        target_size=(img_size, img_size),
        batch_size=batch_size,
        class_mode='categorical',
        shuffle=False
    )
    
    num_classes = len(train_ds.class_indices)
    
    return train_ds, val_ds, test_ds, num_classes


# 2. MODEL BUILDING
def build_model(num_classes, img_size=IMG_SIZE):
    """Build EfficientNetB0 model with custom head"""
    inputs = keras.Input(shape=(img_size, img_size, 3))
    
    # Load EfficientNetB0 backbone (trainable)
    backbone = keras.applications.EfficientNetB0(
        include_top=False,
        weights='imagenet',
        input_tensor=inputs,
        pooling='avg'
    )
    backbone.trainable = True  # Backbone is NOT frozen
    
    x = backbone.output
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    return model


# 3. INITIAL TRAINING
def train_model(model, train_ds, val_ds, epochs=EPOCHS, lr=LEARNING_RATE):
    """Initial training phase"""
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr),
        loss='categorical_crossentropy',
        metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_acc')]
    )
    
    callbacks = [
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-7
        ),
        # keras.callbacks.ModelCheckpoint(
        #     'best_model',
        #     monitor='val_accuracy',
        #     save_best_only=True,
        #     mode='max',
        #     save_format='tf'  # Use SavedModel format
        # )
    ]
    
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=callbacks
    )
    
    return history


# 4. PRUNING
def apply_pruning(model, train_ds, val_ds, epochs=FINE_TUNE_EPOCHS):
    """Apply magnitude-based pruning"""
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    
    # Pruning configuration
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.0,
            final_sparsity=0.5,
            begin_step=0,
            end_step=len(train_ds) * epochs
        )
    }
    
    # Apply pruning to model
    model_for_pruning = prune_low_magnitude(model, **pruning_params)
    
    model_for_pruning.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE * 0.1),
        loss='categorical_crossentropy',
        metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_acc')]
    )
    
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=5,
            restore_best_weights=True
        )
    ]
    
    print("\n" + "="*50)
    print("PRUNING PHASE")
    print("="*50)
    
    history = model_for_pruning.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=callbacks
    )
    
    # Strip pruning wrappers
    model_pruned = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    
    return model_pruned, history


# 5. QUANTIZATION AWARE TRAINING (QAT)
def apply_qat(model, train_ds, val_ds, epochs=QAT_EPOCHS):
    """Apply Quantization Aware Training"""
    quantize_model = tfmot.quantization.keras.quantize_model
    
    # Apply QAT
    q_aware_model = quantize_model(model)
    
    q_aware_model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE * 0.01),
        loss='categorical_crossentropy',
        metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_acc')]
    )
    
    callbacks = [
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=5,
            restore_best_weights=True
        )
    ]
    
    print("\n" + "="*50)
    print("QUANTIZATION AWARE TRAINING PHASE")
    print("="*50)
    
    history = q_aware_model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=callbacks
    )
    
    return q_aware_model, history


# 6. TFLITE CONVERSION
def convert_to_tflite(model, representative_dataset=None, quantize=True):
    """Convert model to TFLite format"""
    # For QAT models, we need to convert directly without saving first
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    
    if quantize:
        # Full integer quantization
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        
        if representative_dataset is not None:
            converter.representative_dataset = representative_dataset
            converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
            converter.inference_input_type = tf.uint8
            converter.inference_output_type = tf.uint8
    
    tflite_model = converter.convert()
    
    return tflite_model


def save_qat_model(model, path):
    """Save QAT model properly by converting to concrete function"""
    try:
        # Try direct save first
        model.save(path, save_format='tf')
        print(f"Model saved successfully to {path}/")
    except Exception as e:
        print(f"Direct save failed: {e}")
        print("Attempting alternative save method...")
        
        # Alternative: Clone the model to a standard Keras model
        try:
            # Get model config and weights
            config = model.get_config()
            weights = model.get_weights()
            
            # Create new model from config
            new_model = keras.Model.from_config(config)
            new_model.set_weights(weights)
            
            # Save the cloned model
            new_model.save(path, save_format='tf')
            print(f"Model saved successfully using cloning method to {path}/")
        except Exception as e2:
            print(f"Cloning save also failed: {e2}")
            print("Model will only be saved as TFLite format")


def representative_data_gen(dataset, num_samples=100):
    """Generate representative dataset for quantization"""
    def gen():
        for i, (images, _) in enumerate(dataset):
            if i >= num_samples // BATCH_SIZE:
                break
            for img in images:
                yield [np.expand_dims(img, axis=0).astype(np.float32)]
    return gen


# 7. EVALUATION
def evaluate_model(model, test_ds):
    """Evaluate model performance"""
    results = model.evaluate(test_ds)
    print(f"\nTest Loss: {results[0]:.4f}")
    print(f"Test Accuracy: {results[1]:.4f}")
    print(f"Test Top-3 Accuracy: {results[2]:.4f}")
    return results


def evaluate_tflite(tflite_model, test_ds, num_samples=500):
    """Evaluate TFLite model"""
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()
    
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    correct = 0
    total = 0
    
    for images, labels in test_ds:
        if total >= num_samples:
            break
            
        for i in range(len(images)):
            img = np.expand_dims(images[i], axis=0)
            
            # Adjust input type
            if input_details[0]['dtype'] == np.uint8:
                img = (img * 255).astype(np.uint8)
            else:
                img = img.astype(np.float32)
            
            interpreter.set_tensor(input_details[0]['index'], img)
            interpreter.invoke()
            predictions = interpreter.get_tensor(output_details[0]['index'])
            
            pred_class = np.argmax(predictions[0])
            true_class = np.argmax(labels[i])
            
            if pred_class == true_class:
                correct += 1
            total += 1
            
            if total >= num_samples:
                break
    
    accuracy = correct / total
    print(f"\nTFLite Model Accuracy: {accuracy:.4f}")
    return accuracy


# 8. MAIN PIPELINE
def main():
    print("="*60)
    print("EFFICIENTNETB0 TRAINING PIPELINE")
    print("="*60)
    
    # Load datasets
    print("\n1. Loading datasets...")
    train_ds, val_ds, test_ds, num_classes = create_datasets(DATA_DIR)
    print(f"Number of classes: {num_classes}")
    
    # Build model
    print("\n2. Building model...")
    model = build_model(num_classes)
    print(f"Total parameters: {model.count_params():,}")
    
    # Initial training
    print("\n3. Initial training...")
    history = train_model(model, train_ds, val_ds)
    
    # Evaluate base model
    # print("\n4. Evaluating base model...")
    # evaluate_model(model, test_ds)
    
    # Apply pruning
    print("\n5. Applying pruning...")
    model_pruned, prune_history = apply_pruning(model, train_ds, val_ds)
    evaluate_model(model_pruned, test_ds)
    
    # Save pruned model
    print("\nSaving pruned model...")
    model_pruned.save('pruned_model', save_format='tf')
    
    # Apply QAT
    print("\n6. Applying Quantization Aware Training...")
    model_qat, qat_history = apply_qat(model_pruned, train_ds, val_ds)
    evaluate_model(model_qat, test_ds)
    
    # Convert to TFLite (direct conversion without intermediate save)
    print("\n7. Converting to TFLite...")
    
    # Create representative dataset for full integer quantization
    rep_data_gen = representative_data_gen(train_ds)
    
    # Convert directly to TFLite (this avoids the tracking issue)
    tflite_model = convert_to_tflite(model_qat, rep_data_gen, quantize=True)
    
    # Save TFLite model
    with open('model_quantized.tflite', 'wb') as f:
        f.write(tflite_model)
    
    tflite_size = len(tflite_model) / 1024 / 1024
    print(f"TFLite model size: {tflite_size:.2f} MB")
    
    # Optionally try to save the QAT model (may fail, but TFLite is already saved)
    print("\n8. Attempting to save QAT model...")
    save_qat_model(model_qat, 'final_model')
    
    # Evaluate TFLite model
    print("\n9. Evaluating TFLite model...")
    evaluate_tflite(tflite_model, test_ds)


if __name__ == "__main__":
    # Verify TensorFlow Model Optimization toolkit is installed
    try:
        import tensorflow_model_optimization as tfmot
    except ImportError:
        print("Please install: pip install tensorflow-model-optimization")
        exit(1)
    
    main()

2025-10-20 00:27:48.176294: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-20 00:27:48.213047: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-10-20 00:27:49.049944: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


EFFICIENTNETB0 TRAINING PIPELINE

1. Loading datasets...
Found 10860 images belonging to 264 classes.
Found 2585 images belonging to 264 classes.
Found 3496 images belonging to 264 classes.
Number of classes: 264

2. Building model...


I0000 00:00:1760898470.487851   88091 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9648 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060, pci bus id: 0000:01:00.0, compute capability: 8.6


Total parameters: 4,049,571

3. Initial training...


E0000 00:00:1760898479.196422   88091 meta_optimizer.cc:967] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inmodel/block2b_drop/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2025-10-20 00:28:00.656288: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91002
2025-10-20 00:28:01.381926: I external/local_xla/xla/service/service.cc:163] XLA service 0x703485128c70 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-10-20 00:28:01.381938: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
2025-10-20 00:28:01.386269: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1760898481.434005   88200 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at 


5. Applying pruning...


ValueError: Please initialize `Prune` with a supported layer. Layers should either be supported by the PruneRegistry (built-in keras layers) or should be a `PrunableLayer` instance, or should has a customer defined `get_prunable_weights` method. You passed: <class 'tf_keras.src.layers.preprocessing.image_preprocessing.Rescaling'>