# 02 - Model Training

This notebook demonstrates how to train a CNN model for facial keypoint detection.

## Overview

- Build a CNN architecture for keypoint regression
- Train the model with proper train/validation split
- Monitor training progress and evaluate performance
- Save the trained model for inference

## Dataset

- **Input**: 96x96 grayscale images, normalized to [0, 1]
- **Output**: 30 values (15 keypoints x 2 coordinates), normalized to [-1, 1]
- **Training samples**: ~2140 (samples with all 15 keypoints)

In [None]:
# Standard imports
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# TensorFlow/Keras imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, callbacks

# Package imports
from facial_keypoints.data.loader import load_data, get_data_statistics
from facial_keypoints.visualization.plotting import plot_keypoints
from facial_keypoints.config import settings

# Display settings
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)

# Check TensorFlow and GPU
print(f"TensorFlow version: {tf.__version__}")
gpus = tf.config.list_physical_devices('GPU')
print(f"GPUs available: {len(gpus)}")
if gpus:
    for gpu in gpus:
        print(f"  - {gpu.name}")

## 1. Load and Prepare Data

In [None]:
# Load training data (already shuffled by load_data)
X, y = load_data(test=False)

print(f"Loaded data shape: X={X.shape}, y={y.shape}")
print(f"Data types: X={X.dtype}, y={y.dtype}")

# Get statistics
stats = get_data_statistics(X, y)
print(f"\nDataset Statistics:")
print(f"  Samples: {stats['n_samples']}")
print(f"  Keypoints: {stats['n_keypoints']}")
print(f"  Image range: [{stats['x_min']:.3f}, {stats['x_max']:.3f}]")
print(f"  Keypoint range: [{stats['y_min']:.3f}, {stats['y_max']:.3f}]")

In [None]:
# Split into train/validation sets (80/20)
VAL_SPLIT = 0.2
split_idx = int(len(X) * (1 - VAL_SPLIT))

X_train, X_val = X[:split_idx], X[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]

print(f"Training samples: {len(X_train)} ({100*(1-VAL_SPLIT):.0f}%)")
print(f"Validation samples: {len(X_val)} ({100*VAL_SPLIT:.0f}%)")
print(f"\nInput shape: {X_train.shape[1:]}")
print(f"Output shape: {y_train.shape[1:]} (15 keypoints x 2 coords)")

## 2. Visualize Training Samples

In [None]:
# Display a few training samples
fig, axes = plt.subplots(2, 4, figsize=(14, 7))

for i, ax in enumerate(axes.flatten()):
    plot_keypoints(X_train[i], y_train[i], ax=ax, denormalize=True, 
                   marker_color='cyan', marker_size=30)
    ax.set_title(f'Sample {i}')

plt.suptitle('Training Samples with Ground Truth Keypoints', fontsize=14)
plt.tight_layout()
plt.show()

## 3. Build CNN Model

Architecture with 4 convolutional blocks followed by dense layers for regression.

In [None]:
def build_keypoint_model(input_shape=(96, 96, 1), n_keypoints=15):
    """Build a CNN model for facial keypoint detection.
    
    Architecture:
    - 4 Conv blocks with increasing filters (32 -> 64 -> 128 -> 256)
    - BatchNorm + MaxPool + Dropout in each block
    - 2 Dense layers (512 -> 256) with dropout
    - Output: n_keypoints * 2 with tanh activation (for [-1, 1] range)
    
    Args:
        input_shape: Shape of input images (height, width, channels).
        n_keypoints: Number of keypoints to predict.
        
    Returns:
        Compiled Keras model.
    """
    model = models.Sequential([
        # Input
        layers.Input(shape=input_shape),
        
        # Conv Block 1: 96x96 -> 48x48
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.1),
        
        # Conv Block 2: 48x48 -> 24x24
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.2),
        
        # Conv Block 3: 24x24 -> 12x12
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.3),
        
        # Conv Block 4: 12x12 -> 6x6
        layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        layers.Dropout(0.4),
        
        # Flatten: 6x6x256 = 9216 features
        layers.Flatten(),
        
        # Dense Block 1
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        
        # Dense Block 2
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        
        # Output: 2 coordinates per keypoint, tanh for [-1, 1] range
        layers.Dense(n_keypoints * 2, activation='tanh'),
    ])
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss='mse',
        metrics=['mae']
    )
    
    return model

# Build model
model = build_keypoint_model()
model.summary()

In [None]:
# Count parameters
total_params = model.count_params()
trainable_params = sum([tf.reduce_prod(w.shape).numpy() for w in model.trainable_weights])

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size (estimated): {total_params * 4 / 1024 / 1024:.2f} MB")

