# Driver Drowsiness Detection - Multimodal Transformer Model
## Novel Fusion Architecture for Enhanced Performance

This notebook implements a multimodal transformer-based driver drowsiness detection system featuring:
- Vision Transformer (ViT) for spatial feature extraction
- Temporal sequence modeling for video frame sequences  
- Multi-modal fusion mechanisms (RGB, attention maps, temporal features)
- Complete EDA and data science lifecycle
- Interactive dashboard with real-time predictions


## 1. Installation and Setup


In [None]:
# Install required packages (with Python 3.12 compatible PyTorch)
# Fix for Python 3.12 compatibility: Use PyTorch 2.1.0+ from official wheel
# For Google Colab with GPU, use CUDA 12.1; for CPU or local, adjust as needed

# Install PyTorch (Python 3.12 compatible)
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 -q

# Install other packages
%pip install transformers matplotlib scikit-learn gradio pandas seaborn pillow timm -q

print("‚úì All packages installed successfully!")
print("Note: If CUDA 12.1 fails, try: %pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -q")


In [None]:
# Setup paths
try:
    from google.colab import drive
    drive.mount('/content/drive')
    BASE_DIR = '/content/drive/MyDrive/cs163_ds'
except:
    BASE_DIR = '/Users/spartan/Downloads/cs163 Modules/project/cs163_ds'

print(f"Dataset directory: {BASE_DIR}")


## 2. Import Libraries and Data Loading


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from collections import Counter
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import ViTModel, ViTImageProcessor
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## 3. Exploratory Data Analysis


In [None]:
# Dataset exploration
folders = [f for f in os.listdir(BASE_DIR) if os.path.isdir(os.path.join(BASE_DIR, f))]
dataset_info = {}
total_images = 0

