Load pre-trained YAMNet, replace classification head, unfreeze all layers, and train end-to-end with low learning rate.

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
from datetime import datetime
import json

# Suppress warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.get_logger().setLevel('ERROR')

# Configuration
PROCESSED_ROOT = '../data/processed'
METADATA_PATH = os.path.join(PROCESSED_ROOT, 'processed_frames_metadata.csv')
FEATURES_DIR = '../data/approach2/features'
MODELS_DIR = '../models/models_approach2/yamnet_finetuned'
RESULTS_DIR = '../results/results_approach2/yamnet_finetuned'
YAMNET_MODEL_HANDLE = 'https://tfhub.dev/google/yamnet/1'

TARGET_SR = 16000
BATCH_SIZE = 16
EPOCHS = 50
INITIAL_LR = 1e-5  # Very low for fine-tuning
RANDOM_SEED = 42

# Set seeds
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)


In [None]:
# 1. Load Data

print("\nLoading data splits...")

# Load label mapping
label_mapping = np.load(os.path.join(FEATURES_DIR, 'label_mapping.npy'),
                       allow_pickle=True).item()
categories = label_mapping['categories']
num_classes = len(categories)

print(f"Classes: {categories}")
print(f"Number of classes: {num_classes}")

# Load metadata for train/val/test splits
train_meta = pd.read_csv(os.path.join(FEATURES_DIR, 'train_metadata.csv'))
val_meta = pd.read_csv(os.path.join(FEATURES_DIR, 'val_metadata.csv'))
test_meta = pd.read_csv(os.path.join(FEATURES_DIR, 'test_metadata.csv'))

print(f"\nDataset splits:")
print(f"  Training:   {len(train_meta)} frames")
print(f"  Validation: {len(val_meta)} frames")
print(f"  Test:       {len(test_meta)} frames")



In [None]:

# 2. Create TensorFlow Datasets 
print("\nCreating TensorFlow datasets...")

def load_all_frames(metadata_df, label='train'):
    """Load all frames into memory (works for reasonably-sized datasets)."""
    print(f"  Loading {label} frames into memory...")
    audio_data = []
    labels = []
    
    for _, row in metadata_df.iterrows():
        try:
            audio = np.load(row['frame_path']).astype(np.float32)
            audio_data.append(audio)
            labels.append(int(row['label']))
        except Exception as e:
            print(f"    Warning: Could not load {row['frame_path']}: {e}")
            continue
    
    audio_data = np.array(audio_data)
    labels = np.array(labels)
    
    print(f"    Loaded {len(audio_data)} frames")
    return audio_data, labels

# Load all data into memory
train_X, train_y = load_all_frames(train_meta, 'training')
val_X, val_y = load_all_frames(val_meta, 'validation')
test_X, test_y = load_all_frames(test_meta, 'test')

# Create tf.data.Dataset from in-memory data (FIXED: Proper shape inference)
def create_dataset_from_arrays(X, y, batch_size, shuffle=True, seed=RANDOM_SEED):
    """Create tf.data.Dataset from numpy arrays with proper shape."""
    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(X), seed=seed)
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

train_dataset = create_dataset_from_arrays(train_X, train_y, BATCH_SIZE, shuffle=True)
val_dataset = create_dataset_from_arrays(val_X, val_y, BATCH_SIZE, shuffle=False)
test_dataset = create_dataset_from_arrays(test_X, test_y, BATCH_SIZE, shuffle=False)

print(f"TensorFlow datasets created with proper shape inference")


In [None]:
# 3. Build Fine-Tuning Model
print("\nBuilding fine-tuning model architecture...")

class YAMNetFineTuned(keras.Model):
    """
    YAMNet model with custom classification head.
    All layers are trainable for end-to-end fine-tuning.
    """
    def __init__(self, num_classes, yamnet_model_handle):
        super(YAMNetFineTuned, self).__init__()
        
        # Load YAMNet
        self.yamnet = hub.KerasLayer(
            yamnet_model_handle,
            trainable=True,  # Make YAMNet layers trainable
            name='yamnet'
        )
        
        # Custom classification head
        self.classifier = keras.Sequential([
            layers.Dense(256, activation='relu', name='fc1'),
            layers.BatchNormalization(name='bn1'),
            layers.Dropout(0.3, name='dropout1'),
            layers.Dense(128, activation='relu', name='fc2'),
            layers.BatchNormalization(name='bn2'),
            layers.Dropout(0.2, name='dropout2'),
            layers.Dense(num_classes, activation='softmax', name='output')
        ], name='classifier')
    
    def call(self, inputs, training=False):
        """Process batch of waveforms through YAMNet."""
        # Input shape: (batch_size, 15360) for 0.96s at 16kHz
        
        # Process each waveform through YAMNet
        def process_single_waveform(waveform):
            # YAMNet expects 1D audio
            scores, embeddings, spectrogram = self.yamnet(waveform)
            # embeddings shape: (num_frames, 1024)
            # Average across temporal dimension
            embedding = tf.reduce_mean(embeddings, axis=0)
            return embedding
        
        # Apply to batch
        embeddings = tf.map_fn(
            process_single_waveform,
            inputs,
            fn_output_signature=tf.float32
        )
        
        # Pass through classifier
        outputs = self.classifier(embeddings, training=training)
        return outputs

