# Experiment 3: Internal Masking vs Occlusion Training

**Goal:** Test whether internal channel masking during training can improve robustness to occlusions, as an alternative to training on occluded images.

**Sessions (6 per model):**
- S1: Clean train (baseline)
- S2: Occluded train (standard practice)
- S3: Clean train + mask backbone_early
- S4: Clean train + mask backbone_late
- S5: Clean train + mask neck
- S6: Clean train + mask head

**Evaluation:**
- All models tested on both `test_clean` and `test_occluded` (40%)
- Using our custom evaluation system (P/R/F1, per-class, counting)

## 0. Configuration

**EDIT THIS CELL to switch between smoke test and full run:**

In [None]:
# ============================================================
# EXPERIMENT 3 CONFIGURATION - EDIT THIS CELL ONLY
# ============================================================

# EPOCHS: Set to 1 for smoke test, 50 for full experiment
EPOCHS = 1  # <-- CHANGE THIS: 1 = smoke test, 50 = full run

# Models to run (both for full experiment, one for quick test)
MODELS = ["yolov8n", "rtdetr-l"]

# Sessions to run (all 6 for full, fewer for quick test)
SESSIONS_TO_RUN = ["S1_clean_train", "S2_occ_train", "S3_mask_backbone_early", 
                   "S4_mask_backbone_late", "S5_mask_neck", "S6_mask_head"]

# Masking parameters (fixed for all sessions)
P_APPLY = 0.5      # Probability of applying masking per batch
P_CHANNELS = 0.2   # Fraction of channels to zero when masking

# Training parameters
IMGSZ = 640
BATCH = -1  # Auto batch size
PATIENCE = 10  # Early stopping patience

# Paths
DATASET_ROOT = "data/raw"
OUTPUT_ROOT = "runs/exp3"
OCCLUSION_LEVEL = "level_040"  # 40% occlusion for test

# ============================================================
# GOOGLE DRIVE BACKUP SETTINGS
# ============================================================
ENABLE_DRIVE_BACKUP = True  # Set to False to disable backups
DRIVE_BACKUP_FOLDER = "exp3_backups"  # Folder name in Google Drive

# ============================================================
print(f"Configuration:")
print(f"  EPOCHS: {EPOCHS} {'(SMOKE TEST)' if EPOCHS <= 1 else '(FULL RUN)'}")
print(f"  MODELS: {MODELS}")
print(f"  SESSIONS: {len(SESSIONS_TO_RUN)} sessions")
print(f"  Masking: p_apply={P_APPLY}, p_channels={P_CHANNELS}")
print(f"  Drive Backup: {'ENABLED' if ENABLE_DRIVE_BACKUP else 'DISABLED'}")

## 1. Setup & Environment

In [None]:
# Check if running in Colab
import sys
import os

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running in Google Colab")
    if not os.path.exists('/content/Deep_Learning_Gil_Alon'):
        !git clone https://github.com/gil-attar/Deep_Learning_Project_Gil_Alon.git Deep_Learning_Gil_Alon
    os.chdir('/content/Deep_Learning_Gil_Alon')
else:
    print("Running locally")
    # Navigate to project root if needed
    if 'Experiment_3' in os.getcwd():
        os.chdir('../..')

print(f"Working directory: {os.getcwd()}")

# ============================================================
# MOUNT GOOGLE DRIVE FOR BACKUPS
# ============================================================
DRIVE_MOUNTED = False
DRIVE_BACKUP_PATH = None

if ENABLE_DRIVE_BACKUP and IN_COLAB:
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        DRIVE_BACKUP_PATH = Path(f'/content/drive/MyDrive/{DRIVE_BACKUP_FOLDER}')
        DRIVE_BACKUP_PATH.mkdir(parents=True, exist_ok=True)
        DRIVE_MOUNTED = True
        print(f"Google Drive mounted. Backups will be saved to: {DRIVE_BACKUP_PATH}")
    except Exception as e:
        print(f"WARNING: Could not mount Google Drive: {e}")
        print("Backups will be disabled.")
        DRIVE_MOUNTED = False
