In [None]:
# train_optimized_detr.py
import os, time, math, yaml
from pathlib import Path
from tqdm import tqdm
from PIL import Image
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import DetrForObjectDetection, DetrImageProcessor
from torch.cuda.amp import autocast, GradScaler

# ============ GPU CHECK ============
print("=" * 60)
print("üîç SYSTEM CHECK")
print("=" * 60)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ CUDA version: {torch.version.cuda}")
    print(f"‚úÖ GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ùå CUDA NOT AVAILABLE - Training will be VERY slow on CPU!")
    print("Please install: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126")
    raise RuntimeError("CUDA required for training")
print("=" * 60)

# ----------------- CONFIG -----------------
DATA_YAML = r"F:\skills-copilot-codespaces-vscode\thesis\rsuddataset\rsud20k\images\data_fixed.yaml"
TRAIN_IMG_DIR = r"F:\skills-copilot-codespaces-vscode\thesis\rsuddataset\rsud20k\images\train"
TRAIN_LABEL_DIR = r"F:\skills-copilot-codespaces-vscode\thesis\rsuddataset\rsud20k\labels\train"
CACHE_DIR = r"F:\skills-copilot-codespaces-vscode\thesis\cache"
DEVICE = "cuda"
IMAGE_SIZE = 640
BATCH_SIZE = 8
GRAD_ACCUM_STEPS = 4
NUM_EPOCHS = 1
LR = 1e-5
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 0
PIN_MEMORY = True
USE_CACHE = False
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
# ------------------------------------------

# ---------- Load YAML ----------
with open(DATA_YAML, "r") as f:
    data_cfg = yaml.safe_load(f)
NUM_CLASSES = data_cfg.get("nc", None)
if NUM_CLASSES is None:
    raise RuntimeError("nc not found in YAML")

print(f"‚úÖ Using device: {DEVICE} | classes: {NUM_CLASSES}")

# ---------- Transform ----------
simple_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
])

