In [2]:
# Log-Mel Spectrogram Model & TFLite Pipeline
import os
import sys
import numpy as np
import tensorflow as tf
from pathlib import Path
import pickle as pkl
from scipy.io import wavfile

# Configuration
PROJECT_ROOT = Path("../../Models/TFlite_Model")
sys.path.append(str(PROJECT_ROOT / '..' / '..' / '..' / 'HCAR'))  # adjust if needed

# Audio parameters
SAMPLE_RATE = 16000
STFT_WINDOW_SECONDS = 0.6  # 600ms
STFT_HOP_SECONDS = 0.03    # 30ms
LOWER_EDGE_HZ = 10.0
UPPER_EDGE_HZ = 8000.0
NUM_MEL_BINS = 64
LOG_OFFSET = 1e-3
FRAME_LENGTH = int(STFT_WINDOW_SECONDS * SAMPLE_RATE)
FRAME_STEP = int(STFT_HOP_SECONDS * SAMPLE_RATE)
FFT_LENGTH = 16384

# Paths
SAVED_MODEL_DIR = PROJECT_ROOT / 'saved_log_mel_model'
TFLITE_MODEL_PATH = PROJECT_ROOT / 'log_mel_model.tflite'
TEST_DATA_DIR = Path(r"D:/code/2024_11_22_SAMoSA_Replicate/Data_folder/2_TrainingDataset/3/Left")

# Model Definition

def build_log_mel_model(
    sample_rate=SAMPLE_RATE,
    frame_length=FRAME_LENGTH,
    frame_step=FRAME_STEP,
    fft_length=FFT_LENGTH,
    num_mel_bins=NUM_MEL_BINS,
    lower_edge_hz=LOWER_EDGE_HZ,
    upper_edge_hz=UPPER_EDGE_HZ,
    log_offset=LOG_OFFSET
):
    """
    Build a Keras model that converts raw audio to log-mel spectrogram.
    """
    inputs = tf.keras.Input(shape=(frame_length,), dtype=tf.float32, name='audio_input')
    # STFT
    stft = tf.signal.stft(
        signals=inputs,
        frame_length=frame_length,
        frame_step=frame_step,
        fft_length=fft_length,
        window_fn=tf.signal.hann_window,
        pad_end=False
    )
    magnitude = tf.abs(stft)
    # Mel weight matrix
    linear_to_mel = tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins=num_mel_bins,
        num_spectrogram_bins=fft_length//2 + 1,
        sample_rate=float(sample_rate),
        lower_edge_hertz=lower_edge_hz,
        upper_edge_hertz=upper_edge_hz
    )
    # Apply mel and log
    mel_spec = tf.matmul(magnitude, linear_to_mel)
    log_mel = tf.math.log(mel_spec + log_offset)
    return tf.keras.Model(inputs=inputs, outputs=log_mel, name='LogMelModel')

# TFLite conversion
def convert_to_tflite(saved_model_dir: Path, tflite_path: Path):
    tf.saved_model.save(model, saved_model_dir)
    converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))
    tflite_model = converter.convert()
    tflite_path.write_bytes(tflite_model)
    print(f"Saved TFLite model to {tflite_path}")

# Inference Utilities
def run_tflite_inference(audio: np.ndarray, model_path: Path) -> np.ndarray:
    """Run inference on a TFLite model given mono audio input shape (1, frame_length)."""
    interpreter = tf.lite.Interpreter(model_path=str(model_path))
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    interpreter.set_tensor(input_details[0]['index'], audio)
    interpreter.invoke()
    return interpreter.get_tensor(output_details[0]['index'])

# Main execution
if __name__ == '__main__':
    # Build and summarize
    model = build_log_mel_model()
    model.summary()

    # Test on random input
    dummy = np.random.randn(1, FRAME_LENGTH).astype(np.float32)
    out = model.predict(dummy)
    print(f"Log-mel output shape: {out.shape}")

    # Convert to TFLite
    convert_to_tflite(SAVED_MODEL_DIR, TFLITE_MODEL_PATH)

    # Load a test example from pickle
    sample_files = sorted(TEST_DATA_DIR.glob('*.pkl'))
    if sample_files:
        with open(sample_files[0], 'rb') as f:
            data = pkl.load(f)
        audio = data['Audio'].astype(np.float32)[:FRAME_LENGTH] / 32768.0
        audio = np.expand_dims(audio, axis=0)
        # Run TFLite inference
        tflite_out = run_tflite_inference(audio, TFLITE_MODEL_PATH)
        print(f"TFLite output shape: {tflite_out.shape}")


Model: "LogMelModel"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 audio_input (InputLayer)    [(None, 9600)]            0         
                                                                 
 tf.signal.stft_1 (TFOpLambd  (None, 1, 8193)          0         
 a)                                                              
                                                                 
 tf.math.abs_1 (TFOpLambda)  (None, 1, 8193)           0         
                                                                 
 tf.linalg.matmul_1 (TFOpLam  (None, 1, 64)            0         
 bda)                                                            
                                                                 
 tf.__operators__.add_1 (TFO  (None, 1, 64)            0         
 pLambda)                                                        
                                                       