elif ENABLE_DRIVE_BACKUP and not IN_COLAB:
    print("Not running in Colab - Drive backup disabled (local files are safe)")
    DRIVE_MOUNTED = False

In [None]:
# Install dependencies
!pip install -q ultralytics roboflow pyyaml pillow numpy matplotlib pandas

In [None]:
# Verify GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Imports
from pathlib import Path
import json
import shutil
import yaml
import pandas as pd
from datetime import datetime

from ultralytics import YOLO, RTDETR

# Import our experiment modules
sys.path.insert(0, str(Path.cwd()))
from experiments.Experiment_3.mask_presets import (
    get_mask_prefixes, get_session_config, SESSIONS
)
from experiments.Experiment_3.channel_masking import MaskingManager

print("All imports successful!")

# ============================================================
# GOOGLE DRIVE BACKUP FUNCTIONS
# ============================================================

def backup_session_to_drive(run_dir: Path, run_name: str) -> bool:
    """
    Backup important training files to Google Drive after each session.
    
    Backs up:
    - weights/best.pt (trained model)
    - weights/last.pt (last checkpoint)
    - results.csv (training metrics)
    - args.yaml (training config)
    - DONE marker
    - masking_summary.json (if exists)
    
    Returns True if backup successful, False otherwise.
    """
    if not DRIVE_MOUNTED or DRIVE_BACKUP_PATH is None:
        return False
    
    try:
        # Create session backup folder
        backup_dir = DRIVE_BACKUP_PATH / run_name
        backup_dir.mkdir(parents=True, exist_ok=True)
        
        # Files to backup
        files_to_backup = [
            ("weights/best.pt", True),      # (path, required)
            ("weights/last.pt", False),
            ("results.csv", False),
            ("args.yaml", False),
            ("DONE", True),
            ("masking_summary.json", False),
        ]
        
        backed_up = []
        for file_rel, required in files_to_backup:
            src = run_dir / file_rel
            if src.exists():
                dst = backup_dir / file_rel
                dst.parent.mkdir(parents=True, exist_ok=True)
                shutil.copy2(src, dst)
                backed_up.append(file_rel)
            elif required:
                print(f"  WARNING: Required file not found: {src}")
        
        # Save backup timestamp
        backup_info = {
            "timestamp": datetime.now().isoformat(),
            "run_name": run_name,
            "source_dir": str(run_dir),
            "files_backed_up": backed_up
        }
        with open(backup_dir / "backup_info.json", 'w') as f:
            json.dump(backup_info, f, indent=2)
        
        print(f"  BACKUP OK: {len(backed_up)} files -> {backup_dir}")
        return True
        
    except Exception as e:
        print(f"  BACKUP FAILED: {e}")
        return False


def backup_evaluation_to_drive(eval_dir: Path, eval_name: str) -> bool:
    """
    Backup evaluation results to Google Drive.
    
    Backs up:
    - metrics.json
    - predictions.json
    - All plot images
    """
    if not DRIVE_MOUNTED or DRIVE_BACKUP_PATH is None:
        return False
    
    try:
        backup_dir = DRIVE_BACKUP_PATH / "evaluations" / eval_name
        backup_dir.mkdir(parents=True, exist_ok=True)
        
        backed_up = []
        
        # Backup JSON files
        for json_file in ["metrics.json", "predictions.json"]:
            src = eval_dir / json_file
            if src.exists():
                shutil.copy2(src, backup_dir / json_file)
                backed_up.append(json_file)
        
        # Backup plot images
        for png_file in eval_dir.glob("*.png"):
            shutil.copy2(png_file, backup_dir / png_file.name)
            backed_up.append(png_file.name)
        
        print(f"  BACKUP OK: {len(backed_up)} files -> {backup_dir}")
        return True
        
    except Exception as e:
        print(f"  BACKUP FAILED: {e}")
        return False


