Convert the best Keras model → TFLite
with a **true “none”** output when no target sound is present.

In [None]:
import os
import json
import time
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
from pathlib import Path
from collections import deque

# ----------------------------------------------------------------------
# CONFIGURATION
# ----------------------------------------------------------------------
MODELS_DIR       = '../models/models_approach1'
RESULTS_DIR      = '../results/results_approach1'
SAVED_MODELS_DIR = os.path.join(MODELS_DIR, 'saved_models')
TFLITE_DIR       = '../models/tflite_models'
TEST_RAW_DIR     = '../data/split_processed/test'

# ---- thresholds (tune once) ----
CONFIDENCE_THRESHOLD       = 0.85   # must be >= this to even consider a class
MIN_CONSECUTIVE_DETECTIONS = 7
RMS_THRESHOLD              = 0.03   # silence gate
OVERCONFIDENCE_CAP         = 0.98   # reject > 0.98 (model hallucination)
RESET_AFTER_SECONDS        = 5.0

os.makedirs(TFLITE_DIR, exist_ok=True)

In [None]:
# CUSTOM YAMNET LAYER (required for load_model)
@tf.keras.utils.register_keras_serializable()
class YamnetEmbedding(tf.keras.layers.Layer):
    def __init__(self, yamnet_model, **kwargs):
        super().__init__(**kwargs)
        self.yamnet_model = yamnet_model
    def call(self, inputs):
        def _single(w):
            _, emb, _ = self.yamnet_model(w)
            return tf.reduce_mean(emb, axis=0)
        return tf.map_fn(_single, inputs, dtype=tf.float32)

In [None]:
# 1. LOAD BEST MODEL
print("1. LOADING BEST KERAS MODEL")
print("="*80)

leaderboard_path = os.path.join(RESULTS_DIR, 'final_test_leaderboard.csv')
df = pd.read_csv(leaderboard_path)
best_name = df.iloc[0]['Model']
best_f1   = df.iloc[0]['F1_Macro']
print(f"Best model: {best_name} (F1-macro = {best_f1:.4f})")

model_path = os.path.join(SAVED_MODELS_DIR, f"{best_name.lower()}_full")
model = tf.keras.models.load_model(
    model_path,
    custom_objects={'YamnetEmbedding': YamnetEmbedding}
)
print("Keras model loaded")

In [None]:
# 2. FREEZE → PLAIN SavedModel (removes hub URL)
print("2. FREEZING TO SAVEDMODEL")
print("="*80)

input_spec = tf.TensorSpec([1, 15360], tf.float32)
concrete   = tf.function(lambda x: model(x)).get_concrete_function(input_spec)
frozen_dir = os.path.join(TFLITE_DIR, "frozen_savedmodel")
tf.saved_model.save(model, frozen_dir, signatures=concrete)
print(f"Frozen SavedModel → {frozen_dir}")

In [None]:
# 3. CONVERT TO TFLITE
print("3. CONVERTING TO TFLITE")
print("="*80)

converter = tf.lite.TFLiteConverter.from_saved_model(frozen_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS
]
converter._experimental_lower_tensor_list_ops = False

print("Running conversion …")
tflite_model = converter.convert()

tflite_path = os.path.join(TFLITE_DIR, f"{best_name.lower()}_model.tflite")
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)
size_mb = len(tflite_model) / (1024*1024)
print(f"TFLite model saved → {tflite_path} ({size_mb:.2f} MB)")

In [None]:
# 4. METADATA
print("4. WRITING METADATA")
print("="*80)

label_encoder = joblib.load(os.path.join(MODELS_DIR, 'label_encoder.pkl'))
classes = label_encoder.classes_

metadata = {
    "model_name": best_name,
    "test_f1_score": float(best_f1),
    "classes": classes.tolist(),
    "sample_rate": 16000,
    "input_length": 15360,
    "duration_seconds": 0.96,
    "confidence_threshold": CONFIDENCE_THRESHOLD,
    "min_consecutive_detections": MIN_CONSECUTIVE_DETECTIONS,
    "rms_threshold": RMS_THRESHOLD,
    "overconfidence_cap": OVERCONFIDENCE_CAP,
    "reset_after_seconds": RESET_AFTER_SECONDS,
    "input_shape": [1, 15360],
    "output_shape": [1, len(classes)],
    "description": "5-class detector – returns None when no target sound is present"
}
metadata_path = os.path.join(TFLITE_DIR, "model_metadata.json")
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"Metadata → {metadata_path}")