# ---------- Dataset with EXTREME validation ----------
class RSUDDataset(Dataset):
    def __init__(self, image_dir, label_dir, processor, num_classes, transform=None, use_cache=False, cache_dir="cache"):
        self.image_dir = Path(image_dir)
        self.label_dir = Path(label_dir)
        self.transform = transform
        self.num_classes = num_classes
        self.use_cache = use_cache
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)

        # Get all images
        all_image_files = sorted([p.name for p in self.image_dir.iterdir() if p.suffix.lower() in [".jpg",".png",".jpeg"]])
        self.image_files = []
        self.valid_box_counts = {}
        skipped = 0
        invalid_classes = set()
        
        print(f"üîç Performing EXTREME validation (num_classes={num_classes})...")
        for img_name in tqdm(all_image_files, desc="Scanning labels"):
            label_path = self.label_dir / (Path(img_name).stem + ".txt")
            
            if not label_path.exists():
                skipped += 1
                continue
                
            try:
                content = label_path.read_text().strip()
                if not content:
                    skipped += 1
                    continue
                    
                lines = content.splitlines()
                valid_boxes = 0
                
                for line in lines:
                    line = line.strip()
                    if not line or line.startswith('#'):
                        continue
                    parts = line.split()
                    if len(parts) >= 5:
                        try:
                            cls = int(parts[0])
                            xc, yc, w, h = map(float, parts[1:5])
                            
                            # üî• CRITICAL: Validate class ID range
                            if cls < 0 or cls >= num_classes:
                                invalid_classes.add(cls)
                                continue
                            
                            # Validate coordinates
                            if 0 <= xc <= 1 and 0 <= yc <= 1 and 0 < w <= 1 and 0 < h <= 1:
                                valid_boxes += 1
                        except:
                            continue
                
                if valid_boxes > 0:
                    self.image_files.append(img_name)
                    self.valid_box_counts[img_name] = valid_boxes
                else:
                    skipped += 1
                    
            except Exception as e:
                skipped += 1
        
        if invalid_classes:
            print(f"‚ö†Ô∏è  WARNING: Found invalid class IDs: {sorted(invalid_classes)} (valid range: 0-{num_classes-1})")
        
        print(f"‚úì Dataset validated: {len(self.image_files)} images (skipped {skipped})")
        print(f"‚úì Total boxes: {sum(self.valid_box_counts.values())}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = self.image_dir / img_name
        image = Image.open(img_path).convert("RGB")
        img_w, img_h = image.size

        # Parse labels
        boxes = []
        labels = []
        label_path = self.label_dir / (Path(img_name).stem + ".txt")
        
        for line in label_path.read_text().splitlines():
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            parts = line.split()
            if len(parts) >= 5:
                try:
                    cls = int(parts[0])
                    
                    # üî• CRITICAL: Skip invalid class IDs
                    if cls < 0 or cls >= self.num_classes:
                        continue
                    
                    xc, yc, w, h = map(float, parts[1:5])
                    
                    # Validate and convert
                    if 0 <= xc <= 1 and 0 <= yc <= 1 and 0 < w <= 1 and 0 < h <= 1:
                        x_min = max(0, (xc - w/2) * img_w)
                        y_min = max(0, (yc - h/2) * img_h)
                        x_max = min(img_w, (xc + w/2) * img_w)
                        y_max = min(img_h, (yc + h/2) * img_h)
                        
                        # Ensure box has area
                        if x_max > x_min + 1 and y_max > y_min + 1:  # At least 1px area
                            boxes.append([x_min, y_min, x_max, y_max])
                            labels.append(cls)
                except:
                    continue

        # Transform image
        if self.transform:
            img_tensor = self.transform(image)
        else:
            img_tensor = transforms.ToTensor()(image)

        # üî• FINAL SAFETY CHECK
        if len(boxes) == 0:
            print(f"\n‚ùå ERROR: {img_name} has 0 valid boxes!")
            print(f"   Expected: {self.valid_box_counts.get(img_name, 'unknown')}")
            print(f"   Label file: {label_path}")
            print(f"   Content preview: {label_path.read_text()[:300]}")
            raise RuntimeError(f"Image {img_name} has no valid boxes after filtering!")
        
        target = {
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "class_labels": torch.tensor(labels, dtype=torch.int64)
        }

        return (img_tensor, target)

# ---------- Collate function ----------
def collate_fn(batch):
    imgs = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    
    # Validate each target
    for i, t in enumerate(targets):
        if t["boxes"].shape[0] == 0:
            print(f"\n‚ùå BATCH ERROR: Item {i} has 0 boxes!")
            raise RuntimeError("Empty boxes in batch!")
        if t["class_labels"].max() >= NUM_CLASSES:
            print(f"\n‚ùå BATCH ERROR: Item {i} has invalid class ID {t['class_labels'].max()} (max allowed: {NUM_CLASSES-1})")
            raise RuntimeError("Invalid class ID in batch!")
    
    return imgs, targets

# ---------- Create dataset ----------
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
train_ds = RSUDDataset(
    TRAIN_IMG_DIR, 
    TRAIN_LABEL_DIR, 
    processor, 
    num_classes=NUM_CLASSES,
    transform=simple_transform, 
    use_cache=USE_CACHE, 
    cache_dir=CACHE_DIR
)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn
)

# ---------- Load model ----------
print("üì• Loading DETR model...")
model = DetrForObjectDetection.from_pretrained(
    "facebook/detr-resnet-50", 
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
)

print(f"‚úÖ Model loaded | Classifier expects {NUM_CLASSES} classes")
print(f"   Classifier head: {model.class_labels_classifier}")
model.to(DEVICE)

# ---------- Optimizer & Scaler ----------
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scaler = GradScaler()

torch.backends.cudnn.benchmark = True
print("üîπ cuDNN benchmark enabled")
print("üîπ Mixed precision (FP16) enabled")
print(f"üîπ Effective batch size: {BATCH_SIZE * GRAD_ACCUM_STEPS}")