def restore_from_drive(output_dir: Path) -> dict:
    """
    Check Google Drive for existing backups and restore them.
    Use this to continue training after Colab disconnection.
    
    Returns dict with restored session names.
    """
    if not DRIVE_MOUNTED or DRIVE_BACKUP_PATH is None:
        return {"restored": []}
    
    restored = []
    
    try:
        # Find all backed up sessions
        for backup_dir in DRIVE_BACKUP_PATH.iterdir():
            if not backup_dir.is_dir():
                continue
            if backup_dir.name == "evaluations":
                continue
                
            run_name = backup_dir.name
            local_dir = output_dir / run_name
            done_marker = local_dir / "DONE"
            
            # Skip if already exists locally
            if done_marker.exists():
                continue
            
            # Check if backup has DONE marker (completed training)
            backup_done = backup_dir / "DONE"
            if not backup_done.exists():
                continue
            
            # Restore from backup
            print(f"Restoring from Drive: {run_name}")
            local_dir.mkdir(parents=True, exist_ok=True)
            
            # Copy all files
            for src_file in backup_dir.rglob("*"):
                if src_file.is_file():
                    rel_path = src_file.relative_to(backup_dir)
                    dst_file = local_dir / rel_path
                    dst_file.parent.mkdir(parents=True, exist_ok=True)
                    shutil.copy2(src_file, dst_file)
            
            restored.append(run_name)
            print(f"  Restored: {run_name}")
        
        # Also restore evaluations
        eval_backup_dir = DRIVE_BACKUP_PATH / "evaluations"
        if eval_backup_dir.exists():
            for eval_dir in eval_backup_dir.iterdir():
                if not eval_dir.is_dir():
                    continue
                
                eval_name = eval_dir.name
                local_eval_dir = output_dir / "evaluations" / eval_name
                
                if (local_eval_dir / "metrics.json").exists():
                    continue
                
                print(f"Restoring evaluation: {eval_name}")
                local_eval_dir.mkdir(parents=True, exist_ok=True)
                
                for src_file in eval_dir.iterdir():
                    if src_file.is_file():
                        shutil.copy2(src_file, local_eval_dir / src_file.name)
                
                restored.append(f"eval:{eval_name}")
        
    except Exception as e:
        print(f"Restore error: {e}")
    
    return {"restored": restored}


print("Backup functions defined!")

## 2. Download Dataset

In [None]:
# Download dataset if not exists
if not Path("data/raw/train/images").exists():
    os.environ["ROBOFLOW_API_KEY"] = "zEF9icmDY2oTcPkaDcQY"
    !python scripts/download_dataset.py --output_dir data/raw
else:
    print("Dataset already exists")

# Verify
print(f"Train images: {len(list(Path('data/raw/train/images').glob('*')))}")
print(f"Val images: {len(list(Path('data/raw/valid/images').glob('*')))}")
print(f"Test images: {len(list(Path('data/raw/test/images').glob('*')))}")

## 3. Generate Occluded Training Data

In [None]:
# Generate occluded training data for S2
# We need 40% occlusion on training images

occluded_train_dir = Path("data/occluded_train_040")

if not occluded_train_dir.exists():
    print("Generating occluded training data (40% occlusion)...")
    
    # First build evaluation indices if not exists
    if not Path("data/processed/evaluation/train_index.json").exists():
        !python scripts/build_evaluation_indices.py \
            --dataset_root data/raw \
            --output_dir data/processed/evaluation
    
    # Generate occluded training images
    !python scripts/generate_synthetic_occlusions.py \
        --test_index data/processed/evaluation/train_index.json \
        --images_dir data/raw/train/images \
        --labels_dir data/raw/train/labels \
        --output_dir data/occluded_train_040 \
        --levels 0.4 \
        --seed 42
else:
    print(f"Occluded training data already exists at {occluded_train_dir}")

In [None]:
# Generate occluded test data if not exists
occluded_test_dir = Path(f"data/synthetic_occlusion/{OCCLUSION_LEVEL}")