In [None]:
# 5. QUICK SANITY CHECK
print("5. QUICK TFLITE TEST")
print("="*80)

interp = tf.lite.Interpreter(model_path=tflite_path)
interp.allocate_tensors()
inp = interp.get_input_details()[0]
out = interp.get_output_details()[0]

rand = np.random.randn(15360).astype(np.float32)
interp.set_tensor(inp['index'], rand[np.newaxis,:])
interp.invoke()
probs = interp.get_tensor(out['index'])[0]
print(f"Random → {classes[np.argmax(probs)]} (conf {probs.max():.4f})")

In [None]:
# 6. DETECTION LOGIC (returns None for “none”)
class Detector:
    def __init__(self):
        self.interp = tf.lite.Interpreter(model_path=tflite_path)
        self.interp.allocate_tensors()
        self.in_idx = self.interp.get_input_details()[0]['index']
        self.out_idx = self.interp.get_output_details()[0]['index']
        self.buffer = deque(maxlen=MIN_CONSECUTIVE_DETECTIONS)
        self.last_trigger = None

    def _infer(self, audio: np.ndarray) -> np.ndarray:
        self.interp.set_tensor(self.in_idx, audio[np.newaxis,:])
        self.interp.invoke()
        return self.interp.get_tensor(self.out_idx)[0]

    def detect(self, audio: np.ndarray):
        """Return dict with 'class' == None when no target sound."""
        rms = np.sqrt(np.mean(audio**2))
        if rms < RMS_THRESHOLD:                     # ---- silence ----
            self.buffer.clear()
            self.last_trigger = None
            return {"class": None, "confidence": 0.0, "triggered": False,
                    "reason": "silence", "rms": rms}

        probs = self._infer(audio)
        idx   = int(np.argmax(probs))
        conf  = float(probs[idx])
        cls   = classes[idx]

        # ---- over-confident hallucination ----
        if conf > OVERCONFIDENCE_CAP:
            self.buffer.clear()
            return {"class": None, "confidence": conf, "triggered": False,
                    "reason": "overconfident"}

        # ---- not confident enough for *any* class ----
        if conf < CONFIDENCE_THRESHOLD:
            self.buffer.clear()
            return {"class": None, "confidence": conf, "triggered": False,
                    "reason": "low_confidence"}

        # ---- we have a plausible candidate ----
        self.buffer.append(cls)

        triggered = (len(self.buffer) >= MIN_CONSECUTIVE_DETECTIONS and
                     len(set(self.buffer)) == 1)

        now = time.time()
        if triggered:
            self.last_trigger = now
        elif self.last_trigger and now - self.last_trigger > RESET_AFTER_SECONDS:
            self.buffer.clear()

        return {
            "class": cls if triggered else None,
            "confidence": conf,
            "triggered": triggered,
            "buffer_len": len(self.buffer),
            "rms": rms
        }

In [None]:
# 7. BACKGROUND TEST (must stay at 0 %)
print("7. BACKGROUND TEST")
print("="*80)

bg_dir = Path(TEST_RAW_DIR) / "background"
if not bg_dir.exists():
    bg_dir.mkdir(parents=True, exist_ok=True)
    np.save(bg_dir / "silence.npy", np.zeros(15360, dtype=np.float32))
    np.save(bg_dir / "noise.npy", np.random.randn(15360).astype(np.float32)*0.05)

det = Detector()
bg_res = []
for npy in list(bg_dir.glob("*.npy"))[:10]:
    a = np.load(npy).astype(np.float32)
    if a.shape == (15360,):
        bg_res.append(det.detect(a))

rate = sum(1 for r in bg_res if r["triggered"]) / len(bg_res) if bg_res else 0
print(f"Trigger rate on background: {rate:.2%}")
assert rate == 0.0, "Background must never trigger!"

In [None]:
# 8. GENERATE PYTHON WRAPPER (audio_detector.py)
print("\n" + "="*80)
print("8. WRITING INFERENCE WRAPPER")
print("="*80)

