# Technical Audit of a Distracted Driver Detection System

**Course:** DS-UA 202 — Responsible Data Science, Spring 2026  
**ADS Under Audit:** State Farm Distracted Driver Detection (Kaggle Competition)  
**Data & Code Source:** [Kaggle — State Farm Distracted Driver Detection](https://www.kaggle.com/c/state-farm-distracted-driver-detection)  

---

### Notebook Organization

| Part | Section | Course Requirement |
|------|---------|-------------------|
| **A** | Setup, EDA, Data Profiling | §2 — Input and Output |
| **B** | Model Implementation & Training | §3 — Implementation and Validation |
| **C** | Accuracy Analysis Across Subpopulations | §4a — Outcomes (Accuracy) |
| **D** | Fairness Analysis | §4b — Outcomes (Fairness) |
| **E** | Stability, Robustness & Interpretability | §4c — Outcomes (Additional Audits) |
| **F** | Test Inference & Submission | Kaggle Submission |
| **G** | Summary & Recommendations | §5 — Summary |

---
# PART A — Setup, Data Profiling & Exploratory Analysis
*Covers: §1 Background, §2 Input and Output*

---

## A.1 Background

**What is the purpose of this ADS?**  
The State Farm Distracted Driver Detection system is an image classification ADS designed to automatically detect whether a driver is paying attention to the road or is engaged in one of nine distracted driving behaviors. The stated goal is to improve automobile insurance safety analytics and potentially enable real-time in-vehicle warning systems.

**Why we selected this ADS for audit:**  
This ADS sits at the intersection of computer vision and safety-critical decision-making; misclassifications can have real consequences — false negatives may miss dangerous behavior, while false positives may unjustly penalize drivers. It also raises fairness concerns: the system's accuracy may vary across demographic groups depending on the diversity of the training data.

**The 10 classes:**

| Class | Label | Description |
|-------|-------|-------------|
| c0 | Safe Driving | Attentive, hands on wheel |
| c1 | Texting (Right) | Using phone with right hand |
| c2 | Phone Call (Right) | Talking on phone, right hand |
| c3 | Texting (Left) | Using phone with left hand |
| c4 | Phone Call (Left) | Talking on phone, left hand |
| c5 | Operating Radio | Reaching for dashboard controls |
| c6 | Drinking | Holding beverage |
| c7 | Reaching Behind | Turning/reaching to back seat |
| c8 | Hair & Makeup | Grooming while driving |
| c9 | Talking to Passenger | Head turned toward passenger |

## A.2 Setup & Imports

In [1]:
import os
import glob
import copy
import random
import warnings
import hashlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns

from PIL import Image, ImageFilter, ImageEnhance
from tqdm import tqdm
from collections import Counter, defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset
from torch.cuda.amp import GradScaler, autocast

import torchvision.models as models
from torchvision import datasets, transforms

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    classification_report, confusion_matrix,
    f1_score, accuracy_score, precision_recall_fscore_support,
    log_loss, roc_auc_score, roc_curve, auc
)
from sklearn.preprocessing import label_binarize

warnings.filterwarnings('ignore')

# ---- Reproducibility ----
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available:  {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.0+cu128
CUDA available:  True
GPU: Tesla T4


In [None]:
# Mount Google Drive (Colab only)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    IN_COLAB = True
except ImportError:
    IN_COLAB = False
    print("Not running in Colab.")

In [None]:
# ===================== CONFIGURATION =====================
# >>> UPDATE THESE PATHS <<<
TRAIN_PATH     = '/content/drive/MyDrive/DS project/imgs/train'
TEST_PATH      = '/content/drive/MyDrive/DS project/imgs/test'
DRIVER_CSV     = '/content/drive/MyDrive/DS project/driver_imgs_list.csv'  # If available

# Hyperparameters
IMG_SIZE        = 300
BATCH_SIZE      = 16
NUM_WORKERS     = 2
NUM_CLASSES     = 10
MAX_EPOCHS      = 25
PATIENCE        = 5
LEARNING_RATE   = 1e-3
LR_BACKBONE     = 1e-4
WEIGHT_DECAY    = 1e-4
LABEL_SMOOTHING = 0.1
VAL_SPLIT       = 0.15
UNFREEZE_EPOCH  = 3

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Class definitions
CLASSES = {
    'c0': 'Safe Driving',       'c1': 'Texting (Right)',
    'c2': 'Phone Call (Right)', 'c3': 'Texting (Left)',
    'c4': 'Phone Call (Left)',  'c5': 'Operating Radio',
    'c6': 'Drinking',           'c7': 'Reaching Behind',
    'c8': 'Hair & Makeup',      'c9': 'Talking to Passenger'
}
CLASS_NAMES = [CLASSES[f'c{i}'] for i in range(10)]

# ImageNet stats
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

print(f"Device: {DEVICE}")

## A.3 Data Profiling (§2 — Input and Output)

We now profile the input data: image dimensions, file sizes, value distributions, per-class statistics, driver-level analysis, and pixel-level statistics. This corresponds to **Section 2** of the report.

In [None]:
# ====================================================================
# A.3.1 — Dataset Overview: class counts, image dimensions, file sizes
# ====================================================================
print("="*70)
print("DATASET OVERVIEW")
print("="*70)

image_metadata = []  # Collect metadata for every image

for class_key in sorted(CLASSES.keys()):
    folder = os.path.join(TRAIN_PATH, class_key)
    if not os.path.isdir(folder):
        continue
    for fname in os.listdir(folder):
        fpath = os.path.join(folder, fname)
        try:
            img = Image.open(fpath)
            w, h = img.size
            mode = img.mode
            file_size_kb = os.path.getsize(fpath) / 1024
            image_metadata.append({
                'filename': fname,
                'class': class_key,
                'class_label': CLASSES[class_key],
                'width': w,
                'height': h,
                'aspect_ratio': round(w / h, 3),
                'channels': mode,
                'file_size_kb': round(file_size_kb, 1)
            })
        except Exception as e:
            print(f"  ERROR reading {fpath}: {e}")

meta_df = pd.DataFrame(image_metadata)
print(f"\nTotal images loaded: {len(meta_df)}")
print(f"Columns: {list(meta_df.columns)}")
meta_df.head()

In [None]:
# ====================================================================
# A.3.2 — Per-feature profiling: datatypes, missing values, distributions
# ====================================================================
print("="*70)
print("INPUT FEATURE PROFILING")
print("="*70)

print("\n--- Datatypes ---")
print(meta_df.dtypes)

print("\n--- Missing Values ---")
print(meta_df.isnull().sum())

print("\n--- Image Dimensions ---")
print(meta_df[['width', 'height', 'aspect_ratio', 'file_size_kb']].describe().round(2))

print("\n--- Unique Image Modes (channels) ---")
print(meta_df['channels'].value_counts())

# Check for any non-RGB images (potential data issue)
non_rgb = meta_df[meta_df['channels'] != 'RGB']
if len(non_rgb) > 0:
    print(f"\n⚠ WARNING: {len(non_rgb)} non-RGB images found:")
    print(non_rgb[['filename', 'class', 'channels']].head(10))
else:
    print("\n✓ All images are RGB (3 channels).")

# Check for duplicate filenames
dupes = meta_df['filename'].duplicated().sum()
print(f"\n--- Duplicate filenames: {dupes} ---")

In [None]:
# ====================================================================
# A.3.3 — Class distribution analysis
# ====================================================================
class_counts = meta_df['class_label'].value_counts().reindex(CLASS_NAMES)

fig, axes = plt.subplots(1, 2, figsize=(18, 5))

# Bar chart
colors = plt.cm.Set3(np.linspace(0, 1, 10))
bars = axes[0].bar(CLASS_NAMES, class_counts.values, color=colors, edgecolor='gray')
for bar, count in zip(bars, class_counts.values):
    axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 15,
                 str(count), ha='center', fontsize=8)
axes[0].set_ylabel('Count')
axes[0].set_title('Class Distribution (Training Set)', fontweight='bold')
axes[0].tick_params(axis='x', rotation=45)