if not occluded_test_dir.exists():
    print(f"Generating occluded test data ({OCCLUSION_LEVEL})...")
    
    !python scripts/generate_synthetic_occlusions.py \
        --test_index data/processed/evaluation/test_index.json \
        --images_dir data/raw/test/images \
        --labels_dir data/raw/test/labels \
        --output_dir data/synthetic_occlusion \
        --levels 0.4 \
        --seed 42
else:
    print(f"Occluded test data already exists at {occluded_test_dir}")

## 4. Create Data YAML Files

In [None]:
# Create data.yaml files for clean and occluded training

# Load class names from original data.yaml
with open('data/raw/data.yaml', 'r') as f:
    original_config = yaml.safe_load(f)

# Clean training data.yaml
clean_config = {
    'path': str(Path('data/raw').resolve()),
    'train': 'train/images',
    'val': 'valid/images',
    'test': 'test/images',
    'names': original_config['names'],
    'nc': len(original_config['names'])
}

Path('data/processed').mkdir(parents=True, exist_ok=True)
with open('data/processed/data_clean.yaml', 'w') as f:
    yaml.dump(clean_config, f, default_flow_style=False)
print("Created data/processed/data_clean.yaml")

# Occluded training data.yaml (for S2)
occ_train_config = {
    'path': str(Path('data').resolve()),
    'train': 'occluded_train_040/level_040/images',  # Occluded training images
    'val': 'raw/valid/images',  # Keep validation clean
    'test': 'raw/test/images',
    'names': original_config['names'],
    'nc': len(original_config['names'])
}

with open('data/processed/data_occ_train.yaml', 'w') as f:
    yaml.dump(occ_train_config, f, default_flow_style=False)
print("Created data/processed/data_occ_train.yaml")

## 5. Build Evaluation Indices

In [None]:
# Build evaluation indices if not exists
if not Path("data/processed/evaluation/test_index.json").exists():
    !python scripts/build_evaluation_indices.py \
        --dataset_root data/raw \
        --output_dir data/processed/evaluation
else:
    print("Evaluation indices already exist")

# Verify
with open("data/processed/evaluation/test_index.json") as f:
    test_index = json.load(f)
print(f"Test set: {test_index['metadata']['num_images']} images, {test_index['metadata']['total_objects']} objects")

## 6. Training Functions

In [None]:
def get_model(model_name: str):
    """Load a model by name."""
    if 'yolo' in model_name.lower():
        return YOLO(f"{model_name}.pt")
    elif 'rtdetr' in model_name.lower():
        return RTDETR(f"{model_name}.pt")
    else:
        raise ValueError(f"Unknown model: {model_name}")


def get_model_type(model_name: str) -> str:
    """Get model type for mask presets."""
    if 'yolo' in model_name.lower():
        return 'yolo'
    elif 'rtdetr' in model_name.lower():
        return 'rtdetr'
    else:
        raise ValueError(f"Unknown model: {model_name}")