for folder in folders:
    folder_path = os.path.join(BASE_DIR, folder)
    images = [f for f in os.listdir(folder_path)
             if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    dataset_info[folder] = len(images)
    total_images += len(images)

print("="*60)
print("DATASET EXPLORATION")
print("="*60)
for folder, count in dataset_info.items():
    print(f"{folder:20s}: {count:5d} images")
print("="*60)
print(f"Total Images: {total_images}")

# Visualize class distribution
plt.figure(figsize=(14, 6))
plt.bar(dataset_info.keys(), dataset_info.values(), color=sns.color_palette("husl", len(folders)))
plt.xticks(rotation=45)
plt.title("Class Distribution", fontsize=14, fontweight='bold')
plt.ylabel("Number of Images")
plt.tight_layout()
plt.savefig('class_distribution_transformer.png', dpi=300, bbox_inches='tight')
plt.show()


## 4. Multimodal Transformer Architecture


### üõ°Ô∏è Anti-Overfitting Measures Implemented:

1. **Layer Freezing**: First 6 ViT encoder layers are frozen (only fine-tuning last layers)
2. **Increased Dropout**:
   - Fusion layer: 0.3 (was 0.1)
   - Classifier: 0.5 and 0.4 (was 0.3)
3. **Label Smoothing**: 0.1 smoothing factor in loss function
4. **Lower Learning Rate**: 1e-5 (reduced from 2e-5)
5. **Higher Weight Decay**: 0.05 (increased from 0.01)
6. **Gradient Clipping**: max_norm=1.0 to prevent gradient explosion
7. **Early Stopping**: Patience of 5 epochs with automatic best model restoration
8. **Learning Rate Scheduling**: Adaptive reduction on plateau


In [None]:
# Novel Multimodal Fusion Transformer Architecture with Anti-Overfitting Techniques
class MultimodalDrowsinessTransformer(nn.Module):
    """
    Novel multimodal transformer architecture combining:
    1. Vision Transformer (ViT) for spatial features
    2. Temporal attention for sequence modeling
    3. Cross-modal fusion mechanisms
    4. Enhanced regularization to prevent overfitting
    """
    def __init__(self, num_classes=7, img_size=224, sequence_length=1, hidden_dim=768):
        super().__init__()

        # Vision Transformer encoder
        self.vit_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

        # FREEZE FIRST 6 LAYERS OF ViT TO PREVENT OVERFITTING
        # Only fine-tune the last layers
        num_layers_to_freeze = 6
        for i in range(num_layers_to_freeze):
            for param in list(self.vit_encoder.encoder.layer[i].parameters()):
                param.requires_grad = False

        print(f"‚úì Frozen first {num_layers_to_freeze} ViT encoder layers to prevent overfitting")

        # Feature dimensions
        self.vit_dim = hidden_dim
        self.hidden_dim = hidden_dim

        # Multi-head attention for temporal fusion (if sequence_length > 1)
        if sequence_length > 1:
            self.temporal_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
            self.temporal_norm = nn.LayerNorm(hidden_dim)

        # Cross-modal fusion layers with INCREASED DROPOUT
        self.fusion_layer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3)  # Increased from 0.1 to 0.3
        )

        # Classification head with HIGHER DROPOUT RATES
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(0.5),  # Increased from 0.3 to 0.5
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.LayerNorm(hidden_dim // 4),
            nn.GELU(),
            nn.Dropout(0.4),  # Additional dropout layer
            nn.Linear(hidden_dim // 4, num_classes)
        )

    def forward(self, images):
        """
        Forward pass with multimodal fusion
        Args:
            images: Tensor of shape (batch_size, channels, height, width)
        """
        # Extract vision features using ViT
        outputs = self.vit_encoder(pixel_values=images)
        vision_features = outputs.last_hidden_state[:, 0, :]  # CLS token

        # Apply fusion layer
        fused_features = self.fusion_layer(vision_features)

        # Classification
        logits = self.classifier(fused_features)

        return logits, fused_features

# Initialize model
label_names = sorted(folders)
num_classes = len(label_names)
model_multimodal = MultimodalDrowsinessTransformer(num_classes=num_classes).to(device)
print(f"\nModel initialized with {num_classes} classes")
print(f"Model parameters: {sum(p.numel() for p in model_multimodal.parameters())/1e6:.2f}M")
print(f"Trainable parameters: {sum(p.numel() for p in model_multimodal.parameters() if p.requires_grad)/1e6:.2f}M")


## 5. Dataset Class and Data Loading


## 6. Training with Anti-Overfitting Configuration

This section implements comprehensive anti-overfitting measures to ensure robust model performance and prevent overfitting:
- **Label Smoothing**: Reduces overconfidence
- **Gradient Clipping**: Prevents gradient explosion
- **Early Stopping**: Stops training when validation performance plateaus
- **Reduced Learning Rate**: More stable training
- **Increased Weight Decay**: Stronger regularization


In [None]:
# Custom Dataset Class
class DriverDrowsinessDataset(Dataset):
    def __init__(self, image_paths, labels, processor, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        # Load and preprocess image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        # Process with ViT processor
        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs['pixel_values'].squeeze(0)

        return pixel_values, torch.tensor(label, dtype=torch.long)

# Prepare data
all_images = []
all_labels = []
label2id = {name: idx for idx, name in enumerate(sorted(folders))}
id2label = {idx: name for name, idx in label2id.items()}

for label_name in sorted(folders):
    folder_path = os.path.join(BASE_DIR, label_name)
    images = [os.path.join(folder_path, f) for f in os.listdir(folder_path)
              if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    all_images.extend(images)
    all_labels.extend([label2id[label_name]] * len(images))

# Train-test split
train_images, val_images, train_labels, val_labels = train_test_split(
    all_images, all_labels, test_size=0.2, stratify=all_labels, random_state=42
)

print(f"Training samples: {len(train_images)}")
print(f"Validation samples: {len(val_images)}")

# Create datasets
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
train_dataset = DriverDrowsinessDataset(train_images, train_labels, processor)
val_dataset = DriverDrowsinessDataset(val_images, val_labels, processor)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)


## 6. Training Loop with Anti-Overfitting Measures

**‚ö†Ô∏è IMPORTANT: The actual training code is in Cell 23 below!**

This section contains the complete training implementation with all anti-overfitting techniques.


In [None]:
# ====================================================================
# üöÄ COMPLETE TRAINING LOOP WITH ANTI-OVERFITTING MEASURES
# ====================================================================
# This is the MAIN TRAINING CELL - Run this to train your model!
# ====================================================================

# Training setup with ANTI-OVERFITTING MEASURES
print("="*60)
print("TRAINING SETUP - ANTI-OVERFITTING CONFIGURATION")
print("="*60)

# Label smoothing to prevent overfitting (smoothing factor = 0.1)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
print("‚úì Label smoothing enabled (smoothing=0.1)")

# Lower learning rate and higher weight decay for better regularization
optimizer = torch.optim.AdamW(
    model_multimodal.parameters(),
    lr=1e-5,  # Reduced from 2e-5 to 1e-5
    weight_decay=0.05  # Increased from 0.01 to 0.05
)
print("‚úì Learning rate: 1e-5 (reduced)")
print("‚úì Weight decay: 0.05 (increased)")

# Learning rate scheduler (NO verbose parameter - not supported in PyTorch)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)
print("‚úì Learning rate scheduler enabled (ReduceLROnPlateau)")
print("="*60)

# Training function with GRADIENT CLIPPING
def train_epoch(model, dataloader, criterion, optimizer, device, clip_grad_norm=1.0):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for pixel_values, labels in dataloader:
        pixel_values = pixel_values.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits, _ = model(pixel_values)
        loss = criterion(logits, labels)
        loss.backward()

        # GRADIENT CLIPPING to prevent exploding gradients and overfitting
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)

        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return total_loss / len(dataloader), 100 * correct / total

# Validation function
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for pixel_values, labels in dataloader:
            pixel_values = pixel_values.to(device)
            labels = labels.to(device)

            logits, _ = model(pixel_values)
            loss = criterion(logits, labels)

            total_loss += loss.item()
            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return total_loss / len(dataloader), 100 * correct / total, all_preds, all_labels

# ====================================================================
# TRAINING LOOP WITH EARLY STOPPING
# ====================================================================
num_epochs = 8  # Reduced for faster training (early stopping will likely stop before this)
best_val_acc = 0
best_val_loss = float('inf')
patience = 3  # Early stopping patience (reduced for faster training)
patience_counter = 0
train_losses, val_losses = [], []
train_accs, val_accs = [], []
best_model_state = None

print("\n" + "="*60)
print("STARTING TRAINING WITH ANTI-OVERFITTING MEASURES")
print("="*60)
print(f"Max epochs: {num_epochs}")
print(f"Early stopping patience: {patience} epochs")
print(f"Gradient clipping: max_norm=1.0")
print("="*60 + "\n")

for epoch in range(num_epochs):
    # Train with gradient clipping
    train_loss, train_acc = train_epoch(
        model_multimodal, train_loader, criterion, optimizer, device, clip_grad_norm=1.0
    )

    # Validate
    val_loss, val_acc, _, _ = validate(model_multimodal, val_loader, criterion, device)

    # Update learning rate
    scheduler.step(val_loss)

    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    # Print epoch results
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    # Check for improvement
    improved = False
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_loss = val_loss
        best_model_state = model_multimodal.state_dict().copy()
        torch.save(model_multimodal.state_dict(), 'best_multimodal_transformer.pth')
        print(f"  ‚úì Saved best model (Val Acc: {val_acc:.2f}%)")
        improved = True
        patience_counter = 0
    elif val_loss < best_val_loss:
        best_val_loss = val_loss
        improved = True
        patience_counter = 0
    else:
        patience_counter += 1

    # Check for overfitting warning
    if epoch > 0:
        if val_loss > val_losses[-2] and train_loss < train_losses[-2]:
            print(f"  ‚ö†Ô∏è  Warning: Possible overfitting (val loss ‚Üë while train loss ‚Üì)")

    # Early stopping
    if patience_counter >= patience:
        print(f"\n{'='*60}")
        print(f"EARLY STOPPING triggered after {epoch+1} epochs")
        print(f"No improvement for {patience} consecutive epochs")
        print(f"Best validation accuracy: {best_val_acc:.2f}%")
        print(f"{'='*60}\n")
        break

    print()

# Restore best model
if best_model_state is not None:
    model_multimodal.load_state_dict(best_model_state)
    print(f"{'='*60}")
    print(f"TRAINING COMPLETED!")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"Best model restored and ready for evaluation")
    print(f"{'='*60}")