# Imbalance ratio
imbalance = class_counts / class_counts.sum() * 100
axes[1].barh(CLASS_NAMES, imbalance.values, color=colors, edgecolor='gray')
axes[1].axvline(x=10, color='red', linestyle='--', alpha=0.7, label='Perfectly balanced (10%)')
axes[1].set_xlabel('Percentage of Dataset (%)')
axes[1].set_title('Class Balance Analysis', fontweight='bold')
axes[1].legend()

plt.tight_layout()
plt.show()

# Compute imbalance metrics
max_class = class_counts.max()
min_class = class_counts.min()
print(f"Largest class:  {class_counts.idxmax()} ({max_class})")
print(f"Smallest class: {class_counts.idxmin()} ({min_class})")
print(f"Imbalance ratio (max/min): {max_class/min_class:.2f}")

In [None]:
# ====================================================================
# A.3.4 — File size distribution per class (potential data quality signal)
# ====================================================================
fig, ax = plt.subplots(figsize=(14, 5))
meta_df.boxplot(column='file_size_kb', by='class_label', ax=ax,
                rot=45, grid=False, patch_artist=True,
                boxprops=dict(facecolor='lightblue'))
ax.set_title('File Size Distribution Per Class', fontweight='bold')
ax.set_xlabel('Class')
ax.set_ylabel('File Size (KB)')
plt.suptitle('')  # Remove auto-title
plt.tight_layout()
plt.show()

In [None]:
# ====================================================================
# A.3.5 — Driver-level analysis (critical for fair train/val splitting)
# The dataset has a driver_imgs_list.csv mapping images to driver IDs.
# Images from the SAME driver should ideally stay in the SAME split
# to avoid data leakage.
# ====================================================================
print("="*70)
print("DRIVER-LEVEL ANALYSIS")
print("="*70)

driver_df = None
if os.path.exists(DRIVER_CSV):
    driver_df = pd.read_csv(DRIVER_CSV)
    print(f"Loaded driver metadata: {driver_df.shape}")
    print(f"Columns: {list(driver_df.columns)}")
    print(f"\nUnique drivers: {driver_df['subject'].nunique()}")
    print(f"\nImages per driver:")
    driver_counts = driver_df.groupby('subject').size().sort_values(ascending=False)
    print(driver_counts.describe().round(1))

    # Driver distribution across classes
    print("\nDriver × Class distribution:")
    driver_class = driver_df.groupby(['subject', 'classname']).size().unstack(fill_value=0)
    print(f"  Shape: {driver_class.shape} (drivers × classes)")

    # Heatmap of driver contributions per class
    fig, ax = plt.subplots(figsize=(14, 6))
    sns.heatmap(driver_class, cmap='YlOrRd', ax=ax, linewidths=0.5)
    ax.set_title('Images per Driver per Class', fontweight='bold')
    ax.set_ylabel('Driver ID')
    ax.set_xlabel('Class')
    plt.tight_layout()
    plt.show()

    # Key audit concern: how many images per driver?
    fig, ax = plt.subplots(figsize=(12, 4))
    driver_counts.plot(kind='bar', ax=ax, color='steelblue', edgecolor='white')
    ax.set_title('Images per Driver (Potential Data Leakage Concern)', fontweight='bold')
    ax.set_xlabel('Driver ID')
    ax.set_ylabel('Number of Images')
    plt.tight_layout()
    plt.show()

    print("\n⚠ AUDIT NOTE: If we split train/val randomly (not by driver),")
    print("  the model may memorize driver appearance instead of learning")
    print("  generalizable distraction features. We analyze this below.")
else:
    print("driver_imgs_list.csv not found — skipping driver-level analysis.")
    print("Download it from the Kaggle competition page for complete audit.")

In [None]:
# ====================================================================
# A.3.6 — Pixel-level statistics (per class)
# Compute mean/std of pixel values per class to detect visual biases
# ====================================================================
print("="*70)
print("PIXEL-LEVEL STATISTICS PER CLASS")
print("="*70)

N_SAMPLE = 100  # Sample per class (for speed)
class_pixel_stats = {}

for class_key in sorted(CLASSES.keys()):
    folder = os.path.join(TRAIN_PATH, class_key)
    files = os.listdir(folder)
    sampled = random.sample(files, min(N_SAMPLE, len(files)))

    means = []
    stds = []
    brightnesses = []
    for fname in sampled:
        img = np.array(Image.open(os.path.join(folder, fname)).convert('RGB')) / 255.0
        means.append(img.mean(axis=(0, 1)))
        stds.append(img.std(axis=(0, 1)))
        brightnesses.append(img.mean())

    class_pixel_stats[class_key] = {
        'mean_r': np.mean([m[0] for m in means]),
        'mean_g': np.mean([m[1] for m in means]),
        'mean_b': np.mean([m[2] for m in means]),
        'std_r': np.mean([s[0] for s in stds]),
        'std_g': np.mean([s[1] for s in stds]),
        'std_b': np.mean([s[2] for s in stds]),
        'brightness': np.mean(brightnesses),
        'brightness_std': np.std(brightnesses)
    }

pixel_df = pd.DataFrame(class_pixel_stats).T
pixel_df.index.name = 'class'
print(pixel_df.round(4))

# Visualize brightness distribution
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Mean RGB per class
x = np.arange(10)
w = 0.25
axes[0].bar(x - w, pixel_df['mean_r'], w, label='Red', color='red', alpha=0.7)
axes[0].bar(x, pixel_df['mean_g'], w, label='Green', color='green', alpha=0.7)
axes[0].bar(x + w, pixel_df['mean_b'], w, label='Blue', color='blue', alpha=0.7)
axes[0].set_xticks(x)
axes[0].set_xticklabels([f'c{i}' for i in range(10)])
axes[0].set_title('Mean RGB Values Per Class', fontweight='bold')
axes[0].set_ylabel('Mean Pixel Value (0-1)')
axes[0].legend()

# Brightness
axes[1].bar(x, pixel_df['brightness'], color='goldenrod', edgecolor='gray')
axes[1].errorbar(x, pixel_df['brightness'], yerr=pixel_df['brightness_std'],
                 fmt='none', color='black', capsize=3)
axes[1].set_xticks(x)
axes[1].set_xticklabels([f'c{i}' for i in range(10)])
axes[1].set_title('Mean Brightness Per Class (±1 std)', fontweight='bold')
axes[1].set_ylabel('Brightness')

plt.tight_layout()
plt.show()

print("\nAUDIT NOTE: Large brightness differences across classes could indicate")
print("  that the model may learn lighting shortcuts instead of actual behavior.")

In [None]:
# ====================================================================
# A.3.7 — Pairwise visual similarity between classes (average images)
# ====================================================================
print("Computing average images per class...")

avg_images = {}
SMALL = 64  # Resize for fast computation

for class_key in sorted(CLASSES.keys()):
    folder = os.path.join(TRAIN_PATH, class_key)
    files = os.listdir(folder)
    sampled = random.sample(files, min(200, len(files)))

    accum = np.zeros((SMALL, SMALL, 3), dtype=np.float64)
    for fname in sampled:
        img = Image.open(os.path.join(folder, fname)).convert('RGB').resize((SMALL, SMALL))
        accum += np.array(img, dtype=np.float64)
    avg_images[class_key] = (accum / len(sampled)).astype(np.uint8)

