In [1]:
config = {
    "use_aug" : False,
    "num_classes" : 264,
    "batch_size" : 64,
    "epochs" : 1,
    "PRECISION" : 16,
    "PATIENCE" : 8,    
    "seed" : 64,
    "model" : "tf_efficientnet_b1_ns",
    "pretrained" : True,            
    "weight_decay" : 1e-3,
    "use_mixup" : True,
    "mixup_alpha" : 0.6,  

    "train_images" : "/mnt/Stuff/phd_projects/esp32-projects/bird_call_id/splitted_dataset_color/train",
    "valid_images" : "/mnt/Stuff/phd_projects/esp32-projects/bird_call_id/splitted_dataset_color/test",
    
    "SR" : 32000,
    "DURATION" : 5,
    "MAX_READ_SAMPLES" : 5,
    "LR" : 1e-3
}

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_model_optimization as tfmot
import numpy as np
import os
from tensorflow_model_optimization.python.core.keras.compat import keras

# Configuration
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 1
FINE_TUNE_EPOCHS = 30
QAT_EPOCHS = 1
LEARNING_RATE = 0.001
DATA_DIR = '/mnt/Stuff/phd_projects/esp32-projects/bird_call_id/splitted_dataset_color'  # Update this path

# 1. DATA LOADING
def create_datasets(data_dir, img_size=IMG_SIZE, batch_size=BATCH_SIZE):
    """Load and prepare train/test datasets"""
    train_dir = os.path.join(data_dir, 'train')
    test_dir = os.path.join(data_dir, 'test')
    
    # Data augmentation for training
    train_datagen = keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        zoom_range=0.2,
        shear_range=0.2,
        fill_mode='nearest',
        validation_split=0.2  # 20% for validation
    )
    
    test_datagen = keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255
    )
    
    train_ds = train_datagen.flow_from_directory(
        train_dir,
        target_size=(img_size, img_size),
        batch_size=batch_size,
        class_mode='categorical',
        subset='training',
        shuffle=True
    )
    
    val_ds = train_datagen.flow_from_directory(
        train_dir,
        target_size=(img_size, img_size),
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation',
        shuffle=False
    )
    
    test_ds = test_datagen.flow_from_directory(
        test_dir,
        target_size=(img_size, img_size),
        batch_size=batch_size,
        class_mode='categorical',
        shuffle=False
    )
    
    num_classes = len(train_ds.class_indices)
    
    return train_ds, val_ds, test_ds, num_classes


# 2. MODEL BUILDING
def build_model(num_classes, img_size=IMG_SIZE):
    """Build EfficientNetB0 model with custom head"""

    model = keras.Sequential([
        # Input layer
        keras.layers.InputLayer(input_shape=(img_size, img_size, 3)),

        # Initial Conv + BN + Activation (Stem)
        keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('swish'),

        # Depthwise Separable Convolution Block 1
        keras.layers.DepthwiseConv2D((3, 3), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('swish'),
        keras.layers.Conv2D(16, (1, 1), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('swish'),

        # Depthwise Separable Convolution Block 2
        keras.layers.DepthwiseConv2D((3, 3), strides=(2, 2), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('swish'),
        keras.layers.Conv2D(24, (1, 1), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('swish'),

        # Depthwise Separable Convolution Block 3
        keras.layers.DepthwiseConv2D((3, 3), strides=(2, 2), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('swish'),
        keras.layers.Conv2D(40, (1, 1), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('swish'),

        # Depthwise Separable Convolution Block 4
        keras.layers.DepthwiseConv2D((3, 3), strides=(2, 2), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('swish'),
        keras.layers.Conv2D(80, (1, 1), padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.Activation('swish'),

        # Global Average Pooling like EfficientNet
        keras.layers.GlobalAveragePooling2D(),

        # Final classification layer
        keras.layers.Dense(num_classes, activation='softmax')
    ])

    return model


# 3. INITIAL TRAINING
def train_model(model, train_ds, val_ds, epochs=EPOCHS, lr=LEARNING_RATE):
    """Initial training phase"""
    model.compile(
        optimizer=keras.optimizers.AdamW(learning_rate=lr),
        loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=['accuracy']
    )
    
    callbacks = [
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-7
        ),
        # keras.callbacks.ModelCheckpoint(
        #     'best_model',
        #     monitor='val_accuracy',
        #     save_best_only=True,
        #     mode='max',
        #     save_format='tf'  # Use SavedModel format
        # )
    ]
    
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=callbacks
    )
    
    return history


# 4. PRUNING
def apply_pruning(model, train_ds, val_ds, epochs=FINE_TUNE_EPOCHS):
    """Apply magnitude-based pruning"""
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    
    # Pruning configuration
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.0,
            final_sparsity=0.5,
            begin_step=0,
            end_step=len(train_ds) * epochs
        )
    }
    
    # Apply pruning to model
    model_for_pruning = prune_low_magnitude(model, **pruning_params)
    
    model_for_pruning.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE * 0.1),
        loss='categorical_crossentropy',
        metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_acc')]
    )
    
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=5,
            restore_best_weights=True
        )
    ]
    
    print("\n" + "="*50)
    print("PRUNING PHASE")
    print("="*50)
    
    history = model_for_pruning.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=callbacks
    )
    
    # Strip pruning wrappers
    model_pruned = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    
    return model_pruned, history


