"""
Approach 2: Fine-tune YAMNet End-to-End - TRAINING ONLY
This script:
1. Loads pre-trained YAMNet
2. Replaces 521-class head with 5-class head
3. Unfreezes all layers for fine-tuning
4. Trains on raw audio with data augmentation
5. Validates on validation set (NO TEST SET EVALUATION)
"""



In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
from sklearn.metrics import classification_report, f1_score
import matplotlib.pyplot as plt
from datetime import datetime
import json



In [None]:
SPLIT_ROOT = '../data/split_processed'
YAMNET_HANDLE = 'https://tfhub.dev/google/yamnet/1'
MODELS_DIR = '../models/models_approach2/yamnet_finetuned'
RESULTS_DIR = '../results/results_approach2/yamnet_finetuned'

os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

SAMPLE_RATE = 16000
FRAME_LENGTH = 15360
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 1e-5
PATIENCE_EARLY = 5
PATIENCE_LR = 3
tf.random.set_seed(42)
np.random.seed(42)





In [None]:
# ==================== LOAD METADATA ====================
CATEGORIES = ['alarm_clock','car_horn','glass_breaking','gunshot','siren']
category_to_label = { 
    'alarm_clock':'Alarm_Clock', 'car_horn':'Car_Horn',
    'glass_breaking':'Glass_Breaking','gunshot':'Gunshot','siren':'Siren'
}
label_to_id = {v:i for i,v in enumerate(sorted(category_to_label.values()))}
id_to_label = {i:v for v,i in label_to_id.items()}

print(f"Classes: {list(label_to_id.keys())}")
print(f"Label mapping: {label_to_id}")

# ==================== PREPARE FILE LISTS ====================
def prepare_split(split):
    paths, labels = [], []
    for c in CATEGORIES:
        d = os.path.join(SPLIT_ROOT, split, c)
        if not os.path.exists(d): continue
        for f in os.listdir(d):
            if f.endswith('.npy'):
                paths.append(os.path.join(d, f))
                labels.append(label_to_id[category_to_label[c]])
    return paths, labels

train_paths, train_labels = prepare_split('train')
val_paths, val_labels = prepare_split('val')


print(f"\nDataset sizes:")
print(f"  Train: {len(train_paths):,} samples")
print(f"  Val:   {len(val_paths):,} samples")

# Class distribution
print(f"\nClass distribution:")
for split_name, labels in [('Train', train_labels), ('Val', val_labels)]:
    print(f"\n{split_name}:")
    for label_id, label_name in id_to_label.items():
        count = labels.count(label_id)
        pct = 100 * count / len(labels)
        print(f"  {label_name:<20} {count:>4} ({pct:>5.2f}%)")



In [None]:
# ==================== DATA LOADING FUNCTIONS ====================
def load_audio(fp):
    x = np.load(fp.numpy().decode())
    if x.ndim > 1: x = x.flatten()
    if len(x) < FRAME_LENGTH: x = np.pad(x, (0, FRAME_LENGTH-len(x)))
    else: x = x[:FRAME_LENGTH]
    return x.astype(np.float32)

def augment_audio(x):
    noise = tf.random.normal(tf.shape(x), stddev=tf.random.uniform([], 0, 0.005))
    gain = tf.random.uniform([], 0.8, 1.2)
    return tf.clip_by_value(x*gain + noise, -1.0, 1.0)

def preprocess(fp, y, aug=False):
    audio = tf.py_function(load_audio, [fp], tf.float32)
    audio.set_shape([FRAME_LENGTH])
    if aug: audio = augment_audio(audio)
    return audio, y

def make_ds(paths, labels, aug=False):
    ds = tf.data.Dataset.from_tensor_slices((paths, labels))
    if aug: ds = ds.shuffle(len(paths), seed=42)
    ds = ds.map(lambda x,y: preprocess(x,y,aug), num_parallel_calls=tf.data.AUTOTUNE)
    return ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)