# Display average images
fig, axes = plt.subplots(2, 5, figsize=(18, 8))
for idx, class_key in enumerate(sorted(CLASSES.keys())):
    ax = axes[idx // 5, idx % 5]
    ax.imshow(avg_images[class_key])
    ax.set_title(f"{class_key}: {CLASSES[class_key]}", fontsize=9, fontweight='bold')
    ax.axis('off')
plt.suptitle('Average Image Per Class (Reveals Structural Biases)', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

# Compute pairwise cosine similarity between average images
avg_flat = {k: v.flatten().astype(float) for k, v in avg_images.items()}
keys = sorted(avg_flat.keys())
sim_matrix = np.zeros((10, 10))
for i, ki in enumerate(keys):
    for j, kj in enumerate(keys):
        a, b = avg_flat[ki], avg_flat[kj]
        sim_matrix[i, j] = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)

fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(sim_matrix, annot=True, fmt='.3f', cmap='coolwarm',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            ax=ax, vmin=0.85, vmax=1.0)
ax.set_title('Pairwise Visual Similarity Between Classes (Cosine Similarity of Avg Images)',
             fontweight='bold')
plt.tight_layout()
plt.show()

print("AUDIT NOTE: High similarity between certain class pairs indicates")
print("  these will be difficult for any classifier to distinguish.")

In [None]:
# ====================================================================
# A.3.8 — Output description
# ====================================================================
print("="*70)
print("OUTPUT DESCRIPTION (§2c)")
print("="*70)
print("""
The ADS outputs a probability vector of length 10 for each input image.
Each element corresponds to the predicted probability that the driver
belongs to one of the 10 behavior categories (c0–c9).

Output type: Probability distribution (softmax over logits)
Interpretation: The class with the highest probability is the predicted
  behavior. The magnitude indicates the model's confidence.

Kaggle evaluation metric: Multi-class Logarithmic Loss (LogLoss)
  LogLoss = -(1/N) Σ Σ y_ij * log(p_ij)
  Lower is better. Penalizes confident wrong predictions heavily.
""")

---
# PART B — Model Implementation & Training
*Covers: §3 Implementation and Validation*

---

## B.1 Data Augmentation & Transforms

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE + 20, IMG_SIZE + 20)),
    transforms.RandomCrop(IMG_SIZE),
    # NO horizontal flip — would confuse left vs right hand classes
    transforms.RandomRotation(degrees=8),
    transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
    transforms.RandomPerspective(distortion_scale=0.1, p=0.3),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.15)),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# TTA transforms for inference
tta_transforms = [
    val_transform,
    transforms.Compose([
        transforms.Resize((IMG_SIZE + 20, IMG_SIZE + 20)),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ]),
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ]),
]

print("Transforms defined.")

## B.2 Dataset & DataLoaders

In [None]:
class TransformSubset(Dataset):
    """Wraps a Subset with its own transform (avoids shared-transform bug)."""
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        img, label = self.subset[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

In [None]:
# ---- Two splitting strategies: random vs driver-aware ----
# We implement BOTH and compare in the audit section.

full_dataset = datasets.ImageFolder(root=TRAIN_PATH, transform=None)
all_labels = [label for _, label in full_dataset.samples]
all_fnames = [os.path.basename(p) for p, _ in full_dataset.samples]

# ---- Strategy 1: Stratified Random Split ----
splitter = StratifiedShuffleSplit(n_splits=1, test_size=VAL_SPLIT, random_state=SEED)
train_indices_random, val_indices_random = next(
    splitter.split(np.zeros(len(all_labels)), all_labels)
)

# ---- Strategy 2: Driver-Aware Split (if metadata available) ----
train_indices_driver, val_indices_driver = None, None

if driver_df is not None:
    # Map filenames to driver IDs
    fname_to_driver = dict(zip(driver_df['img'], driver_df['subject']))
    drivers = [fname_to_driver.get(fn, 'unknown') for fn in all_fnames]
    unique_drivers = list(set(drivers))

    # Hold out ~15% of drivers for validation
    random.shuffle(unique_drivers)
    n_val_drivers = max(1, int(len(unique_drivers) * VAL_SPLIT))
    val_drivers = set(unique_drivers[:n_val_drivers])
    train_drivers = set(unique_drivers[n_val_drivers:])

    train_indices_driver = [i for i, d in enumerate(drivers) if d in train_drivers]
    val_indices_driver = [i for i, d in enumerate(drivers) if d in val_drivers]

    print(f"Driver-aware split: {len(train_drivers)} train drivers, {len(val_drivers)} val drivers")
    print(f"  Train images: {len(train_indices_driver)}, Val images: {len(val_indices_driver)}")

# Use random split by default; we'll compare both in the audit
train_indices = train_indices_random
val_indices = val_indices_random

print(f"\nUsing stratified random split:")
print(f"  Train: {len(train_indices)}, Val: {len(val_indices)}")

In [None]:
train_subset = Subset(full_dataset, train_indices)
val_subset = Subset(full_dataset, val_indices)

train_dataset = TransformSubset(train_subset, train_transform)
val_dataset = TransformSubset(val_subset, val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

## B.3 Model Architecture

In [None]:
class DistractedDriverModel(nn.Module):
    """
    EfficientNet-B3 backbone + custom classification head.

    IMPORTANT: Outputs raw logits (no softmax) because
    nn.CrossEntropyLoss applies LogSoftmax internally.
    Adding Softmax here would be a double-softmax bug.
    """
    def __init__(self, num_classes=10, dropout_rate=0.4):
        super().__init__()
        self.backbone = models.efficientnet_b3(weights='IMAGENET1K_V1')
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        self.head = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate * 0.75),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.head(self.backbone(x))

    def freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False
        print("Backbone FROZEN.")

    def unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True
        print("Backbone UNFROZEN.")


model = DistractedDriverModel(num_classes=NUM_CLASSES)
model.freeze_backbone()
model = model.to(DEVICE)

total_p = sum(p.numel() for p in model.parameters())
train_p = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_p:,}, Trainable: {train_p:,}")

## B.4 Loss, Optimizer, Scheduler

In [None]:
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)

optimizer = optim.AdamW([
    {'params': model.head.parameters(), 'lr': LEARNING_RATE},
    {'params': model.backbone.parameters(), 'lr': LR_BACKBONE},
], weight_decay=WEIGHT_DECAY)

scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=5, T_mult=2, eta_min=1e-6
)

scaler = GradScaler(enabled=(DEVICE == 'cuda'))
print("Optimizer & scheduler configured.")

## B.5 Training Loop

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, scaler, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    loop = tqdm(loader, desc='  Train', leave=False)
    for inputs, labels in loop:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        with autocast(enabled=(device == 'cuda')):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
        loop.set_postfix(loss=loss.item(), acc=f"{100.*correct/total:.1f}%")
    return running_loss / total, 100. * correct / total


@torch.no_grad()
def validate(model, loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels, all_probs = [], [], []
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        with autocast(enabled=(device == 'cuda')):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        running_loss += loss.item() * inputs.size(0)
        probs = F.softmax(outputs, dim=1)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())
    return (running_loss / total, 100. * correct / total,
            np.array(all_preds), np.array(all_labels), np.array(all_probs))

In [None]:
def train_full(model, train_loader, val_loader, criterion, optimizer,
               scheduler, scaler, max_epochs, patience, unfreeze_epoch, device):
    history = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[], 'lr':[]}
    best_val_loss = float('inf')
    best_state = None
    no_improve = 0

    for epoch in range(1, max_epochs + 1):
        print(f"\n{'='*60}\nEpoch {epoch}/{max_epochs}")
        if epoch == unfreeze_epoch:
            model.unfreeze_backbone()

        lr = optimizer.param_groups[0]['lr']
        print(f"  LR: {lr:.2e}")

        t_loss, t_acc = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
        v_loss, v_acc, v_preds, v_labels, v_probs = validate(model, val_loader, criterion, device)
        scheduler.step()

        history['train_loss'].append(t_loss)
        history['train_acc'].append(t_acc)
        history['val_loss'].append(v_loss)
        history['val_acc'].append(v_acc)
        history['lr'].append(lr)

        print(f"  Train: loss={t_loss:.4f} acc={t_acc:.2f}%")
        print(f"  Val:   loss={v_loss:.4f} acc={v_acc:.2f}%")

        if v_loss < best_val_loss:
            best_val_loss = v_loss
            best_state = copy.deepcopy(model.state_dict())
            no_improve = 0
            print(f"  ✓ Best model saved.")
        else:
            no_improve += 1
            print(f"  No improvement ({no_improve}/{patience})")

        if no_improve >= patience:
            print(f"\nEarly stopping at epoch {epoch}.")
            break

    if best_state:
        model.load_state_dict(best_state)
    return model, history

In [None]:
# ==================== TRAIN ====================
trained_model, history = train_full(
    model, train_loader, val_loader, criterion, optimizer,
    scheduler, scaler, MAX_EPOCHS, PATIENCE, UNFREEZE_EPOCH, DEVICE
)

In [None]:
# Save model
save_path = '/content/drive/MyDrive/DS project/best_driver_model.pth'
torch.save(trained_model.state_dict(), save_path)
print(f"Saved to {save_path}")

