# 02 - Model Training

This notebook demonstrates how to train a CNN model for facial keypoint detection.

## Overview

- Build a CNN architecture for keypoint regression
- Train the model with data augmentation
- Monitor training progress and evaluate performance
- Save the trained model for inference

In [None]:
# Standard imports
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# TensorFlow/Keras imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, callbacks

# Package imports
from facial_keypoints.data.loader import load_data, get_data_statistics
from facial_keypoints.visualization.plotting import plot_training_samples
from facial_keypoints.config import settings

# Display settings
%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 6)

# Check GPU availability
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

## 1. Load and Prepare Data

In [None]:
# Load training data
try:
    X, y = load_data(test=False)
    
    # Split into train/validation sets (80/20)
    split_idx = int(len(X) * 0.8)
    X_train, X_val = X[:split_idx], X[split_idx:]
    y_train, y_val = y[:split_idx], y[split_idx:]
    
    print(f"Training samples: {len(X_train)}")
    print(f"Validation samples: {len(X_val)}")
    print(f"Input shape: {X_train.shape[1:]}")
    print(f"Output shape: {y_train.shape[1:]}")
except Exception as e:
    print(f"Could not load data: {e}")
    print("\nTo use this notebook, download the dataset and place it in data/training.csv")

## 2. Data Augmentation

Apply random transformations to increase training data diversity.

In [None]:
def create_augmentation_layer():
    """Create a data augmentation layer for training."""
    return keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.05),
        layers.RandomBrightness(0.1),
        layers.RandomContrast(0.1),
    ], name="augmentation")

# Preview augmentation
try:
    augmentation = create_augmentation_layer()
    
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    sample_img = X_train[0:1]
    
    for i, ax in enumerate(axes.flatten()):
        if i == 0:
            ax.imshow(sample_img[0].squeeze(), cmap='gray')
            ax.set_title('Original')
        else:
            augmented = augmentation(sample_img, training=True)
            ax.imshow(augmented[0].squeeze(), cmap='gray')
            ax.set_title(f'Augmented {i}')
        ax.axis('off')
    
    plt.suptitle('Data Augmentation Examples', fontsize=14)
    plt.tight_layout()
    plt.show()
except NameError:
    print("Data not loaded")

## 3. Build CNN Model

Architecture based on the original Udacity project with modern improvements.

In [None]:
def build_keypoint_model(input_shape=(96, 96, 1), n_keypoints=15):
    """Build a CNN model for facial keypoint detection.
    
    Args:
        input_shape: Shape of input images (height, width, channels).
        n_keypoints: Number of keypoints to predict.
        
    Returns:
        Compiled Keras model.
    """
    model = models.Sequential([
        # Input
        layers.Input(shape=input_shape),
        
        # Conv Block 1
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.1),
        
        # Conv Block 2
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.2),
        
        # Conv Block 3
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.3),
        
        # Conv Block 4
        layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.4),
        
        # Dense layers
        layers.Flatten(),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        
        # Output: 2 coordinates per keypoint
        layers.Dense(n_keypoints * 2, activation='tanh'),
    ])
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss='mse',
        metrics=['mae']
    )
    
    return model

# Build and display model
model = build_keypoint_model()
model.summary()

## 4. Training Configuration

In [None]:
# Training hyperparameters
EPOCHS = 50
BATCH_SIZE = 32
PATIENCE = 10

# Create callbacks
model_callbacks = [
    callbacks.EarlyStopping(
        monitor='val_loss',
        patience=PATIENCE,
        restore_best_weights=True,
        verbose=1
    ),
    callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-6,
        verbose=1
    ),
    callbacks.ModelCheckpoint(
        'models/best_model.keras',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
]

print(f"Training configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Early stopping patience: {PATIENCE}")

## 5. Train the Model

In [None]:
# Create models directory if it doesn't exist
Path('models').mkdir(exist_ok=True)

# Train the model
try:
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=model_callbacks,
        verbose=1
    )
    print("\nTraining completed!")
except NameError:
    print("Data not loaded - cannot train model")

## 6. Training History

In [None]:
def plot_training_history(history):
    """Plot training and validation metrics."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Loss
    axes[0].plot(history.history['loss'], label='Training Loss')
    axes[0].plot(history.history['val_loss'], label='Validation Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss (MSE)')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # MAE
    axes[1].plot(history.history['mae'], label='Training MAE')
    axes[1].plot(history.history['val_mae'], label='Validation MAE')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('MAE')
    axes[1].set_title('Training and Validation MAE')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

try:
    plot_training_history(history)
except NameError:
    print("Training history not available")

## 7. Evaluate on Validation Set

In [None]:
try:
    # Evaluate model
    val_loss, val_mae = model.evaluate(X_val, y_val, verbose=0)
    print(f"Validation Loss (MSE): {val_loss:.6f}")
    print(f"Validation MAE: {val_mae:.6f}")
    
    # Convert MAE to pixel error
    pixel_mae = val_mae * 48  # Denormalize from [-1, 1] to [0, 96]
    print(f"\nAverage pixel error: {pixel_mae:.2f} pixels")
except NameError:
    print("Model not trained")

## 8. Visualize Predictions

In [None]:
def visualize_predictions(model, X, y_true, n_samples=6):
    """Visualize model predictions vs ground truth."""
    # Get predictions
    y_pred = model.predict(X[:n_samples], verbose=0)
    
    # Denormalize
    y_true_pixels = y_true[:n_samples] * 48 + 48
    y_pred_pixels = y_pred * 48 + 48
    
    # Plot
    n_cols = 3
    n_rows = (n_samples + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 4 * n_rows))
    axes = axes.flatten() if n_samples > 1 else [axes]
    
    for i in range(n_samples):
        ax = axes[i]
        ax.imshow(X[i].squeeze(), cmap='gray')
        
        # Ground truth (green)
        ax.scatter(y_true_pixels[i, 0::2], y_true_pixels[i, 1::2],
                   c='lime', s=30, marker='o', label='Ground Truth', edgecolors='black', linewidths=0.5)
        
        # Predictions (red)
        ax.scatter(y_pred_pixels[i, 0::2], y_pred_pixels[i, 1::2],
                   c='red', s=30, marker='x', label='Prediction', linewidths=1.5)
        
        ax.set_title(f'Sample {i}')
        ax.axis('off')
        
        if i == 0:
            ax.legend(loc='upper right', fontsize=8)
    
    # Hide unused axes
    for i in range(n_samples, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle('Predictions (red) vs Ground Truth (green)', fontsize=14)
    plt.tight_layout()
    plt.show()

try:
    visualize_predictions(model, X_val, y_val, n_samples=6)
except NameError:
    print("Model or data not available")

## 9. Save Final Model

In [None]:
# Save the trained model
try:
    model_path = Path('models/keypoint_model.keras')
    model.save(model_path)
    print(f"Model saved to: {model_path}")
    print(f"Model size: {model_path.stat().st_size / 1024 / 1024:.2f} MB")
except NameError:
    print("Model not available to save")

## Summary

This notebook demonstrated:

1. **Data Preparation**: Loading and splitting data for training
2. **Augmentation**: Using Keras preprocessing layers
3. **Model Architecture**: Building a CNN for keypoint regression
4. **Training**: Using callbacks for early stopping and checkpointing
5. **Evaluation**: Visualizing predictions vs ground truth

### Next Steps

- Proceed to `03_inference_pipeline.ipynb` to use the trained model
- Experiment with different architectures or hyperparameters
- Try transfer learning with pre-trained backbones