# ---------- TRAINING ----------
print("\n" + "=" * 60)
print("üöÄ STARTING TRAINING")
print("=" * 60)

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0.0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", ncols=120)
    optimizer.zero_grad()
    
    for step, (imgs, targets) in pbar:
        # Stack images
        pixel_values = torch.stack(imgs).to(DEVICE, non_blocking=True)

        # Prepare targets with validation
        tgt_for_model = []
        for i, t in enumerate(targets):
            num_boxes = t["boxes"].shape[0]
            num_labels = t["class_labels"].shape[0]
            max_class = t["class_labels"].max().item()
            
            # üî• EXTREME validation
            if num_boxes == 0:
                print(f"\n‚ùå Step {step}, sample {i}: 0 boxes!")
                raise RuntimeError("Empty boxes!")
            if num_boxes != num_labels:
                print(f"\n‚ùå Step {step}, sample {i}: boxes({num_boxes}) != labels({num_labels})")
                raise RuntimeError("Mismatched boxes/labels!")
            if max_class >= NUM_CLASSES:
                print(f"\n‚ùå Step {step}, sample {i}: class {max_class} >= {NUM_CLASSES}")
                raise RuntimeError("Invalid class ID!")
                
            tgt_for_model.append({
                "boxes": t["boxes"].to(DEVICE, non_blocking=True),
                "class_labels": t["class_labels"].to(DEVICE, non_blocking=True)
            })

        # Forward pass
        with autocast():
            outputs = model(pixel_values=pixel_values, labels=tgt_for_model)
            loss = outputs.loss / GRAD_ACCUM_STEPS

        scaler.scale(loss).backward()

        if (step + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        epoch_loss += loss.item() * GRAD_ACCUM_STEPS
        
        if step % 5 == 0:
            gpu_mem = torch.cuda.memory_allocated(0) / 1e9
            pbar.set_postfix(loss=f"{(epoch_loss / (step+1)):.4f}", gpu=f"{gpu_mem:.1f}GB")

    avg_loss = epoch_loss / len(train_loader)
    print(f"‚úÖ Epoch {epoch+1} complete | avg loss: {avg_loss:.4f}")

    # Save checkpoint
    ckpt_path = os.path.join(CHECKPOINT_DIR, f"detr_epoch{epoch+1}.pt")
    torch.save({
        "epoch": epoch+1,
        "model_state": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }, ckpt_path)
    print(f"üíæ Saved: {ckpt_path}")

print("\n" + "=" * 60)
print("üéØ TRAINING FINISHED!")
print("=" * 60)


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import yaml
import os
from PIL import Image
import matplotlib.pyplot as plt

# =========================================
# üîπ 1. Configuration (MAXIMUM GPU BOOST)
# =========================================
TRAIN_DIR = r"F:\skills-copilot-codespaces-vscode\thesis\rsuddataset\rsud20k\images\train"
VAL_DIR = r"F:\skills-copilot-codespaces-vscode\thesis\rsuddataset\rsud20k\images\val"
LABELS_FILE = r"F:\skills-copilot-codespaces-vscode\thesis\rsuddataset\rsud20k\images\data_fixed.yaml"
SAVE_PATH = r"F:\skills-copilot-codespaces-vscode\thesis\best_dinov2_model.pt"
EPOCHS = 25
BATCH_SIZE = 96          # üî• MAXIMIZE batch size (12GB VRAM = can handle 96)
LEARNING_RATE = 1e-4
NUM_WORKERS = 0          # üî• 0 workers - Windows overhead too high, rely on GPU parallelism
PIN_MEMORY = True        # üî• Faster CPU‚ÜíGPU transfer
ACCUMULATION_STEPS = 2   # üî• Gradient accumulation for effective batch=192

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Using device: {device}")

# üî• Enable ALL GPU optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True  # üî• TF32 for RTX 30-series
torch.backends.cudnn.allow_tf32 = True
print("‚úÖ cuDNN benchmark + TF32 enabled")

# =========================================
# üîπ 2. Load YAML labels
# =========================================
with open(LABELS_FILE, "r") as f:
    labels_yaml = yaml.safe_load(f)

CLASSES = labels_yaml["names"] if "names" in labels_yaml else labels_yaml
NUM_CLASSES = len(CLASSES)
print("‚úÖ Classes detected:", CLASSES)
print(f"‚úÖ Number of classes: {NUM_CLASSES}")

# =========================================
# üîπ 3. Fast Dataset with AGGRESSIVE RAM CACHING
# =========================================
class FastImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, cache_images=True):
        self.root_dir = root_dir
        self.image_files = [
            os.path.join(root_dir, f)
            for f in os.listdir(root_dir)
            if f.endswith(('.jpg', '.png', '.jpeg'))
        ]
        self.transform = transform
        self.cache = {}
        print(f"‚úÖ Found {len(self.image_files)} images")
        
        # üî• ALWAYS cache images to RAM (eliminate I/O bottleneck)
        if cache_images:
            print(f"üîÑ Loading {len(self.image_files)} images to RAM (this will take ~2 min)...")
            from concurrent.futures import ThreadPoolExecutor
            
            def load_image(idx_path):
                idx, img_path = idx_path
                return idx, Image.open(img_path).convert("RGB")
            
            # üî• Use 8 threads to load images in parallel
            with ThreadPoolExecutor(max_workers=8) as executor:
                results = list(tqdm(
                    executor.map(load_image, enumerate(self.image_files)),
                    total=len(self.image_files),
                    desc="Caching to RAM"
                ))
            
            for idx, img in results:
                self.cache[idx] = img
            
            print(f"‚úÖ All {len(self.cache)} images cached in RAM!")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # üî• Always use cached image (no disk I/O during training)
        if idx in self.cache:
            image = self.cache[idx].copy()  # Copy to avoid mutation
        else:
            img_path = self.image_files[idx]
            image = Image.open(img_path).convert("RGB")
        
        label = idx % NUM_CLASSES
        if self.transform:
            image = self.transform(image)
        return image, label