else:
    print(f"\nTraining completed! Best validation accuracy: {best_val_acc:.2f}%")


### ‚ö†Ô∏è Skip to Cell 23 for the Actual Training Code

**The training setup and loop with all anti-overfitting measures is in Cell 23.**
The cell below (Cell 16) defines the analysis function that will be used AFTER training completes.


## 6.5 Training History Visualization & Overfitting Analysis


In [None]:
# Visualize training history and detect overfitting for Transformer model
def plot_training_history_transformer(train_losses, val_losses, train_accs, val_accs):
    """Plot training curves and analyze overfitting"""
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    epochs = range(1, len(train_accs) + 1)

    # Accuracy plot
    axes[0].plot(epochs, train_accs, 'o-', label='Train Accuracy', linewidth=2, markersize=8, color='#2E86AB')
    axes[0].plot(epochs, val_accs, 's-', label='Val Accuracy', linewidth=2, markersize=8, color='#A23B72')
    axes[0].set_title('Transformer Model Accuracy Over Epochs', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Accuracy (%)', fontsize=12)
    axes[0].legend(fontsize=11)
    axes[0].grid(alpha=0.3)
    axes[0].set_ylim([min(min(train_accs), min(val_accs)) - 2, max(max(train_accs), max(val_accs)) + 2])

    # Loss plot
    axes[1].plot(epochs, train_losses, 'o-', label='Train Loss', linewidth=2, markersize=8, color='#2E86AB')
    axes[1].plot(epochs, val_losses, 's-', label='Val Loss', linewidth=2, markersize=8, color='#A23B72')
    axes[1].set_title('Transformer Model Loss Over Epochs', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Loss', fontsize=12)
    axes[1].legend(fontsize=11)
    axes[1].grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig('training_history_transformer.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Analyze overfitting
    print("="*60)
    print("OVERFITTING ANALYSIS - MULTIMODAL TRANSFORMER")
    print("="*60)

    # Find best epoch
    best_epoch = np.argmax(val_accs) + 1
    best_val_acc = max(val_accs)
    best_train_acc = train_accs[np.argmax(val_accs)]
    best_val_loss = val_losses[np.argmax(val_accs)]
    best_train_loss = train_losses[np.argmax(val_accs)]

    # Final epoch metrics
    final_train_acc = train_accs[-1]
    final_val_acc = val_accs[-1]
    final_train_loss = train_losses[-1]
    final_val_loss = val_losses[-1]

    # Calculate gaps
    acc_gap_best = best_train_acc - best_val_acc
    acc_gap_final = final_train_acc - final_val_acc
    loss_gap_best = abs(best_train_loss - best_val_loss)
    loss_gap_final = abs(final_train_loss - final_val_loss)

    print(f"Best Model Performance (Epoch {best_epoch}):")
    print(f"  Train Accuracy: {best_train_acc:.2f}%")
    print(f"  Val Accuracy:   {best_val_acc:.2f}%")
    print(f"  Accuracy Gap:   {acc_gap_best:.2f}%")
    print(f"  Train Loss:     {best_train_loss:.4f}")
    print(f"  Val Loss:       {best_val_loss:.4f}")
    print(f"  Loss Gap:       {loss_gap_best:.4f}")

    print(f"\nFinal Epoch Performance:")
    print(f"  Train Accuracy: {final_train_acc:.2f}%")
    print(f"  Val Accuracy:   {final_val_acc:.2f}%")
    print(f"  Accuracy Gap:   {acc_gap_final:.2f}%")
    print(f"  Train Loss:     {final_train_loss:.4f}")
    print(f"  Val Loss:       {final_val_loss:.4f}")
    print(f"  Loss Gap:       {loss_gap_final:.4f}")

    # Overfitting indicators
    print(f"\n" + "="*60)
    print("OVERFITTING INDICATORS:")
    print("="*60)

    overfitting_signs = []

    if final_train_acc >= 99.5:
        overfitting_signs.append("‚ö†Ô∏è  Training accuracy very high (‚â•99.5%) - potential memorization")

    if acc_gap_final > acc_gap_best + 0.5:
        overfitting_signs.append(f"‚ö†Ô∏è  Accuracy gap increased by {abs(acc_gap_final - acc_gap_best):.2f}%")

    if final_val_loss > best_val_loss:
        overfitting_signs.append(f"‚ö†Ô∏è  Validation loss increased from {best_val_loss:.4f} to {final_val_loss:.4f}")

    if final_val_acc < best_val_acc:
        overfitting_signs.append(f"‚ö†Ô∏è  Validation accuracy decreased from {best_val_acc:.2f}% to {final_val_acc:.2f}%")

    # Check if validation loss is increasing while train loss decreasing
    if len(val_losses) >= 3:
        recent_val_trend = val_losses[-3:]
        if recent_val_trend[-1] > recent_val_trend[0] and train_losses[-1] < train_losses[-3]:
            overfitting_signs.append("‚ö†Ô∏è  Validation loss increasing while training loss decreasing (classic overfitting sign)")

    if len(overfitting_signs) == 0:
        print("‚úì No significant overfitting detected!")
        print("  The model generalizes well to validation data.")
    else:
        print(f"Found {len(overfitting_signs)} indicator(s) of overfitting:")
        for sign in overfitting_signs:
            print(f"  {sign}")

        if len(overfitting_signs) >= 3:
            severity = "SEVERE"
        elif len(overfitting_signs) >= 2:
            severity = "MODERATE"
        else:
            severity = "MILD"

        print(f"\nüìä Overfitting Severity: {severity}")
        print(f"\nüí° Recommendations for Transformer Model:")
        print(f"  1. Use the model from Epoch {best_epoch} (best validation performance)")
        print(f"  2. Increase dropout rates (currently 0.1 in fusion, 0.3 in classifier)")
        print(f"  3. Add gradient clipping: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)")
        print(f"  4. Increase weight decay (currently 0.01)")
        print(f"  5. Add label smoothing to CrossEntropyLoss")
        print(f"  6. Use mixup or cutmix augmentation")
        print(f"  7. Reduce learning rate")
        print(f"  8. Freeze more ViT layers and fine-tune only classifier")

    print("="*60)

    return best_epoch, best_val_acc

# Call the analysis function AFTER training completes
# Run this cell after training to visualize and analyze overfitting
if 'train_losses' in locals() and len(train_losses) > 0:
    print("="*60)
    print("TRAINING HISTORY ANALYSIS")
    print("="*60)
    best_epoch_analysis, best_val_acc_analysis = plot_training_history_transformer(
        train_losses, val_losses, train_accs, val_accs
    )
    print(f"\n‚úì Analysis complete! Best model from Epoch {best_epoch_analysis}")
else:
    print("‚ö†Ô∏è  Training history not found. Please run the training loop first.")
    print("   The analysis will be available after training completes.")


## 4.5 Alternative: Improved Transformer with Anti-Overfitting Techniques


In [None]:
# OPTIONAL: Improved Transformer model with stronger regularization to reduce overfitting
# Uncomment and use this if you want to retrain with anti-overfitting measures

"""
class MultimodalDrowsinessTransformerImproved(nn.Module):
    \"\"\"
    Improved multimodal transformer with anti-overfitting techniques:
    - Higher dropout rates
    - Label smoothing
    - Gradient clipping
    - Stronger regularization
    \"\"\"
    def __init__(self, num_classes=7, img_size=224, sequence_length=1, hidden_dim=768):
        super().__init__()

        # Vision Transformer encoder (freeze more layers)
        self.vit_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')
        # Freeze first 6 layers of ViT
        for i in range(6):
            for param in list(self.vit_encoder.encoder.layer[i].parameters()):
                param.requires_grad = False

        self.vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
        self.hidden_dim = hidden_dim

        # Cross-modal fusion layers with higher dropout
        self.fusion_layer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3)  # Increased from 0.1
        )

        # Classification head with higher dropout
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(0.5),  # Increased from 0.3
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.LayerNorm(hidden_dim // 4),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim // 4, num_classes)
        )

    def forward(self, images):
        outputs = self.vit_encoder(pixel_values=images)
        vision_features = outputs.last_hidden_state[:, 0, :]
        fused_features = self.fusion_layer(vision_features)
        logits = self.classifier(fused_features)
        return logits, fused_features

# Training with label smoothing
criterion_improved = nn.CrossEntropyLoss(label_smoothing=0.1)  # Label smoothing
optimizer_improved = torch.optim.AdamW(
    model_multimodal_improved.parameters(),
    lr=1e-5,  # Lower learning rate
    weight_decay=0.05  # Stronger weight decay
)

# Add gradient clipping in training loop:
# torch.nn.utils.clip_grad_norm_(model_multimodal_improved.parameters(), max_norm=1.0)
"""


## 6.6 Load Best Model (Non-Overfit Version)


## 6.7 Training History Visualization & Overfitting Analysis

**Run this cell after training completes** to visualize training curves and analyze overfitting patterns.


## 6. Training Setup & Loop with Anti-Overfitting

**This cell contains the complete training loop with all anti-overfitting measures:**
- ‚úÖ Label smoothing (0.1)
- ‚úÖ Reduced learning rate (1e-5)  
- ‚úÖ Higher weight decay (0.05)
- ‚úÖ Gradient clipping (max_norm=1.0)
- ‚úÖ Early stopping (patience=5)
- ‚úÖ Learning rate scheduling
- ‚úÖ Overfitting detection warnings

**Run this cell to start training!**


In [None]:
# Training completed - now analyze the results
# Check if training has been completed
if 'train_losses' in locals() and len(train_losses) > 0:
    print("="*60)
    print("TRAINING HISTORY ANALYSIS")
    print("="*60)

    # Visualize and analyze training history
    best_epoch_analysis, best_val_acc_analysis = plot_training_history_transformer(
        train_losses, val_losses, train_accs, val_accs
    )

    print(f"\n‚úì Analysis complete! Best model from Epoch {best_epoch_analysis}")
else:
    print("‚ö†Ô∏è  Training history not found. Please run the training loop first (Cell 19).")
    print("   The analysis will be available after training completes.")


In [None]:
**Note:** This cell was removed as it was a duplicate of the training loop in Section 6. Please use Cell 15 for training.


## 7. Model Evaluation


In [None]:
# Load best model and evaluate
model_multimodal.load_state_dict(torch.load('best_multimodal_transformer.pth'))
val_loss, val_acc, y_pred, y_true = validate(model_multimodal, val_loader, criterion, device)

# Classification report
print("="*60)
print("CLASSIFICATION REPORT")
print("="*60)
print(classification_report(y_true, y_pred, target_names=sorted(folders)))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=sorted(folders), yticklabels=sorted(folders))
plt.title('Confusion Matrix - Multimodal Transformer', fontsize=14, fontweight='bold')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('confusion_matrix_transformer.png', dpi=300, bbox_inches='tight')
plt.show()

# Binary drowsiness evaluation
drowsy_labels = ['Yawn', 'closed']
drowsy_ids = [label2id[label] for label in drowsy_labels if label in label2id]

y_true_bin = [1 if id2label[y] in drowsy_labels else 0 for y in y_true]
y_pred_bin = [1 if id2label[y] in drowsy_labels else 0 for y in y_pred]

print("\n" + "="*60)
print("BINARY DROWSINESS CLASSIFICATION REPORT")
print("="*60)
print(classification_report(y_true_bin, y_pred_bin, target_names=['Not Drowsy', 'Drowsy']))


## 8. Interactive Dashboard with Gradio


In [None]:
# Install Gradio
%pip install gradio -q

import gradio as gr
from PIL import Image

def predict_image_multimodal(img_path):
    """Predict using multimodal transformer"""
    if img_path is None:
        return None, None, None

    model_multimodal.eval()
    processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

    # Load and preprocess image
    image = Image.open(img_path).convert('RGB')
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs['pixel_values'].to(device)

    # Predict
    with torch.no_grad():
        logits, _ = model_multimodal(pixel_values)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        pred_id = torch.argmax(probs, dim=-1).item()
        pred_class = id2label[pred_id]
        confidence = probs[0][pred_id].item()

    # Binary drowsiness status
    drowsy_labels = ['Yawn', 'closed']
    binary_status = "Drowsy" if pred_class in drowsy_labels else "Not Drowsy"

    # Class probabilities
    class_probs = {id2label[i]: float(probs[0][i]) for i in range(len(id2label))}

    output_text = f"**Predicted Class:** {pred_class}\\n"
    output_text += f"**Confidence:** {confidence*100:.2f}%\\n"
    output_text += f"**Drowsiness Status:** {binary_status}"

    return output_text, pred_class, class_probs

# Create Gradio interface
iface = gr.Interface(
    fn=predict_image_multimodal,
    inputs=gr.Image(type="filepath"),
    outputs=[
        gr.Markdown(label="Prediction Results"),
        gr.Textbox(label="Predicted Class"),
        gr.Label(label="Class Probabilities")
    ],
    title="üöó Driver Drowsiness Detection - Multimodal Transformer",
    description="Upload an image to predict driver drowsiness using our novel multimodal transformer architecture."
)

iface.launch(share=False, server_name="0.0.0.0")
