In [None]:
"""
Approach 3: Progressive Unfreezing - Stage 1
==============================================
Initial training with frozen YAMNet base model.
Train only the new classification head.
"""

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import json
from datetime import datetime

# Suppress warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.get_logger().setLevel('ERROR')

# Configuration
FEATURES_DIR = '../data/features'
MODELS_DIR = '../models/progressive_unfreezing'
RESULTS_DIR = '../results/progressive_unfreezing'
YAMNET_MODEL_HANDLE = 'https://tfhub.dev/google/yamnet/1'

TARGET_SR = 16000
BATCH_SIZE = 16
EPOCHS_STAGE1 = 30
LEARNING_RATE_STAGE1 = 1e-3
RANDOM_SEED = 42

np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

os.makedirs(os.path.join(MODELS_DIR, 'stage1'), exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

print("="*70)
print("APPROACH 3: PROGRESSIVE UNFREEZING - STAGE 1")
print("Training frozen base with new head")
print("="*70)

# ============================================================================
# 1. Load Data
# ============================================================================
print("\n[1] Loading data...")

label_mapping = np.load(os.path.join(FEATURES_DIR, 'label_mapping.npy'),
                       allow_pickle=True).item()
categories = label_mapping['categories']
num_classes = len(categories)

train_meta = pd.read_csv(os.path.join(FEATURES_DIR, 'train_metadata.csv'))
val_meta = pd.read_csv(os.path.join(FEATURES_DIR, 'val_metadata.csv'))
test_meta = pd.read_csv(os.path.join(FEATURES_DIR, 'test_metadata.csv'))

print(f"✓ Classes: {categories}")
print(f"  Training:   {len(train_meta)} frames")
print(f"  Validation: {len(val_meta)} frames")
print(f"  Test:       {len(test_meta)} frames")

# ============================================================================
# 2. Load and Prepare Datasets
# ============================================================================
print("\n[2] Loading audio frames into memory...")

def load_all_frames(metadata_df, label='train'):
    """Load all frames into memory."""
    print(f"  Loading {label} frames...")
    audio_data = []
    labels = []
    
    for _, row in metadata_df.iterrows():
        try:
            audio = np.load(row['frame_path']).astype(np.float32)
            audio_data.append(audio)
            labels.append(int(row['label']))
        except Exception as e:
            print(f"    Warning: Could not load {row['frame_path']}")
            continue
    
    return np.array(audio_data), np.array(labels)

train_X, train_y = load_all_frames(train_meta, 'training')
val_X, val_y = load_all_frames(val_meta, 'validation')
test_X, test_y = load_all_frames(test_meta, 'test')

def create_dataset(X, y, batch_size, shuffle=True, seed=RANDOM_SEED):
    """Create tf.data.Dataset with proper shape."""
    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(X), seed=seed)
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

train_dataset = create_dataset(train_X, train_y, BATCH_SIZE, shuffle=True)
val_dataset = create_dataset(val_X, val_y, BATCH_SIZE, shuffle=False)
test_dataset = create_dataset(test_X, test_y, BATCH_SIZE, shuffle=False)

print(f"✓ Datasets created")

# ============================================================================
# 3. Build Model with Frozen Base
# ============================================================================
print("\n[3] Building model with frozen YAMNet base...")

class ProgressiveUnfreezeModel(keras.Model):
    """YAMNet with progressive unfreezing capability."""
    def __init__(self, num_classes, yamnet_model_handle):
        super(ProgressiveUnfreezeModel, self).__init__()
        
        self.yamnet = hub.KerasLayer(
            yamnet_model_handle,
            trainable=False,  # FROZEN in Stage 1
            name='yamnet'
        )
        
        self.classifier = keras.Sequential([
            layers.Dense(256, activation='relu', name='fc1'),
            layers.BatchNormalization(name='bn1'),
            layers.Dropout(0.3, name='dropout1'),
            layers.Dense(128, activation='relu', name='fc2'),
            layers.BatchNormalization(name='bn2'),
            layers.Dropout(0.2, name='dropout2'),
            layers.Dense(num_classes, activation='softmax', name='output')
        ], name='classifier')
    
    def call(self, inputs, training=False):
        """Forward pass through YAMNet and classifier."""
        def process_waveform(waveform):
            scores, embeddings, spectrogram = self.yamnet(waveform)
            embedding = tf.reduce_mean(embeddings, axis=0)
            return embedding
        
        embeddings = tf.map_fn(
            process_waveform,
            inputs,
            fn_output_signature=tf.float32
        )
        
        outputs = self.classifier(embeddings, training=training)
        return outputs

model = ProgressiveUnfreezeModel(num_classes, YAMNET_MODEL_HANDLE)

# Build model
dummy_input = tf.random.normal([1, int(TARGET_SR * 0.96)])
_ = model(dummy_input)

print(f"✓ Model built")
print(f"  Total parameters: {model.count_params():,}")

# Count frozen vs trainable
frozen_count = sum([tf.size(w).numpy() for w in model.yamnet.trainable_weights])
trainable_count = sum([tf.size(w).numpy() for w in model.classifier.trainable_weights])
print(f"  Frozen parameters (YAMNet): {frozen_count:,}")
print(f"  Trainable parameters (Head): {trainable_count:,}")

print(f"\n  YAMNet trainable: {model.yamnet.trainable}")
print(f"  Classifier trainable: {model.classifier.trainable}")

# ============================================================================
# 4. Compile and Train - Stage 1
# ============================================================================
print("\n[4] Compiling model (Stage 1)...")