def train_session(
    model_name: str,
    session_name: str,
    epochs: int,
    output_dir: Path,
    p_apply: float = 0.5,
    p_channels: float = 0.2
) -> dict:
    """
    Train a single session.
    
    Returns:
        Dictionary with training results and paths
    """
    session_config = get_session_config(session_name)
    run_name = f"{model_name}__{session_name}"
    run_dir = output_dir / run_name
    
    print(f"\n{'='*60}")
    print(f"TRAINING: {run_name}")
    print(f"{'='*60}")
    print(f"Description: {session_config['description']}")
    print(f"Epochs: {epochs}")
    
    # Check if already completed (local or from Drive backup)
    done_marker = run_dir / "DONE"
    if done_marker.exists():
        print(f"Session already completed. Skipping.")
        return {"status": "skipped", "run_dir": str(run_dir)}
    
    # Select data.yaml based on session
    if session_config['train_data'] == 'occluded':
        data_yaml = 'data/processed/data_occ_train.yaml'
    else:
        data_yaml = 'data/processed/data_clean.yaml'
    
    print(f"Data: {data_yaml}")
    
    # Load model
    model = get_model(model_name)
    
    # Setup masking if needed
    masking_manager = None
    if session_config['mask_location'] is not None:
        mask_location = session_config['mask_location']
        model_type = get_model_type(model_name)
        layer_prefixes = get_mask_prefixes(model_type, mask_location)
        
        print(f"Masking: {mask_location} -> layers {layer_prefixes}")
        print(f"Masking params: p_apply={p_apply}, p_channels={p_channels}")
        
        # Add masking hooks
        masking_manager = MaskingManager(model.model, p_apply, p_channels)
        num_hooks = masking_manager.add_masking_to_layers(layer_prefixes)
        print(f"Added {num_hooks} masking hooks")
    else:
        print("Masking: None")
    
    # Train
    try:
        results = model.train(
            data=data_yaml,
            epochs=epochs,
            imgsz=IMGSZ,
            batch=BATCH,
            patience=PATIENCE,
            save=True,
            project=str(output_dir),
            name=run_name,
            exist_ok=True,
            pretrained=True,
            optimizer='auto',
            verbose=True,
            seed=42
        )
        
        # Mark as done
        run_dir.mkdir(parents=True, exist_ok=True)
        done_marker.touch()
        
        # Save masking summary if used
        if masking_manager:
            masking_summary = masking_manager.get_summary()
            with open(run_dir / "masking_summary.json", 'w') as f:
                json.dump(masking_summary, f, indent=2)
        
        print(f"\n Training complete: {run_name}")
        
        # ============================================================
        # BACKUP TO GOOGLE DRIVE IMMEDIATELY AFTER TRAINING
        # ============================================================
        if DRIVE_MOUNTED:
            print("Backing up to Google Drive...")
            backup_session_to_drive(run_dir, run_name)
        
        return {
            "status": "success",
            "run_dir": str(run_dir),
            "weights_path": str(run_dir / "weights" / "best.pt")
        }
        
    except Exception as e:
        print(f"\n Training FAILED: {run_name}")
        print(f"Error: {e}")
        
        # Mark as failed
        run_dir.mkdir(parents=True, exist_ok=True)
        (run_dir / "FAILED").write_text(str(e))
        
        return {"status": "failed", "error": str(e), "run_dir": str(run_dir)}
    
    finally:
        # Clean up masking hooks
        if masking_manager:
            masking_manager.remove_all_hooks()

print("Training functions defined!")

## 7. Run All Training Sessions

In [None]:
# Run all training sessions
output_dir = Path(OUTPUT_ROOT)
output_dir.mkdir(parents=True, exist_ok=True)

# ============================================================
# RESTORE FROM GOOGLE DRIVE IF AVAILABLE
# ============================================================
# This allows continuing after Colab disconnection
if DRIVE_MOUNTED:
    print("Checking Google Drive for previous backups...")
    restore_result = restore_from_drive(output_dir)
    if restore_result['restored']:
        print(f"Restored {len(restore_result['restored'])} sessions from Drive:")
        for name in restore_result['restored']:
            print(f"  - {name}")
    else:
        print("No previous backups found (starting fresh)")
    print()

# ============================================================
# RUN TRAINING
# ============================================================
training_results = []

total_sessions = len(MODELS) * len(SESSIONS_TO_RUN)
current = 0

for model_name in MODELS:
    for session_name in SESSIONS_TO_RUN:
        current += 1
        print(f"\n[{current}/{total_sessions}] Starting {model_name} - {session_name}")
        
        result = train_session(
            model_name=model_name,
            session_name=session_name,
            epochs=EPOCHS,
            output_dir=output_dir,
            p_apply=P_APPLY,
            p_channels=P_CHANNELS
        )
        
        result['model'] = model_name
        result['session'] = session_name
        training_results.append(result)

# Summary
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
for r in training_results:
    status_icon = "" if r['status'] == 'success' else "" if r['status'] == 'skipped' else ""
    print(f"{status_icon} {r['model']}__{r['session']}: {r['status']}")

