<a href="https://colab.research.google.com/github/lucas6028/aortic_valve_detection/blob/main/predict_hybrid_wbf_optimized.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üöÄ Hybrid Model Ensemble with WBF (Optimized)

## YOLOv8 5-Fold + Faster R-CNN Multi-Model Fusion

This notebook combines predictions from:
1. **YOLOv8 5-Fold Models**: Using ALL 5 folds for maximum generalization
2. **Faster R-CNN Models**: High precision two-stage detector

## 1. Environment Setup

In [1]:
# Check GPU availability
!nvidia-smi

Sat Nov 29 08:33:50 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   41C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
# Fix encoding issues
import locale
def getpreferredencoding(do_setlocale=True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [3]:
# Install required packages
!pip install ultralytics ensemble-boxes torch torchvision pycocotools -q

print("‚úÖ Packages installed successfully")

[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m [32m1.1/1.1 MB[0m [31m45.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.1/1.1 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25h‚úÖ Packages installed successfully


In [4]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## 2. Download Test Dataset

In [5]:
import gdown
import os

# Download test dataset
print("üì• Downloading test dataset...")
gdown.download(
    "https://drive.google.com/uc?export=download&id=1drfMaRrfIL0XBRY16Vq9IJ3nJ2Zr9B60",
    "/content/testing_image.zip"
)
print("‚úÖ Download complete")

üì• Downloading test dataset...


Downloading...
From (original): https://drive.google.com/uc?export=download&id=1drfMaRrfIL0XBRY16Vq9IJ3nJ2Zr9B60
From (redirected): https://drive.google.com/uc?export=download&id=1drfMaRrfIL0XBRY16Vq9IJ3nJ2Zr9B60&confirm=t&uuid=751412a9-67dd-4c55-9537-ae3ae63dfdf1
To: /content/testing_image.zip
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1.83G/1.83G [00:29<00:00, 62.5MB/s]

‚úÖ Download complete





## 3. Prepare Test Images

Split images into two batches to manage Colab RAM limits

In [6]:
import os
import shutil

# Extract test dataset
if not os.path.isdir("./testing_image") and os.path.exists("testing_image.zip"):
    os.makedirs("./testing_image", exist_ok=True)
    !unzip -q testing_image.zip -d ./testing_image

def find_patient_root(root):
    """Find directory containing patient folders"""
    for dirpath, dirnames, filenames in os.walk(root):
        if any(d.startswith("patient") for d in dirnames):
            return dirpath
    return root

TEST_ROOT = find_patient_root("./testing_image")
print(f"Test root: {TEST_ROOT}")

# Validate test dataset (should be patient0051-0100)
patient_folders = [f for f in os.listdir(TEST_ROOT) if f.startswith("patient")]
if patient_folders:
    patient_folders.sort()
    first_patient = patient_folders[0]
    last_patient = patient_folders[-1]
    first_num = int(first_patient.replace("patient", ""))
    last_num = int(last_patient.replace("patient", ""))

    print(f"\nüìä Dataset Validation:")
    print(f"  First patient: {first_patient}")
    print(f"  Last patient: {last_patient}")
    print(f"  Total patients: {len(patient_folders)}")

    if first_num <= 50:
        raise ValueError("‚ö†Ô∏è ERROR: Downloaded TRAINING data instead of TEST data!")
    else:
        print(f"‚úÖ Correct test dataset: patient{first_num:04d}-patient{last_num:04d}")

Test root: ./testing_image/testing_image

üìä Dataset Validation:
  First patient: patient0051
  Last patient: patient0100
  Total patients: 50
‚úÖ Correct test dataset: patient0051-patient0100


In [7]:
# Split images into two batches
dst_root1 = "./datasets/test/images1"
dst_root2 = "./datasets/test/images2"

os.makedirs(dst_root1, exist_ok=True)
os.makedirs(dst_root2, exist_ok=True)

# Collect all image paths
all_files = []
for patient_folder in os.listdir(TEST_ROOT):
    patient_path = os.path.join(TEST_ROOT, patient_folder)
    if os.path.isdir(patient_path) and patient_folder.startswith("patient"):
        for fname in os.listdir(patient_path):
            if fname.endswith(".png"):
                all_files.append(os.path.join(patient_path, fname))

# Sort for reproducibility
all_files.sort()

# Calculate split point
half = len(all_files) // 2

# Copy to batch directories
print(f"üì¶ Splitting {len(all_files)} images into 2 batches...")

for f in all_files[:half]:
    dst_file = os.path.join(dst_root1, os.path.basename(f))
    shutil.copy2(f, dst_file)

for f in all_files[half:]:
    dst_file = os.path.join(dst_root2, os.path.basename(f))
    shutil.copy2(f, dst_file)

print(f"‚úÖ Batch 1: {len(os.listdir(dst_root1))} images ‚Üí {dst_root1}")
print(f"‚úÖ Batch 2: {len(os.listdir(dst_root2))} images ‚Üí {dst_root2}")

üì¶ Splitting 16620 images into 2 batches...
‚úÖ Batch 1: 8310 images ‚Üí ./datasets/test/images1
‚úÖ Batch 2: 8310 images ‚Üí ./datasets/test/images2


## 4. Configuration

Configure model paths and WBF parameters

In [None]:
# Model paths
YOLO_KFOLD_PATH = '/content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold'
FASTER_RCNN_PATH = '/content/drive/MyDrive/AI_CUP_2025/faster_rcnn_checkpoints'

# Detection parameters
CONF_THRESHOLD = 0.0005  # Low threshold to capture all candidates
IMG_SIZE = 640
MAX_INPUT_DET = 50      # Max detections per model to feed into WBF
MAX_OUTPUT_DET = 50

# TTA (Test-Time Augmentation) parameters
ENABLE_TTA = False
TTA_SCALES = [1.0, 0.9, 1.1]
TTA_HORIZONTAL_FLIP = True

# WBF parameters
WBF_IOU_THR = 0.64
WBF_SKIP_BOX_THR = 0.0001

# Model weights
# Optimized for 5 Folds + Faster R-CNN
# Fold 1 (0.99) -> 0.12
# Fold 2 (0.986) -> 0.10
# Fold 3 (0.985) -> 0.10
# Fold 4 (0.985) -> 0.10
# Fold 5 (0.977) -> 0.08
# Faster R-CNN (0.987) -> 0.50
MODEL_WEIGHTS = [0.12, 0.10, 0.10, 0.10, 0.08, 0.50]

print("=" * 80)
print("üìã Configuration:")
print(f"  YOLOv8 K-Fold: {YOLO_KFOLD_PATH}")
print(f"  Faster R-CNN: {FASTER_RCNN_PATH}")
print(f"  Confidence Threshold: {CONF_THRESHOLD}")
print(f"  Max Output Detections: {MAX_OUTPUT_DET} (Enforcing single object)")
print(f"  TTA Enabled: {ENABLE_TTA}")
if ENABLE_TTA:
    print(f"  TTA Scales: {TTA_SCALES}")
    print(f"  TTA Horizontal Flip: {TTA_HORIZONTAL_FLIP}")
print(f"  WBF IoU Threshold: {WBF_IOU_THR}")
print("=" * 80)

üìã Configuration:
  YOLOv8 K-Fold: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold
  Faster R-CNN: /content/drive/MyDrive/AI_CUP_2025/faster_rcnn_checkpoints
  Confidence Threshold: 0.0005
  Max Output Detections: 50 (Enforcing single object)
  TTA Enabled: True
  TTA Scales: [1.0, 0.9, 1.1]
  TTA Horizontal Flip: True
  WBF IoU Threshold: 0.64


## 5. Load Models

Load YOLOv8 5-Fold models and Faster R-CNN model

In [9]:
import torch
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from ultralytics import YOLO

def get_faster_rcnn_model(num_classes=2):
    """Create Faster R-CNN model architecture"""
    model = fasterrcnn_resnet50_fpn_v2(weights=None)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

def load_all_models():
    """Load all YOLOv8 fold models and Faster R-CNN model"""
    models = {
        'yolo': [],
        'faster_rcnn': None
    }

    print("=" * 80)
    print("üîÑ Loading Models...")
    print("=" * 80)

    # Load YOLOv8 5-Fold models
    print("\nüì¶ Loading YOLOv8 K-Fold models:")
    for fold in range(1, 6):  # CHANGED: Load all 5 folds
        model_path = f'{YOLO_KFOLD_PATH}/fold{fold}/weights/best.pt'
        if os.path.exists(model_path):
            model = YOLO(model_path)
            models['yolo'].append(model)
            print(f"  ‚úÖ Fold {fold}: {model_path}")
        else:
            print(f"  ‚ö†Ô∏è Fold {fold}: NOT FOUND at {model_path}")

    print(f"\n‚úÖ Loaded {len(models['yolo'])} YOLOv8 models")

    # Load Faster R-CNN model
    print("\nüì¶ Loading Faster R-CNN model:")
    faster_rcnn_checkpoint = os.path.join(FASTER_RCNN_PATH, 'checkpoint_epoch_39.pth')

    if os.path.exists(faster_rcnn_checkpoint):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        model = get_faster_rcnn_model(num_classes=2)

        checkpoint = torch.load(faster_rcnn_checkpoint, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()

        models['faster_rcnn'] = model
        print(f"  ‚úÖ Faster R-CNN: {faster_rcnn_checkpoint}")
        print(f"     Epoch: {checkpoint['epoch']}, AP@0.5: {checkpoint.get('best_ap', 0.0):.4f}")
    else:
        print(f"  ‚ö†Ô∏è Faster R-CNN: NOT FOUND at {faster_rcnn_checkpoint}")

    print("\n" + "=" * 80)
    print(f"üìä Total models loaded: {len(models['yolo'])} YOLO + {1 if models['faster_rcnn'] else 0} Faster R-CNN")
    print("=" * 80 + "\n")

    return models

# Load all models
all_models = load_all_models()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"\nüíª Using device: {device}")

Creating new Ultralytics Settings v0.0.6 file ‚úÖ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.
üîÑ Loading Models...

üì¶ Loading YOLOv8 K-Fold models:
  ‚úÖ Fold 1: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold/fold1/weights/best.pt
  ‚úÖ Fold 2: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold/fold2/weights/best.pt
  ‚úÖ Fold 3: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold/fold3/weights/best.pt
  ‚úÖ Fold 4: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold/fold4/weights/best.pt
  ‚úÖ Fold 5: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold/fold5/weights/best.pt

‚úÖ Loaded 5 YOLOv8 models

üì¶ Loading Faster R-CNN model:
  ‚úÖ Faster R-CNN: /content/drive/MyDrive/AI_CUP_2025/faster_rcnn_checkpoints/checkpoint_epoch_39.pth
     Epoch: 3

## 6. Hybrid WBF Prediction Functions

Combine predictions from YOLOv8 and Faster R-CNN using WBF

In [10]:
import numpy as np
from ensemble_boxes import weighted_boxes_fusion
from PIL import Image
import torchvision.transforms as T
import torchvision.transforms.functional as F
from tqdm import tqdm
import cv2

def get_image_size(image_path):
    """Get image dimensions"""
    with Image.open(image_path) as img:
        return img.size  # (width, height)

def apply_tta_transform(img_tensor, transform_type, scale=1.0):
    """
    Apply Test-Time Augmentation transform to image tensor
    """
    if transform_type == 'original':
        return img_tensor
    elif transform_type == 'hflip':
        return F.hflip(img_tensor)
    elif transform_type == 'scale':
        if scale != 1.0:
            _, h, w = img_tensor.shape
            new_h, new_w = int(h * scale), int(w * scale)
            return F.resize(img_tensor, [new_h, new_w])
        return img_tensor
    else:
        raise ValueError(f"Unknown transform type: {transform_type}")

def reverse_tta_boxes(boxes, transform_type, orig_width, orig_height, scale=1.0):
    """
    Reverse TTA transformation on bounding boxes (normalized coordinates)
    """
    if not boxes:
        return boxes

    reversed_boxes = []
    for box in boxes:
        x1, y1, x2, y2 = box

        if transform_type == 'hflip':
            # Flip x coordinates
            x1_new = 1.0 - x2
            x2_new = 1.0 - x1
            reversed_boxes.append([x1_new, y1, x2_new, y2])
        elif transform_type == 'scale':
            # Boxes are already normalized, no change needed for scale
            reversed_boxes.append([x1, y1, x2, y2])
        else:
            reversed_boxes.append([x1, y1, x2, y2])

    return reversed_boxes

def predict_yolo_single_image_with_tta(yolo_models, img_path, orig_width, orig_height):
    """
    Get predictions from all YOLOv8 fold models with TTA for a single image
    """
    boxes_list = []
    scores_list = []
    labels_list = []

    # Load image with cv2 for YOLO
    img_cv2 = cv2.imread(img_path)
    if img_cv2 is None:
        return [], [], []

    # Generate TTA variants
    tta_configs = []

    if ENABLE_TTA:
        # Original + scales
        for scale in TTA_SCALES:
            tta_configs.append(('scale', scale))

        # Horizontal flip (only at original scale to avoid too many variants)
        if TTA_HORIZONTAL_FLIP:
            tta_configs.append(('hflip', 1.0))
    else:
        # No TTA, just original
        tta_configs.append(('original', 1.0))

    for model in yolo_models:
        for transform_type, scale in tta_configs:
            # Prepare image source
            if transform_type == 'hflip':
                # Flip image horizontally
                source_img = cv2.flip(img_cv2, 1)
                img_size = IMG_SIZE
            elif transform_type == 'scale':
                # YOLO handles scaling via imgsz, pass original image
                source_img = img_cv2
                img_size = int(IMG_SIZE * scale)
            else:
                source_img = img_cv2
                img_size = IMG_SIZE

            # Predict with YOLO
            results = model.predict(
                source=source_img,
                save=False,
                imgsz=img_size,
                device=0,
                verbose=False,
                conf=CONF_THRESHOLD,
                iou=0.5,
                max_det=MAX_INPUT_DET,
                flipud=False,
                fliplr=False  # We handle flipping manually
            )

            result = results[0]
            boxes = result.boxes

            fold_boxes = []
            fold_scores = []
            fold_labels = []

            if len(boxes.cls) > 0:
                for j in range(len(boxes.cls)):
                    x1, y1, x2, y2 = boxes.xyxy[j].tolist()

                    # Normalize to 0-1
                    x1_norm = x1 / orig_width
                    y1_norm = y1 / orig_height
                    x2_norm = x2 / orig_width
                    y2_norm = y2 / orig_height

                    fold_boxes.append([x1_norm, y1_norm, x2_norm, y2_norm])
                    fold_scores.append(boxes.conf[j].item())
                    fold_labels.append(int(boxes.cls[j].item()))

            # Reverse TTA transformation if needed
            fold_boxes = reverse_tta_boxes(fold_boxes, transform_type, orig_width, orig_height, scale)

            boxes_list.append(fold_boxes)
            scores_list.append(fold_scores)
            labels_list.append(fold_labels)

    return boxes_list, scores_list, labels_list

def predict_faster_rcnn_single_image_with_tta(model, img_path, orig_width, orig_height, device):
    """
    Get predictions from Faster R-CNN model with TTA for a single image
    """
    if model is None:
        return [], [], []

    boxes_list = []
    scores_list = []
    labels_list = []

    # Load image
    img = Image.open(img_path).convert('RGB')
    transform = T.ToTensor()
    img_tensor = transform(img)

    # Generate TTA variants
    tta_configs = []

    if ENABLE_TTA:
        # Original + scales
        for scale in TTA_SCALES:
            tta_configs.append(('scale', scale))

        # Horizontal flip
        if TTA_HORIZONTAL_FLIP:
            tta_configs.append(('hflip', 1.0))
    else:
        # No TTA, just original
        tta_configs.append(('original', 1.0))

    for transform_type, scale in tta_configs:
        # Apply TTA transform
        if transform_type == 'scale' and scale != 1.0:
            transformed_tensor = apply_tta_transform(img_tensor, 'scale', scale)
        elif transform_type == 'hflip':
            transformed_tensor = apply_tta_transform(img_tensor, 'hflip')
        else:
            transformed_tensor = img_tensor

        # Predict
        with torch.no_grad():
            predictions = model([transformed_tensor.to(device)])

        prediction = predictions[0]
        boxes = prediction['boxes'].cpu().numpy()
        scores = prediction['scores'].cpu().numpy()
        labels = prediction['labels'].cpu().numpy()

        # Get transformed image dimensions
        _, h_trans, w_trans = transformed_tensor.shape

        # Filter by confidence and normalize
        variant_boxes = []
        variant_scores = []
        variant_labels = []

        for box, score, label in zip(boxes, scores, labels):
            if score >= CONF_THRESHOLD and label == 1:  # Class 1 = aortic_valve
                x1, y1, x2, y2 = box

                # Normalize to transformed image dimensions first
                x1_norm = x1 / w_trans
                y1_norm = y1 / h_trans
                x2_norm = x2 / w_trans
                y2_norm = y2 / h_trans

                variant_boxes.append([x1_norm, y1_norm, x2_norm, y2_norm])
                variant_scores.append(float(score))
                variant_labels.append(0)  # Convert to 0-indexed for WBF

        # Reverse TTA transformation
        variant_boxes = reverse_tta_boxes(variant_boxes, transform_type, orig_width, orig_height, scale)

        boxes_list.append(variant_boxes)
        scores_list.append(variant_scores)
        labels_list.append(variant_labels)

    return boxes_list, scores_list, labels_list

def predict_batch_hybrid_wbf(models, source_path, output_file):
    """
    Predict on a batch of images using hybrid YOLOv8 + Faster R-CNN ensemble with TTA and WBF
    """
    # Get all image files
    image_files = [f for f in os.listdir(source_path) if f.endswith('.png')]
    image_files.sort()

    # Calculate correct weights for TTA
    # We need to expand MODEL_WEIGHTS to match the number of TTA variants per model

    # Determine number of TTA variants
    num_tta_variants = 1
    if ENABLE_TTA:
        num_tta_variants = len(TTA_SCALES) + (1 if TTA_HORIZONTAL_FLIP else 0)

    # Expand weights
    # MODEL_WEIGHTS structure: [YOLO_Fold1, YOLO_Fold2, ..., Faster_RCNN]
    expanded_weights = []

    # Add weights for YOLO models
    for i in range(len(models['yolo'])):
        weight = MODEL_WEIGHTS[i] if i < len(MODEL_WEIGHTS) else 1.0
        expanded_weights.extend([weight] * num_tta_variants)

    # Add weights for Faster R-CNN
    if models['faster_rcnn'] is not None:
        frcnn_idx = len(models['yolo'])
        weight = MODEL_WEIGHTS[frcnn_idx] if frcnn_idx < len(MODEL_WEIGHTS) else 1.0
        expanded_weights.extend([weight] * num_tta_variants)

    tta_info = ""
    if ENABLE_TTA:
        tta_info = f" with TTA ({num_tta_variants} variants per model)"

    print(f"\nüîÆ Processing {len(image_files)} images with Hybrid WBF ensemble{tta_info}...")
    print(f"   Models: {len(models['yolo'])} YOLOv8 + {1 if models['faster_rcnn'] else 0} Faster R-CNN")
    print(f"   Total prediction lists per image: {len(expanded_weights)}")

    if ENABLE_TTA:
        print(f"   TTA Scales: {TTA_SCALES}")
        print(f"   TTA Flip: {TTA_HORIZONTAL_FLIP}")

    final_predictions = {}

    for img_idx, img_file in enumerate(tqdm(image_files, desc="Predicting")):
        img_path = os.path.join(source_path, img_file)
        filename = img_file.split('.png')[0]

        # Get original image size
        orig_width, orig_height = get_image_size(img_path)

        # Collect predictions from all models and TTA variants
        all_boxes_list = []
        all_scores_list = []
        all_labels_list = []

        # 1. Get YOLOv8 predictions with TTA
        yolo_boxes, yolo_scores, yolo_labels = predict_yolo_single_image_with_tta(
            models['yolo'], img_path, orig_width, orig_height
        )
        all_boxes_list.extend(yolo_boxes)
        all_scores_list.extend(yolo_scores)
        all_labels_list.extend(yolo_labels)

        # 2. Get Faster R-CNN predictions with TTA
        if models['faster_rcnn'] is not None:
            frcnn_boxes, frcnn_scores, frcnn_labels = predict_faster_rcnn_single_image_with_tta(
                models['faster_rcnn'], img_path, orig_width, orig_height, device
            )
            all_boxes_list.extend(frcnn_boxes)
            all_scores_list.extend(frcnn_scores)
            all_labels_list.extend(frcnn_labels)

        # 3. Apply WBF to fuse all predictions (including TTA variants)
        if any(len(boxes) > 0 for boxes in all_boxes_list):
            # Ensure weights match boxes list length
            current_weights = expanded_weights
            if len(all_boxes_list) != len(expanded_weights):
                # Fallback if mismatch (should not happen if logic is correct)
                current_weights = None

            fused_boxes, fused_scores, fused_labels = weighted_boxes_fusion(
                all_boxes_list,
                all_scores_list,
                all_labels_list,
                weights=current_weights,
                iou_thr=WBF_IOU_THR,
                skip_box_thr=WBF_SKIP_BOX_THR
            )

            # Sort by confidence descending
            sorted_indices = np.argsort(fused_scores)[::-1]

            # Keep only top K (Single object constraint)
            if MAX_OUTPUT_DET > 0:
                sorted_indices = sorted_indices[:MAX_OUTPUT_DET]

            # Convert normalized coordinates back to pixels
            final_predictions[filename] = []
            for idx in sorted_indices:
                box = fused_boxes[idx]
                score = fused_scores[idx]
                label = fused_labels[idx]

                x1 = int(box[0] * orig_width)
                y1 = int(box[1] * orig_height)
                x2 = int(box[2] * orig_width)
                y2 = int(box[3] * orig_height)

                final_predictions[filename].append({
                    'label': int(label),
                    'conf': float(score),
                    'box': [x1, y1, x2, y2]
                })

    # Write predictions to file
    print(f"\nüíæ Writing predictions to {output_file}...")
    with open(output_file, 'w') as f:
        for filename, predictions in final_predictions.items():
            for pred in predictions:
                line = f"{filename} {pred['label']} {pred['conf']:.4f} {pred['box'][0]} {pred['box'][1]} {pred['box'][2]} {pred['box'][3]}\n"
                f.write(line)

    num_detections = sum(len(preds) for preds in final_predictions.values())
    num_images_with_detections = len(final_predictions)

    print(f"‚úÖ Predictions saved!")
    print(f"   Total detections: {num_detections}")
    print(f"   Images with detections: {num_images_with_detections}")
    if num_images_with_detections > 0:
        print(f"   Average detections per image: {num_detections/num_images_with_detections:.2f}")

print("‚úÖ Hybrid WBF prediction functions with TTA ready")

‚úÖ Hybrid WBF prediction functions with TTA ready


## 7. Predict Batch 1 with Hybrid WBF

In [11]:
# Create output directory
os.makedirs('/content/predict_txt', exist_ok=True)

print("=" * 80)
print("üöÄ Starting Batch 1 Prediction (Hybrid WBF + TTA)")
print("=" * 80)

predict_batch_hybrid_wbf(
    models=all_models,
    source_path=dst_root1,
    output_file='/content/predict_txt/images1_hybrid_wbf_tta.txt'
)

print("\n‚úÖ Batch 1 complete!")

üöÄ Starting Batch 1 Prediction (Hybrid WBF + TTA)

üîÆ Processing 8310 images with Hybrid WBF ensemble with TTA (4 variants per model)...
   Models: 5 YOLOv8 + 1 Faster R-CNN
   Total prediction lists per image: 24
   TTA Scales: [1.0, 0.9, 1.1]
   TTA Flip: True


Predicting: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8310/8310 [2:21:55<00:00,  1.02s/it]


üíæ Writing predictions to /content/predict_txt/images1_hybrid_wbf_tta.txt...
‚úÖ Predictions saved!
   Total detections: 7156
   Images with detections: 2807
   Average detections per image: 2.55

‚úÖ Batch 1 complete!





## 8. Memory Cleanup

In [12]:
# Clear memory between batches
import torch
import gc

print("üßπ Cleaning up memory...")

# Clear YOLO models
if 'all_models' in locals():
    if 'yolo' in all_models:
        all_models['yolo'].clear()
    if 'faster_rcnn' in all_models and all_models['faster_rcnn'] is not None:
        del all_models['faster_rcnn']
    del all_models

gc.collect()
torch.cuda.empty_cache()

print("‚úÖ Memory cleared")

üßπ Cleaning up memory...
‚úÖ Memory cleared


## 9. Reload Models for Batch 2

In [13]:
print("=" * 80)
print("üîÑ Reloading models for Batch 2...")
print("=" * 80)

all_models = load_all_models()

print("\n‚úÖ Models reloaded successfully")

üîÑ Reloading models for Batch 2...
üîÑ Loading Models...

üì¶ Loading YOLOv8 K-Fold models:
  ‚úÖ Fold 1: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold/fold1/weights/best.pt
  ‚úÖ Fold 2: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold/fold2/weights/best.pt
  ‚úÖ Fold 3: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold/fold3/weights/best.pt
  ‚úÖ Fold 4: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold/fold4/weights/best.pt
  ‚úÖ Fold 5: /content/drive/MyDrive/AI_CUP_2025/aortic_valve_kfold/fold5/weights/best.pt

‚úÖ Loaded 5 YOLOv8 models

üì¶ Loading Faster R-CNN model:
  ‚úÖ Faster R-CNN: /content/drive/MyDrive/AI_CUP_2025/faster_rcnn_checkpoints/checkpoint_epoch_39.pth
     Epoch: 39, AP@0.5: 0.9872

üìä Total models loaded: 5 YOLO + 1 Faster R-CNN


‚úÖ Models reloaded successfully


## 10. Predict Batch 2 with Hybrid WBF

In [14]:
print("=" * 80)
print("üöÄ Starting Batch 2 Prediction (Hybrid WBF + TTA)")
print("=" * 80)

predict_batch_hybrid_wbf(
    models=all_models,
    source_path=dst_root2,
    output_file='/content/predict_txt/images2_hybrid_wbf_tta.txt'
)

print("\n‚úÖ Batch 2 complete!")

üöÄ Starting Batch 2 Prediction (Hybrid WBF + TTA)

üîÆ Processing 8310 images with Hybrid WBF ensemble with TTA (4 variants per model)...
   Models: 5 YOLOv8 + 1 Faster R-CNN
   Total prediction lists per image: 24
   TTA Scales: [1.0, 0.9, 1.1]
   TTA Flip: True


Predicting: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8310/8310 [2:22:01<00:00,  1.03s/it]


üíæ Writing predictions to /content/predict_txt/images2_hybrid_wbf_tta.txt...
‚úÖ Predictions saved!
   Total detections: 7126
   Images with detections: 2720
   Average detections per image: 2.62

‚úÖ Batch 2 complete!





## 11. Merge Results and Save

In [15]:
# Merge two batch result files
file1 = "/content/predict_txt/images1_hybrid_wbf_tta.txt"
file2 = "/content/predict_txt/images2_hybrid_wbf_tta.txt"
output = "/content/predict_txt/submission_hybrid_wbf_optimized.txt"

print("=" * 80)
print("üîó Merging batch results...")
print("=" * 80)

with open(output, "w", encoding="utf-8") as fout:
    for f in [file1, file2]:
        if os.path.exists(f):
            with open(f, "r", encoding="utf-8") as fin:
                fout.writelines(fin.readlines())
            print(f"  ‚úÖ Merged: {f}")
        else:
            print(f"  ‚ö†Ô∏è Not found: {f}")

print(f"\n‚úÖ Final submission file: {output}")

# Count statistics
def count_predictions(file_path):
    if not os.path.exists(file_path):
        return 0, 0

    with open(file_path, 'r') as f:
        lines = f.readlines()

    unique_images = set()
    for line in lines:
        filename = line.split()[0]
        unique_images.add(filename)

    return len(lines), len(unique_images)

total_boxes, total_images = count_predictions(output)

print("\n" + "=" * 80)
print("üìä Final Statistics (Optimized):")
print("=" * 80)
print(f"  Total detections: {total_boxes}")
print(f"  Images with detections: {total_images}")
if total_images > 0:
    print(f"  Average detections per image: {total_boxes/total_images:.2f}")
print("=" * 80)

üîó Merging batch results...
  ‚úÖ Merged: /content/predict_txt/images1_hybrid_wbf_tta.txt
  ‚úÖ Merged: /content/predict_txt/images2_hybrid_wbf_tta.txt

‚úÖ Final submission file: /content/predict_txt/submission_hybrid_wbf_optimized.txt

üìä Final Statistics (Optimized):
  Total detections: 14282
  Images with detections: 5527
  Average detections per image: 2.58


## 12. Download Submission File

In [16]:
# Download the final submission file
from google.colab import files

print("üì• Downloading submission file (Optimized)...")
files.download('/content/predict_txt/submission_hybrid_wbf_optimized.txt')
print("‚úÖ Download complete!")

üì• Downloading submission file (Optimized)...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

‚úÖ Download complete!