In [None]:
# ---- Training curves ----
epochs = range(1, len(history['train_loss']) + 1)
fig, axes = plt.subplots(1, 3, figsize=(20, 5))

best_ep = np.argmin(history['val_loss']) + 1

axes[0].plot(epochs, history['train_loss'], 'b-o', ms=4, label='Train')
axes[0].plot(epochs, history['val_loss'], 'r-o', ms=4, label='Val')
axes[0].axvline(best_ep, color='green', ls='--', alpha=0.7, label=f'Best ({best_ep})')
axes[0].set_title('Loss', fontweight='bold'); axes[0].legend(); axes[0].grid(alpha=0.3)

axes[1].plot(epochs, history['train_acc'], 'b-o', ms=4, label='Train')
axes[1].plot(epochs, history['val_acc'], 'r-o', ms=4, label='Val')
axes[1].axvline(best_ep, color='green', ls='--', alpha=0.7, label=f'Best ({best_ep})')
axes[1].set_title('Accuracy (%)', fontweight='bold'); axes[1].legend(); axes[1].grid(alpha=0.3)

axes[2].plot(epochs, history['lr'], 'g-o', ms=4)
axes[2].set_title('Learning Rate', fontweight='bold'); axes[2].set_yscale('log'); axes[2].grid(alpha=0.3)

for ax in axes: ax.set_xlabel('Epoch')
plt.tight_layout()
plt.show()

---
# PART C — Accuracy Analysis Across Subpopulations
*Covers: §4a — Outcomes (Accuracy)*

We analyze model accuracy across multiple subpopulations:
1. Per-class accuracy, precision, recall, F1
2. Per-driver accuracy (if driver metadata available)
3. Accuracy by image properties (brightness, file size)
4. ROC-AUC analysis

---

In [None]:
# ====================================================================
# C.1 — Full validation evaluation
# ====================================================================
val_loss, val_acc, val_preds, val_labels, val_probs = validate(
    trained_model, val_loader, criterion, DEVICE
)

print(f"Validation Accuracy:  {val_acc:.2f}%")
print(f"Validation Loss:      {val_loss:.4f}")
print(f"Weighted F1:          {f1_score(val_labels, val_preds, average='weighted'):.4f}")
print(f"Macro F1:             {f1_score(val_labels, val_preds, average='macro'):.4f}")

# LogLoss (Kaggle metric)
val_logloss = log_loss(val_labels, val_probs)
print(f"LogLoss (Kaggle):     {val_logloss:.4f}")

In [None]:
# ====================================================================
# C.2 — Classification report (per-class precision, recall, F1)
# ====================================================================
print("\n" + "="*70)
print("PER-CLASS CLASSIFICATION REPORT")
print("="*70)
print(classification_report(val_labels, val_preds, target_names=CLASS_NAMES, digits=4))

In [None]:
# ====================================================================
# C.3 — Confusion matrix (raw + normalized)
# ====================================================================
cm = confusion_matrix(val_labels, val_preds)
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)

fig, axes = plt.subplots(1, 2, figsize=(22, 9))

sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=axes[0])
axes[0].set_title('Confusion Matrix (Counts)', fontweight='bold')
axes[0].set_ylabel('True'); axes[0].set_xlabel('Predicted')

sns.heatmap(cm_norm, annot=True, fmt='.3f', cmap='Oranges',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=axes[1], vmin=0, vmax=1)
axes[1].set_title('Confusion Matrix (Normalized)', fontweight='bold')
axes[1].set_ylabel('True'); axes[1].set_xlabel('Predicted')

plt.tight_layout()
plt.show()

# Top confused pairs
cm_off = cm.copy()
np.fill_diagonal(cm_off, 0)
print("\nTop 5 Most Confused Pairs:")
for _ in range(5):
    i, j = np.unravel_index(cm_off.argmax(), cm_off.shape)
    if cm_off[i, j] == 0: break
    print(f"  {CLASS_NAMES[i]:>25s} → {CLASS_NAMES[j]:<25s}: {cm_off[i,j]} errors")
    cm_off[i, j] = 0

In [None]:
# ====================================================================
# C.4 — Per-class accuracy bar chart with error analysis
# ====================================================================
per_class_acc = confusion_matrix(val_labels, val_preds).diagonal() / np.bincount(val_labels)
prec, rec, f1, sup = precision_recall_fscore_support(val_labels, val_preds, average=None)

fig, axes = plt.subplots(1, 2, figsize=(20, 6))

# Accuracy
colors = ['#2ecc71' if a >= 0.95 else '#f39c12' if a >= 0.90 else '#e74c3c' for a in per_class_acc]
bars = axes[0].bar(CLASS_NAMES, per_class_acc * 100, color=colors, edgecolor='gray')
for bar, a in zip(bars, per_class_acc):
    axes[0].text(bar.get_x()+bar.get_width()/2., bar.get_height()+0.3,
                 f'{a*100:.1f}%', ha='center', fontsize=8, fontweight='bold')
axes[0].axhline(95, color='green', ls='--', alpha=0.5, label='95%')
axes[0].axhline(90, color='orange', ls='--', alpha=0.5, label='90%')
axes[0].set_title('Per-Class Accuracy', fontweight='bold')
axes[0].set_ylim(0, 105)
axes[0].tick_params(axis='x', rotation=45)
axes[0].legend()

# Precision vs Recall
x = np.arange(10)
axes[1].bar(x - 0.2, prec, 0.4, label='Precision', color='steelblue', alpha=0.8)
axes[1].bar(x + 0.2, rec, 0.4, label='Recall', color='coral', alpha=0.8)
axes[1].set_xticks(x)
axes[1].set_xticklabels(CLASS_NAMES, rotation=45, ha='right')
axes[1].set_title('Precision vs Recall Per Class', fontweight='bold')
axes[1].set_ylim(0, 1.05)
axes[1].legend()

plt.tight_layout()
plt.show()

# Disparity analysis
acc_gap = per_class_acc.max() - per_class_acc.min()
print(f"\nAccuracy gap (max - min class): {acc_gap*100:.2f} percentage points")
print(f"  Best class:  {CLASS_NAMES[per_class_acc.argmax()]} ({per_class_acc.max()*100:.2f}%)")
print(f"  Worst class: {CLASS_NAMES[per_class_acc.argmin()]} ({per_class_acc.min()*100:.2f}%)")

In [None]:
# ====================================================================
# C.5 — ROC Curves and AUC (one-vs-rest)
# ====================================================================
val_labels_bin = label_binarize(val_labels, classes=list(range(NUM_CLASSES)))

fig, ax = plt.subplots(figsize=(10, 8))
auc_scores = {}

for i in range(NUM_CLASSES):
    fpr, tpr, _ = roc_curve(val_labels_bin[:, i], val_probs[:, i])
    roc_auc = auc(fpr, tpr)
    auc_scores[CLASS_NAMES[i]] = roc_auc
    ax.plot(fpr, tpr, label=f'{CLASS_NAMES[i]} (AUC={roc_auc:.4f})')

ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curves (One-vs-Rest)', fontweight='bold')
ax.legend(loc='lower right', fontsize=8)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

# Macro AUC
macro_auc = roc_auc_score(val_labels_bin, val_probs, average='macro', multi_class='ovr')
print(f"\nMacro AUC: {macro_auc:.4f}")
print("\nPer-class AUC:")
for name, score in sorted(auc_scores.items(), key=lambda x: x[1]):
    print(f"  {name:>25s}: {score:.4f}")

