In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress TF warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.get_logger().setLevel('ERROR')

# Configuration
PROCESSED_ROOT = '../data/processed'
MODEL_DIR = '../models/fine_tuned_yamnet'
YAMNET_MODEL_HANDLE = 'https://tfhub.dev/google/yamnet/1'
TARGET_SR = 16000
BATCH_SIZE = 16  # Start with 16, increase if GPU allows
RANDOM_SEED = 42
EPOCHS = 50
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 1e-4

# Set random seeds
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

# Load YAMNet and modify architecture
print("\nLoading YAMNet model from TensorFlow Hub...")
yamnet_model = hub.load(YAMNET_MODEL_HANDLE)
print("YAMNet model loaded successfully")

# Access penultimate layer (embeddings)
inputs = tf.keras.layers.Input(shape=(None,), dtype=tf.float32, name='input')  # Raw waveform input
_, embeddings, _ = yamnet_model(inputs)  # Get embeddings

# Add new head for 5 classes
output = tf.keras.layers.Dense(5, activation='softmax', name='predictions')(embeddings)

# Create full model
model = tf.keras.Model(inputs=inputs, outputs=output)
model.trainable = True  # Unfreeze all layers

# Compile model
optimizer = tf.keras.optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Load metadata for splits
train_meta = pd.read_csv(os.path.join(PROCESSED_ROOT, 'train/train_metadata.csv'))
val_meta = pd.read_csv(os.path.join(PROCESSED_ROOT, 'val/val_metadata.csv'))
test_meta = pd.read_csv(os.path.join(PROCESSED_ROOT, 'test/test_metadata.csv'))

# Create label encoding (assuming balanced, no class weights)
categories = sorted(train_meta['category'].unique())
category_to_id = {cat: idx for idx, cat in enumerate(categories)}
train_meta['label'] = train_meta['category'].map(category_to_id)
val_meta['label'] = val_meta['category'].map(category_to_id)
test_meta['label'] = test_meta['category'].map(category_to_id)

# TF Dataset generator
def load_frame(path, label):
    frame = np.load(path.numpy().decode('utf-8')).astype(np.float32)
    return frame, label

def make_dataset(meta_df, batch_size=BATCH_SIZE):
    paths = meta_df['frame_path'].values
    labels = meta_df['label'].values
    dataset = tf.data.Dataset.from_tensor_slices((paths, labels))
    dataset = dataset.map(lambda p, l: tf.py_function(load_frame, [p, l], [tf.float32, tf.int64]),
                          num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

train_ds = make_dataset(train_meta)
val_ds = make_dataset(val_meta)
test_ds = make_dataset(test_meta)

# Augmentation in pipeline 
def augment(frame, label):
    noise = tf.random.normal(tf.shape(frame), dtype=tf.float32) * 0.005
    frame += noise
    return frame, label

train_ds = train_ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)  # Apply to train only

# Callbacks
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3),
    tf.keras.callbacks.ModelCheckpoint(os.path.join(MODEL_DIR, 'best_model.h5'), save_best_only=True)
]

# Train
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks
)

# Evaluate on test
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test Accuracy: {test_acc:.4f}")

# Custom metrics
y_true = []
y_pred = []
for frames, labels in test_ds:
    preds = model.predict(frames)
    y_true.extend(labels.numpy())
    y_pred.extend(np.argmax(preds, axis=1))

f1 = f1_score(y_true, y_pred, average='macro')
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
cm = confusion_matrix(y_true, y_pred)

print(f"F1 (macro): {f1:.4f}")
print(f"Precision (macro): {precision:.4f}")
print(f"Recall (macro): {recall:.4f}")

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=categories, yticklabels=categories)
plt.title('Confusion Matrix')
plt.savefig(os.path.join(MODEL_DIR, 'confusion_matrix.png'))

# Optimization for mobile (TFLite)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_model = converter.convert()

with open(os.path.join(MODEL_DIR, 'yamnet_fine_tuned.tflite'), 'wb') as f:
    f.write(tflite_model)

# Test latency (example on sample)
interpreter = tf.lite.Interpreter(model_path=os.path.join(MODEL_DIR, 'yamnet_fine_tuned.tflite'))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Sample inference
sample_frame = np.load(test_meta['frame_path'].iloc[0]).astype(np.float32)[np.newaxis, :]
import time
start = time.time()
interpreter.set_tensor(input_details[0]['index'], sample_frame)
interpreter.invoke()
output = interpreter.get_tensor(output_details[0]['index'])
latency = time.time() - start
print(f"Inference latency: {latency:.4f} seconds")