# Build model
print("Building YAMNet fine-tuning model...")
model = YAMNetFineTuned(num_classes, YAMNET_MODEL_HANDLE)

# Build by calling with dummy input
dummy_input = tf.random.normal([1, int(TARGET_SR * 0.96)])
_ = model(dummy_input)

print(f"Model built")
print(f"  Total parameters: {model.count_params():,}")

# Count trainable parameters
trainable_count = sum([tf.size(w).numpy() for w in model.trainable_weights])
print(f"  Trainable parameters: {trainable_count:,}")

model.summary()


In [None]:
# 4. Compile Model

print("\nCompiling model...")

# Calculate class weights for imbalanced data
class_counts = pd.Series(train_y).value_counts().sort_index()
total_samples = len(train_y)
class_weights = {i: total_samples / (num_classes * count) 
                for i, count in enumerate(class_counts)}

print(f"Class weights: {class_weights}")

# Optimizer with low learning rate for fine-tuning
optimizer = keras.optimizers.Adam(learning_rate=INITIAL_LR)

# Compile
model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print(f"Model compiled with learning rate: {INITIAL_LR}")


In [None]:
# 5. Callbacks
print("\nSetting up callbacks...")

# Model checkpoint
checkpoint_path = os.path.join(MODELS_DIR, 'best_model.keras')
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

# Early stopping
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True,
    verbose=1
)

# Learning rate reduction
lr_scheduler = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-7,
    verbose=1
)

# CSV logger
csv_logger = keras.callbacks.CSVLogger(
    os.path.join(RESULTS_DIR, 'training_log.csv')
)

callbacks = [
    checkpoint_callback,
    early_stopping,
    lr_scheduler,
    csv_logger
]

print(f"Callbacks configured")


In [None]:
# 6. Train Model
print("\nTraining model...")
print("="*70)
print(f"Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Initial learning rate: {INITIAL_LR}")
print(f"  Class weights: Enabled")
print("="*70)

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=EPOCHS,
    class_weight=class_weights,
    callbacks=callbacks,
    verbose=1
)

print("\nTraining completed!")


In [None]:

# 7. Plot Training History
print("\nPlotting training history...")

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history.history['loss'], label='Training Loss', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history.history['accuracy'], label='Training Accuracy', linewidth=2)
axes[1].plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2)
axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'training_history.png'), dpi=150)
plt.show()

print(f"Training history saved to {RESULTS_DIR}/training_history.png")


In [None]:
# 8. Evaluate on Test Set
print("\nEvaluating on test set...")

# Load best model
model = keras.models.load_model(checkpoint_path, custom_objects={'YAMNetFineTuned': YAMNetFineTuned})
print(f"Loaded best model from {checkpoint_path}")

# Get predictions
y_pred = model.predict(test_X, verbose=0)
y_pred = np.argmax(y_pred, axis=1)

# Calculate metrics
test_accuracy = np.mean(test_y == y_pred)
print(f"\nTest Accuracy: {test_accuracy:.4f}")

# Classification report
print("\nClassification Report:")
print("-"*70)
report = classification_report(test_y, y_pred, target_names=categories, digits=4)
print(report)

# Save classification report
report_dict = classification_report(test_y, y_pred, target_names=categories, 
                                   output_dict=True)
with open(os.path.join(RESULTS_DIR, 'classification_report.json'), 'w') as f:
    json.dump(report_dict, f, indent=2)

# Confusion Matrix
cm = confusion_matrix(test_y, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
           xticklabels=categories, yticklabels=categories,
           cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - Fine-Tuned YAMNet (Test Set)', 
         fontsize=16, fontweight='bold')
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'confusion_matrix_test.png'), dpi=150)
plt.show()

print(f"Confusion matrix saved to {RESULTS_DIR}/confusion_matrix_test.png")