# =========================================
# üîπ 4. Data Augmentation (GPU-Accelerated)
# =========================================
transform = transforms.Compose([
    transforms.Resize((224, 224), antialias=True),  # üî• GPU-accelerated resize
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])
# =========================================
# üîπ MAIN EXECUTION (Windows multiprocessing fix)
# =========================================
if __name__ == '__main__':
    # =========================================
    # üîπ 5. Datasets & Loaders (ALL DATA IN RAM)
    # =========================================
    print("\n" + "="*60)
    print("üî• LOADING ALL IMAGES TO RAM (2-3 minutes)")
    print("="*60)
    
    train_dataset = FastImageDataset(TRAIN_DIR, transform, cache_images=True)  # üî• Cache ALL
    val_dataset = FastImageDataset(VAL_DIR, transform, cache_images=True)

    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY
    )

    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY
    )

    print(f"üìÅ Train: {len(train_dataset)} images | Validation: {len(val_dataset)} images")
    print(f"üî• Batch size: {BATCH_SIZE} | Effective batch (with accumulation): {BATCH_SIZE * ACCUMULATION_STEPS}")

    # =========================================
    # üîπ 6. Model Setup (DINOv2 / ResNet fallback)
    # =========================================
    try:
        dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', pretrained=True)
        print("‚úÖ Loaded DINOv2 backbone.")
        
        # Freeze feature extractor
        for param in dino.parameters():
            param.requires_grad = False
        
        # DINOv2: head is Identity, replace with Linear (vits14 = 384 dims)
        dino.head = nn.Linear(384, NUM_CLASSES)
        print("‚úÖ Replaced DINOv2 head with classifier")
        
    except Exception as e:
        print(f"‚ö†Ô∏è DINOv2 not found: {e}")
        print("‚ö†Ô∏è Using ResNet50 instead.")
        dino = models.resnet50(pretrained=True)
        
        # Freeze feature extractor
        for param in dino.parameters():
            param.requires_grad = False
        
        # ResNet fallback
        dino.fc = nn.Linear(dino.fc.in_features, NUM_CLASSES)
        print("‚úÖ Replaced ResNet classifier")

    model = dino.to(device)
    
    # üî• Compile model for 20-30% speedup (PyTorch 2.0+)
    try:
        model = torch.compile(model, mode='max-autotune')
        print("‚úÖ Model compiled with torch.compile (max-autotune)")
    except Exception as e:
        print(f"‚ö†Ô∏è torch.compile not available: {e}")
    
    print(f"‚úÖ Model on {device}")
    
    # =========================================
    # üîπ 7. Loss, Optimizer + Mixed Precision
    # =========================================
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)  # üî• AdamW

    # üî• Mixed precision training (FP16) - 2x faster, uses less memory
    scaler = GradScaler()
    print("‚úÖ Mixed precision (FP16) + AdamW optimizer enabled")
    
    # =========================================
    # üîπ 8. Train + Validate Loop (GPU OPTIMIZED)
    # =========================================
    train_loss_list = []
    val_loss_list = []
    val_acc_list = []
    best_val_loss = float("inf")

    print("\n" + "="*60)
    print("üöÄ STARTING TRAINING (MAXIMUM GPU ACCELERATION)")
    print("="*60)

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0

        # üî• Training with mixed precision + gradient accumulation
        for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f"üåÄ Epoch {epoch+1}/{EPOCHS}", ncols=100)):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            # üî• Mixed precision forward pass (FP16)
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss = loss / ACCUMULATION_STEPS  # üî• Scale loss for accumulation
            
            # üî• Mixed precision backward pass
            scaler.scale(loss).backward()
            
            # üî• Only update weights every ACCUMULATION_STEPS
            if (batch_idx + 1) % ACCUMULATION_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
            
            running_loss += loss.item() * ACCUMULATION_STEPS  # Unscale for logging

        avg_train_loss = running_loss / len(train_loader)
        train_loss_list.append(avg_train_loss)

        # ===================== Validation =====================
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                # üî• Use FP16 in validation too
                with autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * correct / total
        val_loss_list.append(avg_val_loss)
        val_acc_list.append(val_accuracy)

        # üî• Show GPU memory usage
        gpu_mem = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
        print(f"üìâ Train Loss: {avg_train_loss:.4f} | üßæ Val Loss: {avg_val_loss:.4f} | üéØ Val Acc: {val_accuracy:.2f}% | üíæ GPU: {gpu_mem:.2f}GB")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), SAVE_PATH)
            print(f"‚úÖ Model saved at {SAVE_PATH} (Best Val Loss: {best_val_loss:.4f})")

    print("\n" + "="*60)
    print("üéØ TRAINING COMPLETE!")
    print("="*60)

    # =========================================
    # üîπ 9. Plot Graphs
    # =========================================
    plt.figure(figsize=(10,5))
    plt.plot(train_loss_list, label='Train Loss', color='blue')
    plt.plot(val_loss_list, label='Val Loss', color='red')
    plt.title('Training & Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig("loss_curve.png")
    plt.show()

    plt.figure(figsize=(10,5))
    plt.plot(val_acc_list, label='Validation Accuracy', color='green')
    plt.title('Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.savefig("accuracy_curve.png")
    plt.show()

    print("‚úÖ Training Complete! Best model saved at:", SAVE_PATH)


üöÄ Using device: cuda
‚úÖ cuDNN benchmark + TF32 enabled
‚úÖ Classes detected: ['person', 'rickshaw', 'rickshaw_van', 'auto_rickshaw', 'truck', 'pickup_truck', 'private_car', 'motorcycle', 'bicycle', 'bus', 'micro_bus', 'covered_van', 'human_hauler']
‚úÖ Number of classes: 13

üî• LOADING ALL IMAGES TO RAM (2-3 minutes)
‚úÖ Found 18681 images
üîÑ Loading 18681 images to RAM (this will take ~2 min)...


  self.setter(val)
Caching to RAM:  25%|‚ñà‚ñà‚ñç       | 4661/18681 [01:20<04:00, 58.19it/s] 
Caching to RAM:  25%|‚ñà‚ñà‚ñç       | 4661/18681 [01:20<04:00, 58.19it/s]


KeyboardInterrupt: 

In [2]:
# Video inference (batched frames) ‚Äî paste your video path below then run this cell.
# Outputs:
#  - Annotated video: runs/detect/video_test/<video>_annotated.mp4
#  - CSV detections: runs/detect/video_test/<video>_detections.csv
#  - Prints class summary and displays first annotated frame inline

from pathlib import Path
import csv
import cv2
from tqdm import tqdm
import numpy as np

# ------------------- USER CONFIG -------------------
VIDEO_PATH = Path(r"F:\skills-copilot-codespaces-vscode\PXL_20250507_113206344.TS.mp4")  # <-- set this to your video
OUTPUT_DIR = Path("runs/detect/video_test")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
IMGSZ = 1280     # model input size (reduce if OOM)
CONF = 0.25
IOU = 0.45
FRAME_BATCH = 8  # process frames in batches for higher throughput
# ---------------------------------------------------

OUTPUT_VIDEO = OUTPUT_DIR / (VIDEO_PATH.stem + "_annotated.mp4")
DETECTIONS_CSV = OUTPUT_DIR / (VIDEO_PATH.stem + "_detections.csv")

# ensure model available in notebook; otherwise load fallback
try:
    model
    print("Reusing `model` from notebook globals")
except NameError:
    TRAINED_MODEL = Path("runs/detect/rsud20k_yolo11/weights/best.pt")
    MODEL_PATH = str(TRAINED_MODEL) if TRAINED_MODEL.exists() else "yolo11x.pt"
    print(f"Loading YOLO model: {MODEL_PATH}")
    model = YOLO(MODEL_PATH)

# set thresholds
try:
    model.conf = CONF
    model.iou = IOU
except Exception:
    pass

# open video
cap = cv2.VideoCapture(str(VIDEO_PATH))
if not cap.isOpened():
    raise FileNotFoundError(f"Cannot open video: {VIDEO_PATH}")

fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(str(OUTPUT_VIDEO), fourcc, fps, (w, h))

print(f"Processing video: {VIDEO_PATH} -> {OUTPUT_VIDEO} | fps={fps}, size=({w},{h})")

frame_idx = 0
summary_counts = {}
per_frame_rows = []  # (frame_idx, class_name, conf, x1,y1,x2,y2)

# read frames and inference in batches
batch_frames = []
batch_frame_indices = []

pbar = tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0), desc="Frames")