## 4. Training Configuration

In [None]:
# Training hyperparameters
EPOCHS = 100
BATCH_SIZE = 32
PATIENCE = 15  # Early stopping patience

# Create models directory
models_dir = Path('../models')
models_dir.mkdir(exist_ok=True)

# Create callbacks
training_callbacks = [
    # Early stopping: stop if val_loss doesn't improve
    callbacks.EarlyStopping(
        monitor='val_loss',
        patience=PATIENCE,
        restore_best_weights=True,
        verbose=1
    ),
    # Reduce learning rate on plateau
    callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=7,
        min_lr=1e-6,
        verbose=1
    ),
    # Save best model
    callbacks.ModelCheckpoint(
        str(models_dir / 'best_model.keras'),
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
]

print("Training Configuration")
print("=" * 40)
print(f"Epochs (max): {EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Steps per epoch: {len(X_train) // BATCH_SIZE}")
print(f"Early stopping patience: {PATIENCE}")
print(f"Model checkpoint: {models_dir / 'best_model.keras'}")

## 5. Train the Model

In [None]:
# Train the model
print("Starting training...\n")

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=training_callbacks,
    verbose=1
)

print("\nTraining completed!")
print(f"Total epochs trained: {len(history.history['loss'])}")

## 6. Training History

In [None]:
def plot_training_history(history):
    """Plot training and validation metrics."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    epochs = range(1, len(history.history['loss']) + 1)
    
    # Loss
    axes[0].plot(epochs, history.history['loss'], 'b-', label='Training Loss', linewidth=2)
    axes[0].plot(epochs, history.history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss (MSE)')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Find best epoch
    best_epoch = np.argmin(history.history['val_loss']) + 1
    best_val_loss = min(history.history['val_loss'])
    axes[0].axvline(best_epoch, color='green', linestyle='--', alpha=0.7, 
                    label=f'Best: epoch {best_epoch}')
    axes[0].scatter([best_epoch], [best_val_loss], color='green', s=100, zorder=5)
    
    # MAE
    axes[1].plot(epochs, history.history['mae'], 'b-', label='Training MAE', linewidth=2)
    axes[1].plot(epochs, history.history['val_mae'], 'r-', label='Validation MAE', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('MAE (normalized coords)')
    axes[1].set_title('Training and Validation MAE')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Add secondary y-axis for pixel error
    ax2 = axes[1].secondary_yaxis('right', functions=(lambda x: x * 48, lambda x: x / 48))
    ax2.set_ylabel('MAE (pixels)')
    
    plt.tight_layout()
    plt.show()
    
    return best_epoch, best_val_loss

best_epoch, best_val_loss = plot_training_history(history)
print(f"\nBest model at epoch {best_epoch} with val_loss = {best_val_loss:.6f}")

## 7. Evaluate on Validation Set

In [None]:
# Evaluate model
val_loss, val_mae = model.evaluate(X_val, y_val, verbose=0)

# Convert to pixel error
pixel_mae = val_mae * 48  # Denormalize from [-1, 1] to pixel space
pixel_rmse = np.sqrt(val_loss) * 48

print("Validation Metrics")
print("=" * 40)
print(f"Loss (MSE): {val_loss:.6f}")
print(f"MAE (normalized): {val_mae:.6f}")
print(f"\nPixel-space metrics (96x96 image):")
print(f"  MAE: {pixel_mae:.2f} pixels")
print(f"  RMSE: {pixel_rmse:.2f} pixels")
print(f"\nRelative error: {pixel_mae / 96 * 100:.1f}% of image size")

## 8. Visualize Predictions

In [None]:
def visualize_predictions(model, X, y_true, n_samples=8):
    """Visualize model predictions vs ground truth."""
    # Get predictions
    y_pred = model.predict(X[:n_samples], verbose=0)
    
    # Denormalize to pixel coordinates
    y_true_pixels = y_true[:n_samples] * 48 + 48
    y_pred_pixels = y_pred * 48 + 48
    
    # Calculate per-sample error
    errors = np.sqrt(np.mean((y_true_pixels - y_pred_pixels) ** 2, axis=1))
    
    # Plot
    n_cols = 4
    n_rows = (n_samples + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 3.5 * n_rows))
    axes = axes.flatten()
    
    for i in range(n_samples):
        ax = axes[i]
        ax.imshow(X[i].squeeze(), cmap='gray')
        
        # Ground truth (green circles)
        ax.scatter(y_true_pixels[i, 0::2], y_true_pixels[i, 1::2],
                   c='lime', s=40, marker='o', label='Ground Truth', 
                   edgecolors='black', linewidths=0.5, zorder=10)
        
        # Predictions (red x markers)
        ax.scatter(y_pred_pixels[i, 0::2], y_pred_pixels[i, 1::2],
                   c='red', s=40, marker='x', label='Prediction', 
                   linewidths=2, zorder=11)
        
        ax.set_title(f'Sample {i} (RMSE: {errors[i]:.1f}px)')
        ax.axis('off')
        
        if i == 0:
            ax.legend(loc='upper right', fontsize=8)
    
    # Hide unused axes
    for i in range(n_samples, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle('Predictions (red X) vs Ground Truth (green O)', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()

visualize_predictions(model, X_val, y_val, n_samples=8)

## 9. Per-Keypoint Error Analysis

In [None]:
# Analyze error per keypoint
y_pred_all = model.predict(X_val, verbose=0)

# Denormalize
y_true_pixels = y_val * 48 + 48
y_pred_pixels = y_pred_all * 48 + 48

# Reshape to (n_samples, 15, 2)
y_true_kp = y_true_pixels.reshape(-1, 15, 2)
y_pred_kp = y_pred_pixels.reshape(-1, 15, 2)

# Calculate Euclidean distance per keypoint
distances = np.sqrt(np.sum((y_true_kp - y_pred_kp) ** 2, axis=2))  # Shape: (n_samples, 15)

# Mean error per keypoint
mean_errors = distances.mean(axis=0)
std_errors = distances.std(axis=0)

# Keypoint names
KEYPOINT_NAMES = [
    'left_eye_center', 'right_eye_center', 'left_eye_inner', 'left_eye_outer',
    'right_eye_inner', 'right_eye_outer', 'left_eyebrow_inner', 'left_eyebrow_outer',
    'right_eyebrow_inner', 'right_eyebrow_outer', 'nose_tip',
    'mouth_left', 'mouth_right', 'mouth_top', 'mouth_bottom'
]

print("Per-Keypoint Error (pixels)")
print("=" * 50)
for i, (name, mean_err, std_err) in enumerate(zip(KEYPOINT_NAMES, mean_errors, std_errors)):
    bar = '#' * int(mean_err)
    print(f"{i:2d}. {name:22s}: {mean_err:5.2f} +/- {std_err:4.2f}  {bar}")

print(f"\nOverall mean: {mean_errors.mean():.2f} +/- {std_errors.mean():.2f} pixels")

In [None]:
# Visualize per-keypoint error
fig, ax = plt.subplots(figsize=(12, 6))

x_pos = np.arange(15)
bars = ax.bar(x_pos, mean_errors, yerr=std_errors, capsize=4, color='steelblue', edgecolor='black')

# Color bars by error magnitude
for bar, err in zip(bars, mean_errors):
    if err > 4:
        bar.set_color('lightcoral')
    elif err > 3:
        bar.set_color('gold')

ax.set_xticks(x_pos)
ax.set_xticklabels(range(15))
ax.set_xlabel('Keypoint Index')
ax.set_ylabel('Mean Error (pixels)')
ax.set_title('Per-Keypoint Prediction Error')
ax.axhline(mean_errors.mean(), color='red', linestyle='--', label=f'Overall mean: {mean_errors.mean():.2f}px')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## 10. Save Final Model

In [None]:
# Save the final trained model
final_model_path = models_dir / 'model.keras'
model.save(final_model_path)

print(f"Model saved to: {final_model_path}")
print(f"Model size: {final_model_path.stat().st_size / 1024 / 1024:.2f} MB")

# Also check best model
best_model_path = models_dir / 'best_model.keras'
if best_model_path.exists():
    print(f"\nBest model (checkpointed): {best_model_path}")
    print(f"Best model size: {best_model_path.stat().st_size / 1024 / 1024:.2f} MB")

## Summary

This notebook demonstrated:

1. **Data Preparation**: Loading and splitting data for training/validation
2. **Model Architecture**: 4-block CNN with BatchNorm and Dropout
3. **Training**: Using callbacks for early stopping, LR reduction, and checkpointing
4. **Evaluation**: MSE/MAE metrics converted to pixel error
5. **Visualization**: Comparing predictions vs ground truth
6. **Error Analysis**: Per-keypoint error breakdown

### Model Performance

The model achieves reasonable accuracy for facial keypoint detection. Common error patterns:
- Eyes and nose tip are typically predicted well
- Mouth corners can be more challenging due to expression variation
- Eyebrow endpoints may have higher variance

### Next Steps

- Proceed to `03_inference_pipeline.ipynb` to use the trained model on new images
- Experiment with data augmentation for better generalization
- Try different architectures (ResNet, EfficientNet) via transfer learning