# Save results
with open(output_dir / "training_results.json", 'w') as f:
    json.dump(training_results, f, indent=2)
print(f"\nResults saved to {output_dir / 'training_results.json'}")

## 8. Evaluation Functions

In [None]:
# Import evaluation modules
from evaluation.io import load_predictions, load_ground_truth, load_class_names
from evaluation.metrics import (
    eval_detection_prf_at_iou,
    eval_per_class_metrics_and_confusions,
    eval_counting_quality
)
from evaluation.plots import plot_all_metrics
from tqdm import tqdm


def generate_predictions(model_path: str, test_images_dir: str, test_index: dict) -> list:
    """
    Generate predictions for a trained model on a test set.
    """
    # Load model
    if 'rtdetr' in model_path.lower():
        model = RTDETR(model_path)
    else:
        model = YOLO(model_path)
    
    predictions = []
    test_images_dir = Path(test_images_dir)
    
    for img_data in tqdm(test_index['images'], desc="Inference"):
        image_path = test_images_dir / img_data['image_filename']
        
        if not image_path.exists():
            continue
        
        results = model.predict(
            source=str(image_path),
            conf=0.01,
            imgsz=640,
            verbose=False
        )[0]
        
        detections = []
        if len(results.boxes) > 0:
            for i in range(len(results.boxes)):
                detections.append({
                    "class_id": int(results.boxes.cls[i].item()),
                    "class_name": results.names[int(results.boxes.cls[i].item())],
                    "confidence": float(results.boxes.conf[i].item()),
                    "bbox": results.boxes.xyxy[i].tolist(),
                    "bbox_format": "xyxy"
                })
        
        predictions.append({
            "image_id": img_data['image_id'],
            "detections": detections
        })
    
    return predictions


def evaluate_session(
    model_name: str,
    session_name: str,
    weights_path: str,
    test_type: str,  # 'clean' or 'occluded'
    output_dir: Path
) -> dict:
    """
    Evaluate a trained model on a test set.
    """
    run_name = f"{model_name}__{session_name}"
    eval_name = f"{run_name}__test_{test_type}"
    eval_dir = output_dir / "evaluations" / eval_name
    eval_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\nEvaluating: {eval_name}")
    
    # Select test set
    if test_type == 'clean':
        test_images_dir = "data/raw/test/images"
    else:
        test_images_dir = f"data/synthetic_occlusion/{OCCLUSION_LEVEL}/images"
    
    # Load test index
    with open("data/processed/evaluation/test_index.json") as f:
        test_index = json.load(f)
    
    # Generate predictions
    predictions = generate_predictions(weights_path, test_images_dir, test_index)
    
    # Save predictions
    pred_json = {
        "run_id": eval_name,
        "split": "test",
        "test_type": test_type,
        "model_family": get_model_type(model_name),
        "predictions": predictions
    }
    pred_path = eval_dir / "predictions.json"
    with open(pred_path, 'w') as f:
        json.dump(pred_json, f)
    
    # Load for evaluation
    preds = predictions
    gts = load_ground_truth("data/processed/evaluation/test_index.json")
    class_names = load_class_names("data/processed/evaluation/test_index.json")
    
    # Run metrics
    threshold_sweep = eval_detection_prf_at_iou(
        preds, gts,
        iou_threshold=0.5,
        conf_thresholds=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    )
    
    # Find best threshold
    best_thr = max(threshold_sweep.keys(), key=lambda k: threshold_sweep[k]['f1'])
    best_metrics = threshold_sweep[best_thr]
    
    per_class = eval_per_class_metrics_and_confusions(
        preds, gts,
        conf_threshold=float(best_thr),
        class_names=class_names
    )
    
    counting = eval_counting_quality(
        preds, gts,
        conf_threshold=float(best_thr),
        class_names=class_names
    )
    
    # Generate plots
    plot_all_metrics(
        threshold_sweep=threshold_sweep,
        per_class_results=per_class['per_class'],
        confusion_data=per_class,
        counting_results=counting,
        output_dir=str(eval_dir),
        run_name=eval_name
    )
    
    # Save metrics
    metrics = {
        "run_name": eval_name,
        "model": model_name,
        "session": session_name,
        "test_type": test_type,
        "best_threshold": float(best_thr),
        "precision": best_metrics['precision'],
        "recall": best_metrics['recall'],
        "f1": best_metrics['f1'],
        "tp": best_metrics['tp'],
        "fp": best_metrics['fp'],
        "fn": best_metrics['fn'],
        "count_mae_matched": counting['matched_only']['global_mae'],
        "count_mae_all": counting['all_predictions']['global_mae']
    }
    
    with open(eval_dir / "metrics.json", 'w') as f:
        json.dump(metrics, f, indent=2)
    
    print(f"  Best F1: {best_metrics['f1']:.4f} @ conf={best_thr}")
    
    return metrics