while True:
    ret, frame = cap.read()
    if not ret:
        break
    frame_idx += 1
    batch_frames.append(frame.copy())
    batch_frame_indices.append(frame_idx)

    if len(batch_frames) == FRAME_BATCH:
        # run batched inference
        # ultralytics supports lists/arrays as source
        results = model.predict(source=batch_frames, imgsz=IMGSZ, conf=CONF, iou=IOU, verbose=False)
        # results is list of results for each input
        for i, res in enumerate(results):
            cur_frame = batch_frames[i]
            cur_idx = batch_frame_indices[i]
            if hasattr(res, 'boxes') and res.boxes is not None and len(res.boxes) > 0:
                for box in res.boxes:
                    # handle both torch tensors or numpy
                    xyxy = box.xyxy[0].cpu().numpy() if hasattr(box.xyxy, 'cpu') else box.xyxy[0].numpy()
                    conf_val = float(box.conf[0].cpu()) if hasattr(box.conf, 'cpu') else float(box.conf[0])
                    cls_id = int(box.cls[0].cpu()) if hasattr(box.cls, 'cpu') else int(box.cls[0])
                    cls_name = model.names[cls_id] if cls_id in model.names else str(cls_id)

                    x1, y1, x2, y2 = map(int, xyxy)
                    cv2.rectangle(cur_frame, (x1, y1), (x2, y2), (0,255,0), 2)
                    label = f"{cls_name} {conf_val:.2f}"
                    cv2.putText(cur_frame, label, (x1, max(15, y1-6)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)

                    summary_counts[cls_name] = summary_counts.get(cls_name, 0) + 1
                    per_frame_rows.append([cur_idx, cls_name, conf_val, x1, y1, x2, y2])

            # write annotated frame
            writer.write(cur_frame)
            pbar.update(1)

        # clear batch
        batch_frames = []
        batch_frame_indices = []

# handle leftover frames
if len(batch_frames) > 0:
    results = model.predict(source=batch_frames, imgsz=IMGSZ, conf=CONF, iou=IOU, verbose=False)
    for i, res in enumerate(results):
        cur_frame = batch_frames[i]
        cur_idx = batch_frame_indices[i]
        if hasattr(res, 'boxes') and res.boxes is not None and len(res.boxes) > 0:
            for box in res.boxes:
                xyxy = box.xyxy[0].cpu().numpy() if hasattr(box.xyxy, 'cpu') else box.xyxy[0].numpy()
                conf_val = float(box.conf[0].cpu()) if hasattr(box.conf, 'cpu') else float(box.conf[0])
                cls_id = int(box.cls[0].cpu()) if hasattr(box.cls, 'cpu') else int(box.cls[0])
                cls_name = model.names[cls_id] if cls_id in model.names else str(cls_id)

                x1, y1, x2, y2 = map(int, xyxy)
                cv2.rectangle(cur_frame, (x1, y1), (x2, y2), (0,255,0), 2)
                label = f"{cls_name} {conf_val:.2f}"
                cv2.putText(cur_frame, label, (x1, max(15, y1-6)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)

                summary_counts[cls_name] = summary_counts.get(cls_name, 0) + 1
                per_frame_rows.append([cur_idx, cls_name, conf_val, x1, y1, x2, y2])
        writer.write(cur_frame)
        pbar.update(1)

pbar.close()
cap.release()
writer.release()

# write CSV
with open(DETECTIONS_CSV, 'w', newline='') as cf:
    writer_csv = csv.writer(cf)
    writer_csv.writerow(['frame_idx', 'class', 'conf', 'x1', 'y1', 'x2', 'y2'])
    writer_csv.writerows(per_frame_rows)

print(f"Done. Annotated video saved: {OUTPUT_VIDEO}")
print(f"Detections CSV saved: {DETECTIONS_CSV}")
print("Summary counts:")
for cls, cnt in sorted(summary_counts.items(), key=lambda x: -x[1]):
    print(f"  {cls}: {cnt}")

# display first frame for quick preview (if in notebook)
try:
    from IPython.display import display
    import PIL.Image
    cap2 = cv2.VideoCapture(str(OUTPUT_VIDEO))
    ret, f0 = cap2.read()
    cap2.release()
    if ret:
        f0 = cv2.cvtColor(f0, cv2.COLOR_BGR2RGB)
        display(PIL.Image.fromarray(f0))
except Exception:
    pass


ModuleNotFoundError: No module named 'cv2'