# 7. EVALUATION
def evaluate_model(model, test_ds):
    """Evaluate model performance"""
    results = model.evaluate(test_ds)
    print(f"\nTest Loss: {results[0]:.4f}")
    print(f"Test Accuracy: {results[1]:.4f}")
    print(f"Test Top-3 Accuracy: {results[2]:.4f}")
    return results

In [4]:
train_ds, val_ds, test_ds, num_classes = create_datasets(DATA_DIR)

model = build_model(num_classes)

Found 10860 images belonging to 264 classes.
Found 2585 images belonging to 264 classes.
Found 3496 images belonging to 264 classes.


In [5]:
quantize_model = tfmot.quantization.keras.quantize_model

q_aware_model = quantize_model(model)

q_aware_model.compile(
        optimizer=keras.optimizers.AdamW(learning_rate=0.001),
        loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=['accuracy']
    )

In [10]:
callbacks = [
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-7
        ),
        # keras.callbacks.ModelCheckpoint(
        #     'best_model',
        #     monitor='val_accuracy',
        #     save_best_only=True,
        #     mode='max',
        #     save_format='tf'  # Use SavedModel format
        # )
    ]

q_aware_model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=1,
        callbacks=callbacks
    )

2025-10-20 22:59:04.712765: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91002
2025-10-20 22:59:05.180486: I external/local_xla/xla/service/service.cc:163] XLA service 0x771044cb52b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-10-20 22:59:05.180497: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
I0000 00:00:1760979545.233024   40138 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.




<tf_keras.src.callbacks.History at 0x77141c70de50>

In [6]:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()

INFO:tensorflow:Assets written to: /tmp/tmpkdj58zla/assets


INFO:tensorflow:Assets written to: /tmp/tmpkdj58zla/assets
W0000 00:00:1760979018.271967   40038 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1760979018.271981   40038 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2025-10-20 22:50:18.272216: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpkdj58zla
2025-10-20 22:50:18.280161: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-10-20 22:50:18.280170: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpkdj58zla
I0000 00:00:1760979018.316171   40038 mlir_graph_optimization_pass.cc:437] MLIR V1 optimization pass is not enabled
2025-10-20 22:50:18.325412: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-10-20 22:50:18.464260: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpkdj58zla
2025-10-20 22:50:18.508

In [7]:
import tempfile

float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))

INFO:tensorflow:Assets written to: /tmp/tmpx277a8bn/assets


INFO:tensorflow:Assets written to: /tmp/tmpx277a8bn/assets


Float model in Mb: 0.11505889892578125
Quantized model in Mb: 0.0447998046875


W0000 00:00:1760979022.161131   40038 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1760979022.161145   40038 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2025-10-20 22:50:22.161251: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpx277a8bn
2025-10-20 22:50:22.164891: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-10-20 22:50:22.164899: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpx277a8bn
2025-10-20 22:50:22.182822: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-10-20 22:50:22.237997: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpx277a8bn
2025-10-20 22:50:22.257607: I tensorflow/cc/saved_model/loader.cc:471] SavedModel load for tags { serve }; Status: success: OK. Took 96358 microseconds.


In [9]:
with open('quant_model.tflite', 'wb') as f:
  f.write(quantized_tflite_model)