print("Evaluation functions defined!")

## 9. Run All Evaluations

In [None]:
# Run evaluations on both clean and occluded test sets
# With crash recovery: skip already-completed evaluations

all_metrics = []

# First, load any existing metrics from previous runs
existing_metrics_file = output_dir / "all_metrics.json"
if existing_metrics_file.exists():
    with open(existing_metrics_file) as f:
        existing_data = json.load(f)
    print(f"Found {len(existing_data)} existing evaluation results")


def is_eval_done(model_name, session_name, test_type):
    """Check if evaluation already completed."""
    eval_dir = output_dir / "evaluations" / f"{model_name}__{session_name}__test_{test_type}"
    return (eval_dir / "metrics.json").exists()


def load_existing_eval(model_name, session_name, test_type):
    """Load existing evaluation metrics."""
    eval_dir = output_dir / "evaluations" / f"{model_name}__{session_name}__test_{test_type}"
    with open(eval_dir / "metrics.json") as f:
        return json.load(f)


# Scan for all completed training sessions (not just from this run)
for model_name in MODELS:
    for session_name in SESSIONS_TO_RUN:
        run_dir = output_dir / f"{model_name}__{session_name}"
        done_marker = run_dir / "DONE"
        
        if not done_marker.exists():
            print(f"SKIP (no training): {model_name}__{session_name}")
            continue
        
        # Find weights
        weights_candidates = [
            run_dir / "weights" / "best.pt",
            output_dir / f"{model_name}__{session_name}" / "weights" / "best.pt"
        ]
        
        weights_path = None
        for candidate in weights_candidates:
            if candidate.exists():
                weights_path = str(candidate)
                break
        
        if not weights_path:
            found = list(output_dir.rglob(f"*{model_name}__{session_name}*/weights/best.pt"))
            if found:
                weights_path = str(found[0])
        
        if not weights_path:
            print(f"WARNING: Weights not found for {model_name}__{session_name}")
            continue
        
        # Evaluate on both test sets
        for test_type in ['clean', 'occluded']:
            eval_name = f"{model_name}__{session_name}__test_{test_type}"
            
            if is_eval_done(model_name, session_name, test_type):
                print(f"SKIP (already done): {eval_name}")
                metrics = load_existing_eval(model_name, session_name, test_type)
                all_metrics.append(metrics)
            else:
                metrics = evaluate_session(
                    model_name, session_name, weights_path, test_type, output_dir
                )
                all_metrics.append(metrics)
                
                # Backup evaluation to Google Drive
                if DRIVE_MOUNTED:
                    eval_dir = output_dir / "evaluations" / eval_name
                    backup_evaluation_to_drive(eval_dir, eval_name)
                
                # Save after each evaluation for crash recovery
                with open(output_dir / "all_metrics.json", 'w') as f:
                    json.dump(all_metrics, f, indent=2)

# Final save
with open(output_dir / "all_metrics.json", 'w') as f:
    json.dump(all_metrics, f, indent=2)

