# Model Training - Driver Drowsiness Detection CNN

This notebook trains a PyTorch CNN model for binary classification of driver drowsiness.

## Steps:
1. Load and prepare the dataset
2. Define the CNN architecture
3. Set up training loop with W&B logging
4. Train the model
5. Evaluate on test set
6. Save the best model


In [None]:
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Add project root to path
project_root = Path().resolve().parent
sys.path.insert(0, str(project_root))

from src.models.cnn_model import DrowsinessCNN
from src.config.settings import WANDB_PROJECT, WANDB_API_KEY

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)


## 1. Dataset Class


In [None]:
class DrowsinessDataset(Dataset):
    """Dataset class for driver drowsiness images."""
    
    def __init__(self, drowsy_dir, non_drowsy_dir, transform=None):
        self.transform = transform
        self.images = []
        self.labels = []
        
        # Load drowsy images (label 1)
        drowsy_path = Path(drowsy_dir)
        for img_path in drowsy_path.glob("*.png"):
            self.images.append(str(img_path))
            self.labels.append(1)  # drowsy = 1
        
        # Load non-drowsy images (label 0)
        non_drowsy_path = Path(non_drowsy_dir)
        for img_path in non_drowsy_path.glob("*.png"):
            self.images.append(str(img_path))
            self.labels.append(0)  # alert = 0
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        return image, label


## 2. Load and Split Dataset


In [None]:
# Define paths
data_dir = project_root / "Data"
drowsy_dir = data_dir / "Drowsy"
non_drowsy_dir = data_dir / "Non Drowsy"

# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((227, 227)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize((227, 227)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create full dataset
full_dataset = DrowsinessDataset(drowsy_dir, non_drowsy_dir, transform=None)
print(f"Total images: {len(full_dataset)}")

# Split: 70% train, 15% val, 15% test
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# Apply transforms
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_test_transform
test_dataset.dataset.transform = val_test_transform

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


## 3. Initialize W&B


In [None]:
# Initialize W&B
wandb.login(key=WANDB_API_KEY)

wandb.init(
    project=WANDB_PROJECT,
    config={
        "learning_rate": 0.001,
        "batch_size": batch_size,
        "epochs": 10,
        "optimizer": "Adam",
        "model_architecture": "DrowsinessCNN",
        "input_size": "227x227",
        "num_classes": 2,
        "train_size": len(train_dataset),
        "val_size": len(val_dataset),
        "test_size": len(test_dataset),
    }
)


## 4. Initialize Model, Loss, and Optimizer


In [None]:
# Initialize model
model = DrowsinessCNN(num_classes=2).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)


## 5. Training Loop


In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    epoch_prec = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    epoch_rec = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    epoch_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    return epoch_loss, epoch_acc, epoch_prec, epoch_rec, epoch_f1


def validate(model, loader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    epoch_prec = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    epoch_rec = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    epoch_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    return epoch_loss, epoch_acc, epoch_prec, epoch_rec, epoch_f1


In [None]:
# Training
num_epochs = wandb.config.epochs
best_val_acc = 0.0
best_model_path = project_root / "models" / "best_model.pth"

# Create models directory
best_model_path.parent.mkdir(exist_ok=True)

print("Starting training...")
for epoch in range(num_epochs):
    # Train
    train_loss, train_acc, train_prec, train_rec, train_f1 = train_epoch(
        model, train_loader, criterion, optimizer, device
    )
    
    # Validate
    val_loss, val_acc, val_prec, val_rec, val_f1 = validate(
        model, val_loader, criterion, device
    )
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Log to W&B
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_accuracy": train_acc,
        "train_precision": train_prec,
        "train_recall": train_rec,
        "train_f1": train_f1,
        "val_loss": val_loss,
        "val_accuracy": val_acc,
        "val_precision": val_prec,
        "val_recall": val_rec,
        "val_f1": val_f1,
        "learning_rate": optimizer.param_groups[0]['lr']
    })
    
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
    print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
        }, best_model_path)
        print(f"  âœ“ Saved best model (val_acc: {val_acc:.4f})")
        
        # Log model artifact to W&B
        artifact = wandb.Artifact('best_model', type='model')
        artifact.add_file(str(best_model_path))
        wandb.log_artifact(artifact)

print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.4f}")


## 6. Test Set Evaluation


In [None]:
# Load best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
print("Loaded best model for testing")

# Evaluate on test set
test_loss, test_acc, test_prec, test_rec, test_f1 = validate(model, test_loader, criterion, device)

# Confusion matrix
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)

# Log test metrics to W&B
wandb.log({
    "test_loss": test_loss,
    "test_accuracy": test_acc,
    "test_precision": test_prec,
    "test_recall": test_rec,
    "test_f1": test_f1
})

print(f"\nTest Results:")
print(f"  Loss: {test_loss:.4f}")
print(f"  Accuracy: {test_acc:.4f}")
print(f"  Precision: {test_prec:.4f}")
print(f"  Recall: {test_rec:.4f}")
print(f"  F1-Score: {test_f1:.4f}")

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Alert', 'Drowsy'], 
            yticklabels=['Alert', 'Drowsy'])
plt.title('Confusion Matrix - Test Set')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

# Log confusion matrix to W&B
wandb.log({"confusion_matrix": wandb.Image(plt)})


## 7. Finish W&B Run


In [None]:
wandb.finish()
print("W&B run completed!")