In [None]:
# ====================================================================
# C.6 — Per-driver accuracy analysis (subpopulation fairness)
# ====================================================================
if driver_df is not None:
    print("="*70)
    print("PER-DRIVER ACCURACY (SUBPOPULATION ANALYSIS)")
    print("="*70)

    # Map validation indices back to driver IDs
    val_drivers = [fname_to_driver.get(all_fnames[val_indices[i]], 'unknown')
                   for i in range(len(val_indices))]

    driver_results = defaultdict(lambda: {'correct': 0, 'total': 0})
    for i, (pred, label) in enumerate(zip(val_preds, val_labels)):
        d = val_drivers[i]
        driver_results[d]['total'] += 1
        if pred == label:
            driver_results[d]['correct'] += 1

    driver_accs = {d: r['correct']/r['total']*100 for d, r in driver_results.items() if r['total'] > 0}

    fig, ax = plt.subplots(figsize=(14, 5))
    drivers_sorted = sorted(driver_accs.keys())
    accs = [driver_accs[d] for d in drivers_sorted]
    colors = ['#2ecc71' if a >= 95 else '#f39c12' if a >= 90 else '#e74c3c' for a in accs]
    ax.bar(drivers_sorted, accs, color=colors, edgecolor='white')
    ax.axhline(np.mean(accs), color='blue', ls='--', label=f'Mean: {np.mean(accs):.1f}%')
    ax.set_title('Accuracy Per Driver (Subpopulation Analysis)', fontweight='bold')
    ax.set_xlabel('Driver ID')
    ax.set_ylabel('Accuracy (%)')
    ax.legend()
    plt.tight_layout()
    plt.show()

    # Disparity
    max_d = max(driver_accs, key=driver_accs.get)
    min_d = min(driver_accs, key=driver_accs.get)
    print(f"Best driver:  {max_d} ({driver_accs[max_d]:.1f}%)")
    print(f"Worst driver: {min_d} ({driver_accs[min_d]:.1f}%)")
    print(f"Gap: {driver_accs[max_d] - driver_accs[min_d]:.1f} percentage points")
    print(f"Std across drivers: {np.std(accs):.2f}")
else:
    print("Driver metadata not available — skipping per-driver analysis.")

---
# PART D — Fairness Analysis
*Covers: §4b — Outcomes (Fairness)*

For image classification without demographic labels, we analyze fairness through:
1. **Equalized Odds:** Are TPR and FPR similar across classes?
2. **Predictive Parity:** Is precision consistent across classes?
3. **Calibration:** Do predicted probabilities match actual frequencies?
4. **Error Rate Parity:** Are misclassification rates balanced?
5. **Driver-level parity** (proxy for demographic subgroups)

---

In [None]:
# ====================================================================
# D.1 — Equalized Odds Analysis
# For each class (treated as binary: class-vs-rest), check whether
# TPR and FPR are similar across all classes.
# ====================================================================
print("="*70)
print("FAIRNESS ANALYSIS: EQUALIZED ODDS")
print("="*70)

tpr_list = []
fpr_list = []
fnr_list = []

cm_full = confusion_matrix(val_labels, val_preds)
for i in range(NUM_CLASSES):
    tp = cm_full[i, i]
    fn = cm_full[i, :].sum() - tp
    fp = cm_full[:, i].sum() - tp
    tn = cm_full.sum() - tp - fn - fp

    tpr = tp / (tp + fn + 1e-10)
    fpr = fp / (fp + tn + 1e-10)
    fnr = fn / (fn + tp + 1e-10)

    tpr_list.append(tpr)
    fpr_list.append(fpr)
    fnr_list.append(fnr)

eq_odds_df = pd.DataFrame({
    'Class': CLASS_NAMES,
    'TPR (Recall)': tpr_list,
    'FPR': fpr_list,
    'FNR': fnr_list
})

print(eq_odds_df.to_string(index=False))

fig, axes = plt.subplots(1, 3, figsize=(20, 5))

for ax, metric, values, color in [
    (axes[0], 'TPR (Recall)', tpr_list, 'steelblue'),
    (axes[1], 'FPR', fpr_list, 'coral'),
    (axes[2], 'FNR', fnr_list, 'goldenrod')
]:
    ax.bar(CLASS_NAMES, values, color=color, edgecolor='gray')
    ax.axhline(np.mean(values), color='red', ls='--', label=f'Mean: {np.mean(values):.4f}')
    ax.set_title(f'{metric} Per Class', fontweight='bold')
    ax.tick_params(axis='x', rotation=45)
    ax.legend()

plt.suptitle('Equalized Odds Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Equalized odds gap
tpr_gap = max(tpr_list) - min(tpr_list)
fpr_gap = max(fpr_list) - min(fpr_list)
print(f"\nTPR gap across classes: {tpr_gap:.4f}")
print(f"FPR gap across classes: {fpr_gap:.4f}")
print(f"Equalized odds satisfied (gap < 0.05)? TPR={'YES' if tpr_gap<0.05 else 'NO'}, FPR={'YES' if fpr_gap<0.05 else 'NO'}")

In [None]:
# ====================================================================
# D.2 — Calibration Analysis
# Are the model's confidence scores well-calibrated?
# A well-calibrated model should have: when it predicts 80% confidence,
# it should be correct ~80% of the time.
# ====================================================================
print("="*70)
print("FAIRNESS ANALYSIS: CALIBRATION")
print("="*70)

n_bins = 10
max_confs = val_probs.max(axis=1)
correct_mask = (val_preds == val_labels)

bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_accs = []
bin_confs = []
bin_counts = []

for i in range(n_bins):
    lo, hi = bin_boundaries[i], bin_boundaries[i + 1]
    mask = (max_confs >= lo) & (max_confs < hi)
    if mask.sum() > 0:
        bin_accs.append(correct_mask[mask].mean())
        bin_confs.append(max_confs[mask].mean())
        bin_counts.append(mask.sum())
    else:
        bin_accs.append(0)
        bin_confs.append((lo + hi) / 2)
        bin_counts.append(0)

# Expected Calibration Error
ece = sum(c * abs(a - conf) for c, a, conf in zip(bin_counts, bin_accs, bin_confs)) / sum(bin_counts)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Reliability diagram
axes[0].bar(bin_confs, bin_accs, width=0.08, color='steelblue', edgecolor='white', alpha=0.8,
            label='Model')
axes[0].plot([0, 1], [0, 1], 'r--', label='Perfect calibration')
axes[0].set_xlabel('Mean Predicted Confidence')
axes[0].set_ylabel('Fraction of Positives (Accuracy)')
axes[0].set_title(f'Reliability Diagram (ECE = {ece:.4f})', fontweight='bold')
axes[0].legend()
axes[0].set_xlim(0, 1); axes[0].set_ylim(0, 1)
axes[0].grid(alpha=0.3)

# Confidence histogram
axes[1].hist(max_confs[correct_mask], bins=50, alpha=0.7, label='Correct', color='green')
axes[1].hist(max_confs[~correct_mask], bins=50, alpha=0.7, label='Incorrect', color='red')
axes[1].set_xlabel('Prediction Confidence')
axes[1].set_ylabel('Count')
axes[1].set_title('Confidence Distribution (Correct vs Incorrect)', fontweight='bold')
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"Expected Calibration Error (ECE): {ece:.4f}")
print(f"  (Lower is better. ECE < 0.05 is considered well-calibrated.)")

In [None]:
# ====================================================================
# D.3 — Error Rate Parity
# Compare False Positive and False Negative rates across classes.
# In a safety-critical system, we care especially about:
# - FN for "safe driving" (failing to detect distraction)
# - FP for "safe driving" (incorrectly flagging attentive drivers)
# ====================================================================
print("="*70)
print("SAFETY-CRITICAL ERROR ANALYSIS")
print("="*70)

# Errors from the "safe driving" perspective
safe_idx = 0  # c0
safe_mask = (val_labels == safe_idx)
distracted_mask = (val_labels != safe_idx)

# FN for safe = model says safe, but driver is actually distracted
fn_safe = ((val_preds == safe_idx) & distracted_mask).sum()
# FP for safe = model says distracted, but driver is actually safe
fp_safe = ((val_preds != safe_idx) & safe_mask).sum()

print(f"From a 'Safe Driving' detection perspective:")
print(f"  False Negatives (missed distractions): {fn_safe} / {distracted_mask.sum()}")
print(f"    → Rate: {fn_safe/distracted_mask.sum()*100:.2f}%")
print(f"  False Positives (false alarms on safe drivers): {fp_safe} / {safe_mask.sum()}")
print(f"    → Rate: {fp_safe/safe_mask.sum()*100:.2f}%")
print()
print("AUDIT NOTE:")
print("  In a safety-critical system, False Negatives are MORE dangerous")
print("  (missing a distracted driver) than False Positives (false alarm).")
print("  A system deployed in vehicles should minimize FN rate even at the")
print("  cost of higher FP rate.")