# Class weights
class_counts = pd.Series(train_y).value_counts().sort_index()
total_samples = len(train_y)
class_weights = {i: total_samples / (num_classes * count) 
                for i, count in enumerate(class_counts)}

optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE_STAGE1)

model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print(f"✓ Model compiled")
print(f"  Learning rate: {LEARNING_RATE_STAGE1}")
print(f"  Optimizer: Adam")

# Callbacks
checkpoint_path = os.path.join(MODELS_DIR, 'stage1', 'best_model_stage1.keras')
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True,
    verbose=1
)

lr_scheduler = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-6,
    verbose=1
)

csv_logger = keras.callbacks.CSVLogger(
    os.path.join(RESULTS_DIR, 'stage1_training_log.csv')
)

print("\n[5] Training Stage 1 (frozen base, training head only)...")
print("="*70)

history_stage1 = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=EPOCHS_STAGE1,
    class_weight=class_weights,
    callbacks=[checkpoint_callback, early_stopping, lr_scheduler, csv_logger],
    verbose=1
)

print("\n✓ Stage 1 training completed!")

# ============================================================================
# 5. Plot Training History - Stage 1
# ============================================================================
print("\n[6] Plotting training history...")

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(history_stage1.history['loss'], label='Training Loss', linewidth=2)
axes[0].plot(history_stage1.history['val_loss'], label='Validation Loss', linewidth=2)
axes[0].set_title('Stage 1: Training Loss (Frozen Base)', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history_stage1.history['accuracy'], label='Training Accuracy', linewidth=2)
axes[1].plot(history_stage1.history['val_accuracy'], label='Validation Accuracy', linewidth=2)
axes[1].set_title('Stage 1: Training Accuracy', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'stage1_training_history.png'), dpi=150)
plt.show()

print(f"✓ Training history saved")

# ============================================================================
# 6. Evaluate Stage 1
# ============================================================================
print("\n[7] Evaluating Stage 1 on test set...")

model = keras.models.load_model(checkpoint_path, 
                               custom_objects={'ProgressiveUnfreezeModel': ProgressiveUnfreezeModel})

y_pred = model.predict(test_X, verbose=0)
y_pred = np.argmax(y_pred, axis=1)

test_accuracy = np.mean(test_y == y_pred)
print(f"\n✓ Stage 1 Test Accuracy: {test_accuracy:.4f}")

report = classification_report(test_y, y_pred, target_names=categories, digits=4)
print("\nClassification Report:")
print(report)

report_dict = classification_report(test_y, y_pred, target_names=categories, 
                                   output_dict=True)

# Confusion matrix
cm = confusion_matrix(test_y, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
           xticklabels=categories, yticklabels=categories,
           cbar_kws={'label': 'Count'})
plt.title('Stage 1: Confusion Matrix (Frozen Base)', fontsize=16, fontweight='bold')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'stage1_confusion_matrix.png'), dpi=150)
plt.show()

# ============================================================================
# 7. Save Stage 1 Results
# ============================================================================
print("\n[8] Saving Stage 1 artifacts...")

stage1_results = {
    'stage': 1,
    'description': 'Frozen YAMNet base, training head only',
    'test_accuracy': float(test_accuracy),
    'test_precision': float(report_dict['weighted avg']['precision']),
    'test_recall': float(report_dict['weighted avg']['recall']),
    'test_f1': float(report_dict['weighted avg']['f1-score']),
    'epochs_trained': len(history_stage1.history['loss']),
    'learning_rate': LEARNING_RATE_STAGE1,
    'trainable_layers': 'Classifier head only',
    'frozen_layers': 'YAMNet (all)'
}

with open(os.path.join(RESULTS_DIR, 'stage1_results.json'), 'w') as f:
    json.dump(stage1_results, f, indent=2)

# Save best validation metrics
val_metrics_stage1 = {
    'best_val_loss': float(min(history_stage1.history['val_loss'])),
    'best_val_accuracy': float(max(history_stage1.history['val_accuracy'])),
    'final_val_loss': float(history_stage1.history['val_loss'][-1]),
    'final_val_accuracy': float(history_stage1.history['val_accuracy'][-1])
}

with open(os.path.join(RESULTS_DIR, 'stage1_val_metrics.json'), 'w') as f:
    json.dump(val_metrics_stage1, f, indent=2)

print(f"✓ Results saved to {RESULTS_DIR}/")

# ============================================================================
# 9. Summary
# ============================================================================
print("\n" + "="*70)
print("STAGE 1 SUMMARY")
print("="*70)
print(f"\n✓ Stage 1 Complete: Frozen Base Training")
print(f"\nPerformance:")
print(f"  Test Accuracy:  {test_accuracy:.4f}")
print(f"  Test Precision: {report_dict['weighted avg']['precision']:.4f}")
print(f"  Test Recall:    {report_dict['weighted avg']['recall']:.4f}")
print(f"  Test F1-Score:  {report_dict['weighted avg']['f1-score']:.4f}")

print(f"\nTraining:")
print(f"  Epochs: {len(history_stage1.history['loss'])}")
print(f"  Best Val Accuracy: {max(history_stage1.history['val_accuracy']):.4f}")
print(f"  Final Val Accuracy: {history_stage1.history['val_accuracy'][-1]:.4f}")

print(f"\nConfiguration:")
print(f"  Frozen: YAMNet base")
print(f"  Trainable: Classifier head only")
print(f"  Learning rate: {LEARNING_RATE_STAGE1}")

print(f"\nSaved:")
print(f"  Model: {checkpoint_path}")
print(f"  Results: {RESULTS_DIR}/stage1_results.json")

print("\n" + "="*70)
print("Next: Stage 2 - Unfreeze top layers of YAMNet")
print("="*70)