wrapper_code = f'''\
import numpy as np
import tensorflow as tf
import json
import time
from collections import deque

class AudioEventDetector:
    """
    Robust 5-class detector.
    Returns **None** when none of the target sounds is present.
    """
    def __init__(self,
                 model_path="{tflite_path}",
                 metadata_path="{metadata_path}",
                 rms_threshold={RMS_THRESHOLD},
                 confidence_threshold={CONFIDENCE_THRESHOLD},
                 min_consecutive={MIN_CONSECUTIVE_DETECTIONS},
                 overconfidence_cap={OVERCONFIDENCE_CAP},
                 reset_after_seconds={RESET_AFTER_SECONDS}):

        self.interp = tf.lite.Interpreter(model_path=model_path)
        self.interp.allocate_tensors()
        self.in_idx = self.interp.get_input_details()[0]["index"]
        self.out_idx = self.interp.get_output_details()[0]["index"]

        with open(metadata_path) as f:
            meta = json.load(f)
        self.classes = meta["classes"]

        self.rms_thr = rms_threshold
        self.conf_thr = confidence_threshold
        self.min_cons = min_consecutive
        self.over_cap = overconfidence_cap
        self.reset_sec = reset_after_seconds

        self.buffer = deque(maxlen=self.min_cons)
        self.last_trigger = None

    # --------------------------------------------------------------
    def _infer(self, audio: np.ndarray) -> np.ndarray:
        if audio.shape != (15360,):
            raise ValueError("Audio must be exactly 15360 samples")
        inp = audio.astype(np.float32)[np.newaxis,:]
        self.interp.set_tensor(self.in_idx, inp)
        self.interp.invoke()
        return self.interp.get_tensor(self.out_idx)[0]

    # --------------------------------------------------------------
    def detect(self, audio: np.ndarray):
        rms = np.sqrt(np.mean(audio**2))

        # ---- silence ----
        if rms < self.rms_thr:
            self.buffer.clear()
            self.last_trigger = None
            return {{"class": None, "confidence": 0.0, "triggered": False,
                    "reason": "silence", "probabilities": [0.0]*len(self.classes)}}

        probs = self._infer(audio)
        idx = int(np.argmax(probs))
        conf = float(probs[idx])
        cls = self.classes[idx]

        # ---- over-confident hallucination ----
        if conf > self.over_cap:
            self.buffer.clear()
            return {{"class": None, "confidence": conf, "triggered": False,
                    "reason": "overconfident", "probabilities": probs.tolist()}}

        # ---- not confident enough for any class ----
        if conf < self.conf_thr:
            self.buffer.clear()
            return {{"class": None, "confidence": conf, "triggered": False,
                    "reason": "low_confidence", "probabilities": probs.tolist()}}

        # ---- plausible candidate ----
        self.buffer.append(cls)

        triggered = (len(self.buffer) >= self.min_cons and
                     len(set(self.buffer)) == 1)

        now = time.time()
        if triggered:
            self.last_trigger = now
        elif self.last_trigger and now - self.last_trigger > self.reset_sec:
            self.buffer.clear()

        return {{
            "class": cls if triggered else None,
            "confidence": conf,
            "triggered": triggered,
            "probabilities": probs.tolist()
        }}

    # --------------------------------------------------------------
    def reset(self):
        self.buffer.clear()
        self.last_trigger = None

# ------------------------------------------------------------------
# Demo
# ------------------------------------------------------------------
if __name__ == "__main__":
    d = AudioEventDetector()
    def demo(name, gen):
        print(f"\\n--- {{name}} ---")
        for i in range(8):
            r = d.detect(gen())
            cls = r["class"] if r["class"] else "None"
            print(f"{{i:02d}}: {{cls:<12}} conf={{r['confidence']:.4f}} trigger={{r['triggered']}}")
    demo("SILENCE", lambda: np.zeros(15360, dtype=np.float32))
    demo("LOW NOISE", lambda: np.random.uniform(-0.05,0.05,15360).astype(np.float32))
    demo("WHITE NOISE", lambda: np.random.randn(15360).astype(np.float32)*0.15)
'''

wrapper_path = os.path.join(TFLITE_DIR, "audio_detector.py")
with open(wrapper_path, 'w') as f:
    f.write(wrapper_code)
print(f"Wrapper written → {wrapper_path}")



In [None]:
# FINAL SUMMARY
print("CONVERSION COMPLETE")
print("="*80)
print(f"TFLite model : {tflite_path} ({size_mb:.2f} MB)")
print(f"Metadata     : {metadata_path}")
print(f"Wrapper      : {wrapper_path}")
print("\nRun the demo:")
print(f"    python {wrapper_path}")