# Backup final summary to Drive
if DRIVE_MOUNTED:
    try:
        shutil.copy2(output_dir / "all_metrics.json", DRIVE_BACKUP_PATH / "all_metrics.json")
        print("Backed up all_metrics.json to Google Drive")
    except Exception as e:
        print(f"Could not backup all_metrics.json: {e}")

print(f"\n All evaluations complete! Results in {output_dir}")

## 10. Generate Summary Table & Comparison Plots

In [None]:
# Create summary DataFrame
df = pd.DataFrame(all_metrics)

# Pivot for nice display
summary_clean = df[df['test_type'] == 'clean'][['model', 'session', 'f1', 'precision', 'recall']]
summary_clean = summary_clean.rename(columns={'f1': 'F1_clean', 'precision': 'P_clean', 'recall': 'R_clean'})

summary_occ = df[df['test_type'] == 'occluded'][['model', 'session', 'f1', 'precision', 'recall']]
summary_occ = summary_occ.rename(columns={'f1': 'F1_occ', 'precision': 'P_occ', 'recall': 'R_occ'})

# Merge
summary = pd.merge(summary_clean, summary_occ, on=['model', 'session'])

# Display
print("\n" + "="*80)
print("EXPERIMENT 3 RESULTS SUMMARY")
print("="*80)
print(summary.to_string(index=False))

# Save CSV
summary.to_csv(output_dir / "summary_metrics.csv", index=False)
print(f"\nSaved to {output_dir / 'summary_metrics.csv'}")

In [None]:
# Generate comparison bar plots
import matplotlib.pyplot as plt
import numpy as np

def plot_comparison(df, metric_col, test_type, title, output_path):
    """Create grouped bar chart comparing models across sessions."""
    data = df[df['test_type'] == test_type]
    
    sessions = data['session'].unique()
    models = data['model'].unique()
    
    x = np.arange(len(sessions))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(14, 6))
    
    for i, model in enumerate(models):
        model_data = data[data['model'] == model]
        values = [model_data[model_data['session'] == s][metric_col].values[0] 
                  if len(model_data[model_data['session'] == s]) > 0 else 0 
                  for s in sessions]
        offset = width * (i - len(models)/2 + 0.5)
        bars = ax.bar(x + offset, values, width, label=model)
        
        # Add value labels
        for bar, val in zip(bars, values):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                   f'{val:.3f}', ha='center', va='bottom', fontsize=8)
    
    ax.set_xlabel('Session')
    ax.set_ylabel(metric_col.upper())
    ax.set_title(title)
    ax.set_xticks(x)
    ax.set_xticklabels([s.replace('_', '\n') for s in sessions], fontsize=9)
    ax.legend()
    ax.set_ylim(0, 1.0)
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved: {output_path}")

# Create plots directory
plots_dir = output_dir / "plots"
plots_dir.mkdir(exist_ok=True)

# F1 on Clean Test Set
plot_comparison(
    df, 'f1', 'clean',
    'Experiment 3: F1 Score on Clean Test Set',
    plots_dir / "comparison_f1_clean.png"
)

# F1 on Occluded Test Set
plot_comparison(
    df, 'f1', 'occluded',
    f'Experiment 3: F1 Score on Occluded Test Set ({OCCLUSION_LEVEL})',
    plots_dir / "comparison_f1_occluded.png"
)

## 11. Final Summary

In [None]:
print("\n" + "="*80)
print("EXPERIMENT 3 COMPLETE!")
print("="*80)

print(f"\nConfiguration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Models: {MODELS}")
print(f"  Sessions: {SESSIONS_TO_RUN}")
print(f"  Masking: p_apply={P_APPLY}, p_channels={P_CHANNELS}")

print(f"\nOutputs saved to: {output_dir}")
print(f"  - training_results.json")
print(f"  - all_metrics.json")
print(f"  - summary_metrics.csv")
print(f"  - plots/comparison_f1_clean.png")
print(f"  - plots/comparison_f1_occluded.png")
print(f"  - evaluations/*/metrics.json (per-session details)")

print("\n" + "="*80)