# 5. Ensemble Methods

This notebook combines predictions from multiple trained models using:
- **Weighted Boxes Fusion (WBF)** — the primary ensemble method.
- **Grid search** over WBF parameters (IoU threshold, skip threshold, model weights).
- **Per-class COCO evaluation** of ensemble results.

**Prerequisites:** Run `1_setup.ipynb` and have evaluation results (coco_instances_results.json) from trained models.

## 5.1 Imports

In [None]:
import os
import json

from pycocotools.coco import COCO

import config
from utils.ensemble import (
    load_predictions,
    run_wbf,
    evaluate_ensemble,
    save_ensemble,
    grid_search_wbf,
)

## 5.2 Configure Ensemble

Select which models' predictions to combine and on which test set.

In [None]:
# ===================== CONFIGURE ENSEMBLE =====================

# --- Ground truth ---
DATASET_SOURCE = "agar"    # 'agar' or 'roboflow'
SUBSET = "total"           # For AGAR: 'total', 'bright', 'dark', 'vague', 'lowres'

# --- Models to ensemble (keys from config) ---
# Example for AGAR total subset:
ENSEMBLE_MODEL_KEYS = [
    "total_faster_rcnn_R50",
    "total_faster_rcnn_R101",
    "total_retinanet_R50",
    "total_retinanet_R101",
    "total_mask_rcnn_R50",
    "total_mask_rcnn_R101",
]
MODEL_SOURCE = "agar"  # 'agar' or 'roboflow'

# Subfolder containing coco_instances_results.json for each model
# (typically '0', '2', or 'test' depending on evaluation run)
PREDICTIONS_SUBFOLDER = "0"

# ====================================================================

# Resolve paths
if DATASET_SOURCE == "agar":
    gt_path = config.AGAR_DATASETS[SUBSET]["test"]
elif DATASET_SOURCE == "roboflow":
    gt_path = config.ROBOFLOW_DATASETS["curated"]["test"]

prediction_paths = [
    config.get_predictions_path(key, MODEL_SOURCE, PREDICTIONS_SUBFOLDER)
    for key in ENSEMBLE_MODEL_KEYS
]

print(f"Ground truth: {gt_path}")
print(f"\nModel predictions ({len(prediction_paths)}):")
for k, p in zip(ENSEMBLE_MODEL_KEYS, prediction_paths):
    exists = '✓' if os.path.exists(p) else '✗'
    print(f"  [{exists}] {k}: {p}")

## 5.3 Load Predictions

In [None]:
coco_gt, coco_dts, img_ids = load_predictions(gt_path, prediction_paths)
print(f"Loaded {len(coco_dts)} model predictions over {len(img_ids)} images.")

## 5.4 Run WBF Ensemble

In [None]:
# WBF parameters (best found via grid search in the thesis)
IOU_THR = 0.75
SKIP_BOX_THR = 0.01
WEIGHTS = [5, 5, 7, 7, 5, 5]  # Per-model weights (match ENSEMBLE_MODEL_KEYS order)

ensemble_results = run_wbf(
    coco_gt, coco_dts, img_ids,
    iou_thr=IOU_THR,
    skip_box_thr=SKIP_BOX_THR,
    weights=WEIGHTS,
)

print(f"\nEnsemble produced {len(ensemble_results)} predictions.")

## 5.5 Evaluate Ensemble

In [None]:
print("=" * 60)
print("Overall evaluation:")
print("=" * 60)
evaluate_ensemble(coco_gt, ensemble_results)

print("\n" + "=" * 60)
print("Per-class evaluation:")
print("=" * 60)
class_names = ["S. aureus", "P. aeruginosa", "E. coli"]
for i, name in enumerate(class_names):
    print(f"\n--- {name} (category_id={i}) ---")
    evaluate_ensemble(coco_gt, ensemble_results, category_ids=[i])

## 5.6 Save Ensemble Results

In [None]:
output_name = f"{SUBSET}_{IOU_THR}_{SKIP_BOX_THR}_{WEIGHTS}_ensemble.json"
output_path = os.path.join(config.RESULTS_DIR, output_name)
save_ensemble(ensemble_results, output_path)

## 5.7 Grid Search (Optional)

Search over combinations of IoU thresholds, skip thresholds, and model weights.

In [None]:
# Uncomment to run grid search
# grid_results = grid_search_wbf(
#     coco_gt, coco_dts, img_ids,
#     iou_thresholds=[0.5, 0.75],
#     skip_box_thresholds=[0.01, 0.05],
#     weight_options=[
#         [1, 1, 1, 1, 1, 1],
#         [5, 5, 7, 7, 5, 5],
#         [7, 7, 1, 1, 7, 7],
#     ],
#     output_dir=os.path.join(config.RESULTS_DIR, "grid_search"),
# )