# Per-class metrics
print("\nPer-class Performance:")
print("-"*70)
print(f"{'Class':<20s} {'Precision':>10s} {'Recall':>10s} {'F1-Score':>10s} {'Support':>10s}")
print("-"*70)

for cat in categories:
    precision = report_dict[cat]['precision']
    recall = report_dict[cat]['recall']
    f1 = report_dict[cat]['f1-score']
    support = report_dict[cat]['support']
    print(f"{cat:<20s} {precision:>10.4f} {recall:>10.4f} {f1:>10.4f} {support:>10.0f}")


In [None]:

# 9. Compare with Baseline Classifiers
print("\nComparing with baseline classifiers...")

comparison_file = '../results/model_comparison.csv'
if os.path.exists(comparison_file):
    baseline_df = pd.read_csv(comparison_file)
    
    # Add fine-tuned YAMNet results
    yamnet_results = {
        'Model': 'Fine-Tuned YAMNet',
        'Accuracy': report_dict['accuracy'],
        'Precision': report_dict['weighted avg']['precision'],
        'Recall': report_dict['weighted avg']['recall'],
        'F1-Score': report_dict['weighted avg']['f1-score']
    }
    
    comparison_df = pd.concat([baseline_df, pd.DataFrame([yamnet_results])], 
                              ignore_index=True)
    comparison_df = comparison_df.sort_values('F1-Score', ascending=False)
    
    print("\nModel Comparison (including Fine-Tuned YAMNet):")
    print(comparison_df.to_string(index=False))
    
    comparison_df.to_csv(os.path.join(RESULTS_DIR, 'final_model_comparison.csv'), 
                        index=False)
    
    # Visualize
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
    for idx, metric in enumerate(metrics):
        ax = axes[idx // 2, idx % 2]
        sorted_df = comparison_df.sort_values(metric)
        
        colors = ['coral' if model == 'Fine-Tuned YAMNet' else 'steelblue' 
                 for model in sorted_df['Model']]
        
        ax.barh(sorted_df['Model'], sorted_df[metric], color=colors)
        ax.set_xlabel(metric, fontsize=12)
        ax.set_title(f'{metric} Comparison', fontsize=14, fontweight='bold')
        ax.set_xlim([0, 1])
        ax.grid(axis='x', alpha=0.3)
        
        for i, v in enumerate(sorted_df[metric]):
            ax.text(v + 0.01, i, f'{v:.4f}', va='center', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(os.path.join(RESULTS_DIR, 'final_comparison.png'), 
               dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Comparison saved to {RESULTS_DIR}/final_comparison.png")


In [None]:
# 10. Save Final Model
print("\nSaving final model...")

final_model_path = os.path.join(MODELS_DIR, 'yamnet_finetuned_final.keras')
model.save(final_model_path)
print(f"Model saved to {final_model_path}")

# Save model configuration
model_config = {
    'model_type': 'YAMNet Fine-Tuned (Approach 2)',
    'num_classes': num_classes,
    'categories': categories,
    'sample_rate': TARGET_SR,
    'frame_duration': 0.96,
    'total_parameters': int(model.count_params()),
    'trainable_parameters': int(trainable_count),
    'training_config': {
        'batch_size': BATCH_SIZE,
        'epochs': EPOCHS,
        'initial_lr': INITIAL_LR,
        'optimizer': 'Adam',
    },
    'test_performance': {
        'accuracy': float(test_accuracy),
        'precision': float(report_dict['weighted avg']['precision']),
        'recall': float(report_dict['weighted avg']['recall']),
        'f1_score': float(report_dict['weighted avg']['f1-score'])
    }
}

with open(os.path.join(MODELS_DIR, 'model_config.json'), 'w') as f:
    json.dump(model_config, f, indent=2)

print(f"Model configuration saved")


In [None]:
# 11. Summary
print(f"\nSuccessfully fine-tuned complete YAMNet model")
print(f"\nTest Set Performance:")
print(f"  Accuracy:  {test_accuracy:.4f}")
print(f"  Precision: {report_dict['weighted avg']['precision']:.4f}")
print(f"  Recall:    {report_dict['weighted avg']['recall']:.4f}")
print(f"  F1-Score:  {report_dict['weighted avg']['f1-score']:.4f}")

print(f"\nModel Details:")
print(f"  Total parameters: {model.count_params():,}")
print(f"  Trainable parameters: {trainable_count:,}")
print(f"  Model size: ~{os.path.getsize(final_model_path) / (1024*1024):.2f} MB")

print(f"\nSaved artifacts:")
print(f"  - Model: {final_model_path}")
print(f"  - Config: {MODELS_DIR}/model_config.json")
print(f"  - Results: {RESULTS_DIR}/")