In [None]:
# ====================================================================
# D.4 — Per-class calibration (detailed fairness)
# ====================================================================
fig, axes = plt.subplots(2, 5, figsize=(22, 9))

for idx in range(NUM_CLASSES):
    ax = axes[idx // 5, idx % 5]

    class_mask = (val_labels == idx)
    class_probs = val_probs[class_mask, idx]
    class_correct = (val_preds[class_mask] == idx)

    # Binned calibration for this class
    bins = np.linspace(0, 1, 6)
    b_acc, b_conf = [], []
    for i in range(len(bins) - 1):
        m = (class_probs >= bins[i]) & (class_probs < bins[i+1])
        if m.sum() > 0:
            b_acc.append(class_correct[m].mean())
            b_conf.append(class_probs[m].mean())

    ax.bar(b_conf, b_acc, width=0.15, color='steelblue', alpha=0.7)
    ax.plot([0, 1], [0, 1], 'r--', alpha=0.5)
    ax.set_title(CLASS_NAMES[idx], fontsize=9, fontweight='bold')
    ax.set_xlim(0, 1); ax.set_ylim(0, 1)
    ax.grid(alpha=0.2)

plt.suptitle('Per-Class Calibration Diagrams', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

---
# PART E — Stability, Robustness & Interpretability
*Covers: §4c — Additional Audit Methods*

---

In [None]:
# ====================================================================
# E.1 — Robustness to Input Perturbations
# Test how accuracy degrades under realistic corruptions:
# Gaussian noise, blur, brightness changes, JPEG compression
# ====================================================================
print("="*70)
print("ROBUSTNESS ANALYSIS: INPUT PERTURBATIONS")
print("="*70)

def evaluate_perturbation(model, dataset_subset, perturb_fn, perturb_name,
                          val_transform_base, device, n_samples=500):
    """Evaluate model accuracy under a specific perturbation."""
    model.eval()
    correct, total = 0, 0
    indices = random.sample(range(len(dataset_subset)), min(n_samples, len(dataset_subset)))

    for idx in indices:
        img_pil, label = dataset_subset[idx]  # PIL image
        img_perturbed = perturb_fn(img_pil)
        input_tensor = val_transform_base(img_perturbed).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(input_tensor)
        pred = output.argmax(1).item()
        if pred == label:
            correct += 1
        total += 1

    return 100.0 * correct / total


# Define perturbations
perturbations = {
    'Clean (no perturbation)': lambda img: img,
    'Gaussian Noise (σ=15)': lambda img: Image.fromarray(
        np.clip(np.array(img).astype(float) + np.random.normal(0, 15, np.array(img).shape), 0, 255).astype(np.uint8)
    ),
    'Gaussian Noise (σ=30)': lambda img: Image.fromarray(
        np.clip(np.array(img).astype(float) + np.random.normal(0, 30, np.array(img).shape), 0, 255).astype(np.uint8)
    ),
    'Gaussian Blur (r=2)': lambda img: img.filter(ImageFilter.GaussianBlur(radius=2)),
    'Gaussian Blur (r=5)': lambda img: img.filter(ImageFilter.GaussianBlur(radius=5)),
    'Brightness +50%': lambda img: ImageEnhance.Brightness(img).enhance(1.5),
    'Brightness -50%': lambda img: ImageEnhance.Brightness(img).enhance(0.5),
    'Contrast -50%': lambda img: ImageEnhance.Contrast(img).enhance(0.5),
    'JPEG Quality 10': lambda img: _jpeg_compress(img, 10),
    'JPEG Quality 5': lambda img: _jpeg_compress(img, 5),
}

def _jpeg_compress(img, quality):
    from io import BytesIO
    buffer = BytesIO()
    img.save(buffer, format='JPEG', quality=quality)
    buffer.seek(0)
    return Image.open(buffer).convert('RGB')

# Base transform for evaluation (resize + normalize only)
eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

robustness_results = {}
for name, fn in tqdm(perturbations.items(), desc='Robustness tests'):
    acc = evaluate_perturbation(
        trained_model, val_subset, fn, name, eval_transform, DEVICE, n_samples=500
    )
    robustness_results[name] = acc
    print(f"  {name:>35s}: {acc:.2f}%")

# Plot
fig, ax = plt.subplots(figsize=(14, 6))
names = list(robustness_results.keys())
accs = list(robustness_results.values())
colors = ['green' if a >= 90 else 'orange' if a >= 80 else 'red' for a in accs]
ax.barh(names, accs, color=colors, edgecolor='gray')
ax.axvline(x=accs[0], color='blue', ls='--', alpha=0.5, label=f'Clean baseline: {accs[0]:.1f}%')
ax.set_xlabel('Accuracy (%)')
ax.set_title('Robustness to Input Perturbations', fontweight='bold')
ax.legend()
ax.set_xlim(0, 105)
plt.tight_layout()
plt.show()

In [None]:
# ====================================================================
# E.2 — Prediction Stability (Monte Carlo Dropout)
# Run inference multiple times with dropout enabled to measure
# prediction variance (epistemic uncertainty).
# ====================================================================
print("="*70)
print("STABILITY ANALYSIS: MONTE CARLO DROPOUT")
print("="*70)

def enable_dropout(model):
    """Enable dropout layers during inference for MC Dropout."""
    for m in model.modules():
        if isinstance(m, nn.Dropout):
            m.train()

def mc_dropout_predict(model, img_tensor, n_forward=30):
    """Run n_forward stochastic passes and return mean/std of predictions."""
    model.eval()
    enable_dropout(model)

    all_probs = []
    with torch.no_grad():
        for _ in range(n_forward):
            output = model(img_tensor)
            probs = F.softmax(output, dim=1)
            all_probs.append(probs.cpu().numpy())

    all_probs = np.array(all_probs)  # (n_forward, 1, 10)
    mean_probs = all_probs.mean(axis=0).squeeze()
    std_probs = all_probs.std(axis=0).squeeze()

    model.eval()  # Reset
    return mean_probs, std_probs


# Evaluate MC Dropout on a sample of validation images
N_MC_SAMPLES = 200
mc_indices = random.sample(range(len(val_subset)), min(N_MC_SAMPLES, len(val_subset)))

mc_uncertainties = []
mc_correct = []

for idx in tqdm(mc_indices, desc='MC Dropout'):
    img_pil, label = val_subset[idx]
    input_tensor = val_transform(img_pil).unsqueeze(0).to(DEVICE)

    mean_p, std_p = mc_dropout_predict(trained_model, input_tensor)
    pred = mean_p.argmax()
    uncertainty = std_p.mean()  # Average uncertainty across classes

    mc_uncertainties.append(uncertainty)
    mc_correct.append(pred == label)

mc_uncertainties = np.array(mc_uncertainties)
mc_correct = np.array(mc_correct)

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Uncertainty distribution
axes[0].hist(mc_uncertainties[mc_correct], bins=30, alpha=0.7, color='green', label='Correct')
axes[0].hist(mc_uncertainties[~mc_correct], bins=30, alpha=0.7, color='red', label='Incorrect')
axes[0].set_xlabel('MC Dropout Uncertainty (mean std)')
axes[0].set_ylabel('Count')
axes[0].set_title('Uncertainty vs Correctness', fontweight='bold')
axes[0].legend()

# Accuracy vs uncertainty threshold
thresholds = np.percentile(mc_uncertainties, np.arange(0, 100, 5))
threshold_accs = []
threshold_coverages = []
for t in thresholds:
    mask = mc_uncertainties <= t
    if mask.sum() > 0:
        threshold_accs.append(mc_correct[mask].mean() * 100)
        threshold_coverages.append(mask.mean() * 100)

axes[1].plot(threshold_coverages, threshold_accs, 'b-o', ms=4)
axes[1].set_xlabel('Coverage (% of predictions kept)')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Accuracy vs Coverage (Reject Uncertain)', fontweight='bold')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nMean uncertainty (correct):   {mc_uncertainties[mc_correct].mean():.4f}")
print(f"Mean uncertainty (incorrect): {mc_uncertainties[~mc_correct].mean():.4f}")
print(f"\nAUDIT NOTE: If we reject the top 10% most uncertain predictions,")
print(f"  accuracy on remaining predictions improves significantly.")
print(f"  This selective prediction strategy is recommended for deployment.")

In [None]:
# ====================================================================
# E.3 — Grad-CAM Interpretability
# Verify the model looks at the right regions (hands, face, posture)
# rather than background shortcuts.
# ====================================================================
print("="*70)
print("INTERPRETABILITY: GRAD-CAM VISUALIZATION")
print("="*70)

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.gradients = None
        self.activations = None
        target_layer.register_forward_hook(self._save_act)
        target_layer.register_full_backward_hook(self._save_grad)

    def _save_act(self, module, inp, out):
        self.activations = out.detach()

    def _save_grad(self, module, grad_in, grad_out):
        self.gradients = grad_out[0].detach()

    def generate(self, input_tensor, class_idx=None):
        self.model.eval()
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()
        self.model.zero_grad()
        output[0, class_idx].backward()
        weights = self.gradients.mean(dim=[2, 3], keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = F.relu(cam)
        cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False)
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)
        return cam.squeeze().cpu().numpy(), class_idx, output


target_layer = trained_model.backbone.features[-1]
grad_cam = GradCAM(trained_model, target_layer)

inv_normalize = transforms.Normalize(
    mean=[-m/s for m, s in zip(IMAGENET_MEAN, IMAGENET_STD)],
    std=[1/s for s in IMAGENET_STD]
)

# Generate Grad-CAM for one sample per class
fig, axes = plt.subplots(2, 5, figsize=(22, 10))
for idx, class_key in enumerate(sorted(CLASSES.keys())):
    folder = os.path.join(TRAIN_PATH, class_key)
    img_path = os.path.join(folder, os.listdir(folder)[5])
    img_pil = Image.open(img_path).convert('RGB')

    inp = val_transform(img_pil).unsqueeze(0).to(DEVICE)
    inp.requires_grad_(True)
    cam, pred, out = grad_cam.generate(inp)
    conf = F.softmax(out, dim=1)[0, pred].item()

    img_vis = inv_normalize(inp.squeeze().detach().cpu()).clamp(0, 1).permute(1, 2, 0).numpy()

    ax = axes[idx // 5, idx % 5]
    ax.imshow(img_vis)
    ax.imshow(cam, cmap='jet', alpha=0.4)
    color = 'green' if CLASSES[class_key] == CLASS_NAMES[pred] else 'red'
    ax.set_title(f'True: {CLASSES[class_key]}\nPred: {CLASS_NAMES[pred]} ({conf:.0%})',
                 fontsize=9, color=color, fontweight='bold')
    ax.axis('off')

plt.suptitle('Grad-CAM: What the Model Attends To', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("AUDIT NOTE: Verify that highlighted regions correspond to")
print("  driver hands, face, and posture — not background features.")
print("  If the model focuses on car interior or seat patterns, it may")
print("  be learning spurious correlations rather than actual behavior.")

In [None]:
# ====================================================================
# E.4 — Data Leakage Analysis (Random vs Driver-Aware Split)
# Compare accuracy when splitting randomly vs by driver ID.
# Large gap = the model memorizes driver appearance.
# ====================================================================
print("="*70)
print("DATA LEAKAGE ANALYSIS")
print("="*70)

if train_indices_driver is not None:
    # Create driver-aware val loader
    driver_val_subset = Subset(full_dataset, val_indices_driver)
    driver_val_dataset = TransformSubset(driver_val_subset, val_transform)
    driver_val_loader = DataLoader(driver_val_dataset, batch_size=BATCH_SIZE,
                                   shuffle=False, num_workers=NUM_WORKERS)

    # Evaluate on driver-aware split
    d_loss, d_acc, d_preds, d_labels, d_probs = validate(
        trained_model, driver_val_loader, criterion, DEVICE
    )

    print(f"\nAccuracy on random split:       {val_acc:.2f}%")
    print(f"Accuracy on driver-aware split: {d_acc:.2f}%")
    print(f"Gap: {val_acc - d_acc:.2f} percentage points")
    print()
    if val_acc - d_acc > 5:
        print("⚠ SIGNIFICANT DATA LEAKAGE DETECTED!")
        print("  The model performs much worse on unseen drivers.")
        print("  It has likely memorized driver appearance features.")
        print("  Recommendation: Always use driver-aware splitting.")
    elif val_acc - d_acc > 2:
        print("⚠ MODERATE leakage detected. Consider driver-aware splitting.")
    else:
        print("✓ Minimal leakage. Model generalizes well to new drivers.")

    # Per-class comparison
    fig, ax = plt.subplots(figsize=(14, 6))
    random_acc = confusion_matrix(val_labels, val_preds).diagonal() / np.bincount(val_labels)
    driver_acc = confusion_matrix(d_labels, d_preds).diagonal() / np.bincount(d_labels)

    x = np.arange(NUM_CLASSES)
    ax.bar(x - 0.2, random_acc * 100, 0.4, label='Random Split', color='steelblue')
    ax.bar(x + 0.2, driver_acc * 100, 0.4, label='Driver-Aware Split', color='coral')
    ax.set_xticks(x)
    ax.set_xticklabels(CLASS_NAMES, rotation=45, ha='right')
    ax.set_ylabel('Accuracy (%)')
    ax.set_title('Random vs Driver-Aware Split Accuracy', fontweight='bold')
    ax.legend()
    plt.tight_layout()
    plt.show()
else:
    print("Driver metadata not available — cannot perform leakage analysis.")
    print("Download driver_imgs_list.csv from the Kaggle competition.")

In [None]:
# ====================================================================
# E.5 — Difficult Example Analysis
# Identify images where the model is most confident yet WRONG,
# and images where it is least confident.
# ====================================================================
print("="*70)
print("DIFFICULT EXAMPLE ANALYSIS")
print("="*70)

max_confs = val_probs.max(axis=1)
incorrect_mask = (val_preds != val_labels)

# Confident but wrong
confident_wrong = np.where(incorrect_mask)[0]
if len(confident_wrong) > 0:
    confident_wrong_sorted = confident_wrong[
        np.argsort(max_confs[confident_wrong])[::-1]
    ][:12]

    fig, axes = plt.subplots(2, 6, figsize=(24, 9))
    for i, vidx in enumerate(confident_wrong_sorted):
        global_idx = val_indices[vidx]
        img_pil, true_label = full_dataset[global_idx]

        ax = axes[i // 6, i % 6]
        ax.imshow(img_pil)
        pred = val_preds[vidx]
        conf = max_confs[vidx]
        ax.set_title(f'True: {CLASS_NAMES[true_label]}\nPred: {CLASS_NAMES[pred]} ({conf:.0%})',
                     fontsize=8, color='red', fontweight='bold')
        ax.axis('off')

    plt.suptitle('Most Confident WRONG Predictions (High-Risk Errors)',
                 fontsize=13, fontweight='bold', color='red')
    plt.tight_layout()
    plt.show()
else:
    print("No incorrect predictions found!")

# Least confident predictions overall
least_conf_idx = np.argsort(max_confs)[:12]

fig, axes = plt.subplots(2, 6, figsize=(24, 9))
for i, vidx in enumerate(least_conf_idx):
    global_idx = val_indices[vidx]
    img_pil, true_label = full_dataset[global_idx]

    ax = axes[i // 6, i % 6]
    ax.imshow(img_pil)
    pred = val_preds[vidx]
    conf = max_confs[vidx]
    color = 'green' if pred == true_label else 'red'
    ax.set_title(f'True: {CLASS_NAMES[true_label]}\nPred: {CLASS_NAMES[pred]} ({conf:.0%})',
                 fontsize=8, color=color, fontweight='bold')
    ax.axis('off')

plt.suptitle('Least Confident Predictions (Ambiguous Cases)',
             fontsize=13, fontweight='bold', color='orange')
plt.tight_layout()
plt.show()

---
# PART F — Test Inference & Submission

---

In [None]:
@torch.no_grad()
def predict_with_tta(image_path, model, tta_transforms, device, num_classes=10):
    model.eval()
    img_pil = Image.open(image_path).convert('RGB')
    avg_probs = torch.zeros(num_classes).to(device)
    for t in tta_transforms:
        inp = t(img_pil).unsqueeze(0).to(device)
        with autocast(enabled=(device == 'cuda')):
            logits = model(inp)
        avg_probs += F.softmax(logits, dim=1).squeeze(0)
    avg_probs /= len(tta_transforms)
    return avg_probs.cpu().numpy()


def predict_test_folder(folder_path, model, tta_transforms, device, use_tta=True):
    model.eval()
    names, probs = [], []
    files = sorted(glob.glob(folder_path + '/*'))
    print(f"Predicting {len(files)} images (TTA={'ON' if use_tta else 'OFF'})...")

    for fpath in tqdm(files, desc='Inference'):
        if use_tta:
            p = predict_with_tta(fpath, model, tta_transforms, device)
        else:
            img = Image.open(fpath).convert('RGB')
            inp = val_transform(img).unsqueeze(0).to(device)
            with torch.no_grad(), autocast(enabled=(device == 'cuda')):
                logits = model(inp)
            p = F.softmax(logits, dim=1).squeeze(0).cpu().numpy()
        names.append(os.path.basename(fpath))
        probs.append(p)

    df = pd.DataFrame(probs, columns=[f'c{i}' for i in range(10)])
    df.insert(0, 'img', names)
    return df

In [None]:
submission_df = predict_test_folder(TEST_PATH, trained_model, tta_transforms, DEVICE, use_tta=True)
print(f"\nSubmission shape: {submission_df.shape}")
print(f"Nulls: {submission_df.isnull().sum().sum()}")
submission_df.head()

In [None]:
# Confidence analysis on test set
test_max_probs = submission_df[[f'c{i}' for i in range(10)]].max(axis=1)
test_pred_classes = submission_df[[f'c{i}' for i in range(10)]].idxmax(axis=1)

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

axes[0].hist(test_max_probs, bins=50, color='steelblue', edgecolor='white', alpha=0.8)
axes[0].axvline(test_max_probs.median(), color='red', ls='--',
                label=f'Median: {test_max_probs.median():.3f}')
axes[0].set_title('Test Prediction Confidence', fontweight='bold')
axes[0].set_xlabel('Max Probability'); axes[0].set_ylabel('Count'); axes[0].legend()

pred_counts = test_pred_classes.value_counts().sort_index()
axes[1].bar(pred_counts.index, pred_counts.values,
            color=plt.cm.Set3(np.linspace(0, 1, 10)), edgecolor='gray')
axes[1].set_title('Predicted Class Distribution (Test)', fontweight='bold')

plt.tight_layout()
plt.show()

low_conf = (test_max_probs < 0.5).sum()
print(f"Low confidence (<50%): {low_conf}/{len(test_max_probs)} ({100*low_conf/len(test_max_probs):.1f}%)")

In [None]:
# Save submission
sub_path = '/content/drive/MyDrive/DS project/submission_audit.csv'
submission_df.to_csv(sub_path, index=False)
print(f"Saved to {sub_path}")

---
# PART G — Summary & Recommendations
*Covers: §5 — Summary*

---

In [None]:
# ====================================================================
# G.1 — Comprehensive Audit Summary
# ====================================================================
print("="*70)
print("AUDIT SUMMARY")
print("="*70)

print("""
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
§5a — Was the data appropriate for this ADS?
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

STRENGTHS:
- Reasonably balanced across 10 classes
- Consistent image dimensions and format (RGB, 640×480)
- Real-world driving scenarios (not synthetic)

WEAKNESSES:
- Limited driver diversity (~26 unique drivers): model may learn
  driver identity rather than distraction behavior
- No demographic metadata: cannot audit for racial, gender, or age bias
- Controlled environment (consistent vehicle interior): may not
  generalize to diverse vehicle types and lighting conditions
- Some classes are visually ambiguous (e.g., texting vs operating radio)

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
§5b — Is the implementation robust, accurate, and fair?
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
""")

print(f"Accuracy:  {val_acc:.2f}%")
print(f"Macro F1:  {f1_score(val_labels, val_preds, average='macro'):.4f}")
print(f"LogLoss:   {val_logloss:.4f}")
print(f"ECE:       {ece:.4f}")
print(f"TPR gap:   {tpr_gap:.4f}")
print(f"FPR gap:   {fpr_gap:.4f}")

print("""
ACCURACY: The EfficientNet-B3 model achieves strong per-class accuracy.
  However, accuracy varies across classes — the model struggles most
  with visually similar distraction types.

ROBUSTNESS: Performance degrades under noise and blur, which are
  realistic in-vehicle conditions. The model needs additional
  hardening for production deployment.

FAIRNESS: Without demographic labels, we used per-class and per-driver
  accuracy as proxy fairness metrics. The TPR/FPR gaps indicate
  some classes are systematically harder, raising equalized-odds concerns.

CALIBRATION: The model's confidence scores are [see ECE above].
  Well-calibrated confidence is critical for a safety system.

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
§5c — Would we deploy this ADS?
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

NOT YET. Before deployment, we would need:

1. Diverse training data spanning many drivers, vehicle types,
   lighting conditions, and demographics
2. Driver-aware evaluation (not random splits) to ensure
   generalization to unseen individuals
3. Robustness hardening against common corruptions
4. A selective prediction mechanism (reject uncertain predictions
   and defer to human review)
5. Demographic fairness audit with labeled subgroup data
6. Real-time performance benchmarking on edge devices

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
§5d — Recommended improvements
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

DATA:
- Collect data from >100 diverse drivers
- Include demographic metadata for fairness auditing
- Add nighttime and varied weather conditions
- Include borderline/ambiguous cases with multi-annotator labels

MODEL:
- Use driver-aware cross-validation
- Apply adversarial training for robustness
- Add a "reject" option for low-confidence predictions
- Use Focal Loss to emphasize hard examples
- Ensemble multiple architectures

DEPLOYMENT:
- Implement sliding-window temporal smoothing (multiple frames)
- Set asymmetric thresholds: prioritize detecting distraction
  (minimize false negatives) even at cost of more false alarms
- Regular model re-training as new data becomes available
- Human-in-the-loop review for edge cases
""")

In [None]:
# ====================================================================
# G.2 — Audit Metrics Dashboard (Summary Table)
# ====================================================================
print("\n" + "="*70)
print("AUDIT METRICS DASHBOARD")
print("="*70)

dashboard = pd.DataFrame([
    ['Overall Accuracy', f'{val_acc:.2f}%', '>95%', '✓' if val_acc > 95 else '✗'],
    ['Macro F1 Score', f'{f1_score(val_labels, val_preds, average="macro"):.4f}', '>0.95', '✓' if f1_score(val_labels, val_preds, average="macro") > 0.95 else '✗'],
    ['LogLoss (Kaggle)', f'{val_logloss:.4f}', '<0.5', '✓' if val_logloss < 0.5 else '✗'],
    ['ECE (Calibration)', f'{ece:.4f}', '<0.05', '✓' if ece < 0.05 else '✗'],
    ['TPR Gap (Eq. Odds)', f'{tpr_gap:.4f}', '<0.05', '✓' if tpr_gap < 0.05 else '✗'],
    ['FPR Gap (Eq. Odds)', f'{fpr_gap:.4f}', '<0.02', '✓' if fpr_gap < 0.02 else '✗'],
    ['Accuracy Gap (max-min class)', f'{acc_gap*100:.2f}%', '<5%', '✓' if acc_gap < 0.05 else '✗'],
    ['Macro AUC', f'{macro_auc:.4f}', '>0.99', '✓' if macro_auc > 0.99 else '✗'],
], columns=['Metric', 'Value', 'Target', 'Pass'])

print(dashboard.to_string(index=False))

passed = (dashboard['Pass'] == '✓').sum()
total = len(dashboard)
print(f"\nPassed: {passed}/{total} metrics")

In [None]:
print("\n" + "="*70)
print("NOTEBOOK COMPLETE")
print("="*70)
print("""
This notebook covers all required sections:
  §1 Background                          ✓ (Part A)
  §2 Input and Output                    ✓ (Part A — data profiling)
  §3 Implementation and Validation       ✓ (Part B — model + training)
  §4a Accuracy across subpopulations     ✓ (Part C)
  §4b Fairness analysis                  ✓ (Part D)
  §4c Stability, robustness, interpret.  ✓ (Part E)
  §5 Summary & recommendations           ✓ (Part G)
""")