In [None]:
train_ds = make_ds(train_paths, train_labels, True)
val_ds   = make_ds(val_paths, val_labels)


In [None]:
# ==================== BUILD MODEL (FULL FINE-TUNING) ====================
print("\nLoading YAMNet as tf.Module...")

yamnet_layer = hub.KerasLayer(YAMNET_HANDLE, trainable=True)

class FineTunedYAMNet(tf.keras.Model):
    def __init__(self, n_classes):
        super().__init__()
        self.yamnet = yamnet_layer
        self.pool = tf.keras.layers.GlobalAveragePooling1D()
        self.head = tf.keras.layers.Dense(n_classes, activation='softmax')

    def call(self, x):
        # x shape: [B, FRAME_LENGTH]
        # Map each example in the batch to YAMNet, which expects 1D waveform
        def per_example(wav):
            scores, embeddings, _ = self.yamnet(wav)  # wav: [samples]
            return embeddings  # [T, 1024]

        # Apply YAMNet to each waveform in the batch
        embeddings = tf.map_fn(per_example, x, dtype=tf.float32, parallel_iterations=4)
        # embeddings shape: [B, T, 1024]
        pooled = self.pool(embeddings)  # [B, 1024]
        return self.head(pooled)  # [B, num_classes]

model = FineTunedYAMNet(len(label_to_id))
_ = model(tf.random.normal([1, FRAME_LENGTH]))  



In [None]:
# ================== TRAINING ==================
class MacroF1(tf.keras.metrics.Metric):
    def __init__(self, n_classes, name="macro_f1", **kwargs):
        super().__init__(name=name, **kwargs)
        self.n_classes = n_classes
        self.f1 = self.add_weight(name="f1", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.argmax(y_pred, axis=-1)
        y_true_oh = tf.one_hot(y_true, depth=self.n_classes)
        y_pred_oh = tf.one_hot(y_pred, depth=self.n_classes)

        tp = tf.reduce_sum(y_true_oh * y_pred_oh, axis=0)
        fp = tf.reduce_sum((1 - y_true_oh) * y_pred_oh, axis=0)
        fn = tf.reduce_sum(y_true_oh * (1 - y_pred_oh), axis=0)

        precision = tp / (tp + fp + 1e-8)
        recall = tp / (tp + fn + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        batch_f1 = tf.reduce_mean(f1)

        self.f1.assign_add(batch_f1)
        self.count.assign_add(1.0)

    def result(self):
        return self.f1 / (self.count + 1e-8)

    def reset_state(self):
        self.f1.assign(0.0)
        self.count.assign(0.0)



model.compile(
    optimizer=tf.keras.optimizers.AdamW(LEARNING_RATE, weight_decay=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy', MacroF1(len(label_to_id))]
)


cb=[
    tf.keras.callbacks.EarlyStopping(monitor='val_macro_f1',patience=PATIENCE_EARLY,mode='max',restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_macro_f1',factor=0.5,patience=PATIENCE_LR,mode='max'),
    tf.keras.callbacks.ModelCheckpoint(os.path.join(MODELS_DIR,'best.keras'),monitor='val_macro_f1',mode='max',save_best_only=True),
    tf.keras.callbacks.CSVLogger(os.path.join(RESULTS_DIR,'train_log.csv'))
]

history=model.fit(train_ds,validation_data=val_ds,epochs=EPOCHS,callbacks=cb)

In [None]:
y_true,y_pred=[],[]
for x,y in test_ds:
    p=model.predict(x,verbose=0)
    y_true.extend(y.numpy()); y_pred.extend(np.argmax(p,1))
val_f1=f1_score(y_true,y_pred,average='macro')
print(classification_report(y_true,y_pred,target_names=[id_to_label[i] for i in range(len(id_to_label))]))
print("Test Macro-F1:",val_f1)

# ================== SAVE MODEL ==================
model.save(os.path.join(MODELS_DIR,'yamnet_finetuned_final.keras'))