# Conformal Object Detection: first steps

This tutorial should get you started doing **Conformal Object Detection (COD)** with the [`cods`](https://github.com/leoandeol/cods) library.

For more information on the methods implemented in CODS, see the papers: 
- [Andéol et al. 2023: Confident Object Detection via Conformal Prediction and Conformal Risk Control](https://proceedings.mlr.press/v204/andeol23a.html)
- [Angelopoulos et al. 2022: Conformal Risk Control](https://arxiv.org/abs/2208.02814)
- [Li et al. 2022: Towards PAC Multi-Object Detection and Tracking](https://arxiv.org/abs/2204.07482)
- [Bates et al. 2021: Risk Controlling Prediction Sets](https://dl.acm.org/doi/abs/10.1145/3478535)


### Get started
1. Download the MS-COCO dataset: 
    - https://cocodataset.org/
2. Download DETR: automatically via Pytorch hub: https://pytorch.org/hub/
    - source: https://github.com/facebookresearch/detr

### Contents
What we will be doing:
1. Setup inference [⤵](#Setup-inferences)
    - load predictor (DETR) pretrained on COCO
    - Split the validation into: calibration & validation dataset
2. Run inferences on these datasets [⤵](#Setup-inferences)
    - Save predictions to disk: faster than re-predict for every test
3. Test Conformal Prediction !

In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from cods.od.data import MSCOCODataset
from cods.od.models import YOLOModel, DETRModel
import logging
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = (
    "1"  # chose the GPU. If only one, then "0"
)

logging.getLogger().setLevel(logging.INFO)

## 2. Setup inferences [🔝](#conformal-object-detection-first-steps)

In [3]:
# set [COCO_PATH] to the directory to your local copy of the COCO dataset
COCO_PATH = "/datasets/shared_datasets/coco/"

data = MSCOCODataset(root=COCO_PATH, split="val")

In [4]:
calibration_ratio = (
    0.5  # set 0.5 to use 50% for calibration and 50% for testing
)

use_smaller_subset = True  # TODO: Temp

if use_smaller_subset:
    data_cal, data_val = data.split_dataset(
        calibration_ratio, shuffle=False, n_calib_test=800
    )
else:
    data_cal, data_val = data.split_dataset(calibration_ratio, shuffle=False)

# model and weights are downloaded from https://github.com/facebookresearch/detr
model = DETRModel(model_name="detr_resnet50", pretrained=True, device="cpu")
# model = YOLOModel(model_name="yolov8x.pt", pretrained=True)


print(f"{len(data) = }")
print(f"{len(data_cal) = }")
print(f"{len(data_val) = }")

Using cache found in /home/leo.andeol/.cache/torch/hub/facebookresearch_detr_main


len(data) = 5000
len(data_cal) = 400
len(data_val) = 400


Run inferences:
- the first time, run inferences and save them disk
- if predictions are saved on disk, load them

In [5]:
preds_cal = model.build_predictions(
    data_cal,
    dataset_name="mscoco",
    split_name="cal",
    batch_size=12,
    collate_fn=data._collate_fn,  # TODO: make this a default for COCO
    shuffle=False,
    force_recompute=False,  # False,
    deletion_method="nms",
)
preds_val = model.build_predictions(
    data_val,
    dataset_name="mscoco",
    split_name="test",
    batch_size=12,
    collate_fn=data._collate_fn,
    shuffle=False,
    force_recompute=False,  # False,
    deletion_method="nms",
)

Loading predictions from ./saved_predictions/2c92d1aaa0cc2db665dc992cc2c004015b949d723cda785c3c3a140ebe8a808b.pkl
Predictions already exist, loading them...
Loading predictions from ./saved_predictions/27b7022a01eb9f119e53d0e6c2c7e9a25a4444c25cea01599ce79e2a14f06cd0.pkl
Predictions already exist, loading them...


# New Tests

In [None]:
from cods.od.cp import ODConformalizer
from cods.od.loss import (
    PixelWiseRecallLoss,
    ClassificationLossWrapper,
    ODBinaryClassificationLoss,
)
from cods.od.data import ODParameters, ODConformalizedPredictions
from cods.od.utils import (
    generalized_iou,
    compute_risk_image_level,
    match_predictions_to_true_boxes,
    apply_margins,
)
import numpy as np
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt

scs = np.linspace(0, 1, 20)
lbd_loc = 0
lbd_cls = 0

pred_cls = preds_cal.pred_cls
true_cls = preds_cal.true_cls
pred_boxes = preds_cal.pred_boxes
true_boxes = preds_cal.true_boxes

loc_loss_f = PixelWiseRecallLoss()
_loss = ODBinaryClassificationLoss()
cls_loss_f = ClassificationLossWrapper(_loss)

cls_lbds = 1 - np.logspace(-6, 1, 20)
loc_lbds = np.linspace(0, 300, 20)

results_loc = {}
results_cls = {}
# TODO(leo): métrique monotonization
# TODO(leo): générer pdf avec des paires images, et courbes en loc et cls

for cls_lbd, loc_lbd in tqdm(zip(cls_lbds, loc_lbds)):
    res_loc = []
    res_cls = []
    for confi in scs:
        match_predictions_to_true_boxes(
            preds_cal,
            distance_function="hausdorff",
            verbose=False,
            overload_confidence_threshold=confi,
        )

        conf_boxes = apply_margins(preds_cal.pred_boxes, loc_lbd)
        conf_cls = [
            [torch.where(cli >= 1 - cls_lbd)[0] for cli in cl]
            for cl in preds_cal.pred_cls
        ]

        tmp_parameters = ODParameters(
            global_alpha=None,
            confidence_threshold=confi,
            predictions_id=preds_cal.unique_id,
        )
        tmp_conformalized_predictions = ODConformalizedPredictions(
            predictions=preds_cal,
            parameters=tmp_parameters,
            conf_boxes=conf_boxes,
            conf_cls=conf_cls,
        )
        # TODO(leo): cannot do that with with object level or can I ?
        # TODO(leoandeol): classwise risk ????
        loc_loss = compute_risk_image_level(
            tmp_conformalized_predictions,
            preds_cal,
            loc_loss_f,
            return_list=True,
        )

        cls_loss = compute_risk_image_level(
            tmp_conformalized_predictions,
            preds_cal,
            cls_loss_f,
            return_list=True,
        )

        res_loc.append(loc_loss)
        res_cls.append(cls_loss)

    results_loc[f"{loc_lbd}"] = res_loc
    results_cls[f"{cls_lbd}"] = res_cls

# Mean

In [None]:
import matplotlib.pyplot as plt

x = scs
for k, v in results_loc.items():
    plt.plot(x, [vv.mean() for vv in v], label=f"{int(float(k))}")
plt.xlabel(f"Confidence Threshold")
plt.ylabel(f"Loss value")
plt.title("Localization")
plt.yscale("log")
plt.legend()

In [None]:
import matplotlib.pyplot as plt

x = scs
for k, v in results_cls.items():
    plt.plot(x, [vv[2].mean() for vv in v], label=f"{float(k):.5f}")
plt.xlabel(f"Confidence Threshold")
plt.ylabel(f"Loss value")
plt.title("Classification")
plt.yscale("log")
plt.legend()

# Per Sample

In [None]:
idx = 1

In [None]:
for k, v in results_loc.items():
    plt.plot(scs, [vv[idx] for vv in v], label=f"{int(float(k))}")
plt.xlabel(f"Confidence Threshold")
plt.ylabel(f"Loss value")
plt.title("Localization")
plt.yscale("log")
plt.legend()

In [None]:
for k, v in results_cls.items():
    plt.plot(scs, [vv[idx] for vv in v], label=f"{int(float(k))}")
plt.xlabel(f"Confidence Threshold")
plt.ylabel(f"Loss value")
plt.title("Classification")
plt.yscale("log")
plt.legend()

# Creating a PDF

In [19]:
def generate_plot(idx, plot_type):
    """
    Generates a matplotlib plot for the given image.

    Args:
        image: PIL Image object.
        plot_type: Type of plot to generate (1 or 2).

    Returns:
        Matplotlib figure object.
    """
    import matplotlib.pyplot as plt
    import numpy as np

    fig, ax = plt.subplots(figsize=(4, 3))
    if plot_type == "lox":
        for k, v in results_loc.items():
            ax.plot(scs, [vv[idx] for vv in v], label=f"{int(float(k))}")
        ax.set_xlabel("Confidence Threshold")  # Correct method
        ax.set_ylabel("Loss Value")  # Correct method
        ax.set_title("Localization")  # Correct method
        ax.set_yscale("log")  # Set y-axis to log scale
        ax.legend()
    elif plot_type == "cls":
        for k, v in results_cls.items():
            ax.plot(scs, [vv[idx] for vv in v], label=f"{float(k):.4f}")
        ax.set_xlabel("Confidence Threshold")  # Correct method
        ax.set_ylabel("Loss Value")  # Correct method
        ax.set_title("Classification")  # Correct method
        ax.set_yscale("log")
        ax.legend()
    return fig

In [20]:
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter
from reportlab.lib.utils import ImageReader
from reportlab.pdfbase.ttfonts import TTFont
from reportlab.pdfbase import pdfmetrics
from reportlab.lib.units import inch
import matplotlib.pyplot as plt
import io
import math
from PIL import Image
from tqdm import tqdm


def create_dataset_pdf_with_plots(
    predictions,
    output_filename="monotonicity_plot_with_images.pdf",
):
    """
    Creates a PDF displaying rows of three images (one loaded image and two plots) on each page.

    Args:
        predictions: An object with attribute `image_paths` containing paths to the images.
        generate_plot: A function that generates matplotlib plots based on image data.
        output_filename: Name of the output PDF file.
    """

    # Set up the PDF canvas
    c = canvas.Canvas(output_filename, pagesize=letter)
    width, height = letter

    # Register a default font
    pdfmetrics.registerFont(TTFont("Monospace", "monospace.medium.ttf"))

    # Calculate positions for the three images per row
    row_height = (height - 2 * inch) / 2
    image_width = (width - 4 * inch) / 3  # One row with three columns
    image_height = row_height * 0.8  # Reserve some space for titles

    # Positions for three items per row
    positions = [
        (1 * inch, height - 1.5 * inch - row_height),  # Top row: Image
        (
            1 * inch + image_width + 1 * inch,
            height - 1.5 * inch - row_height,
        ),  # Top row: Plot 1
        (
            1 * inch + 2 * (image_width + 1 * inch),
            height - 1.5 * inch - row_height,
        ),  # Top row: Plot 2
        (1 * inch, height - 2.5 * inch - 2 * row_height),  # Bottom row: Image
        (
            1 * inch + image_width + 1 * inch,
            height - 2.5 * inch - 2 * row_height,
        ),  # Bottom row: Plot 1
        (
            1 * inch + 2 * (image_width + 1 * inch),
            height - 2.5 * inch - 2 * row_height,
        ),  # Bottom row: Plot 2
    ]

    image_count = 0

    for i, path in tqdm(enumerate(predictions.image_paths)):
        # Load image
        img = Image.open(path)

        # Convert image to bytes
        img_byte_arr = io.BytesIO()
        img.save(img_byte_arr, format="JPEG")
        img_byte_arr = img_byte_arr.getvalue()

        # Create an ImageReader object
        img_reader = ImageReader(io.BytesIO(img_byte_arr))

        # Generate two plots
        plot1 = generate_plot(i, plot_type="loc")
        plot2 = generate_plot(i, plot_type="cls")

        # Convert plots to bytes
        plot1_buf = io.BytesIO()
        plot1.savefig(plot1_buf, format="PNG", bbox_inches="tight")
        plot1_buf.seek(0)
        plot1_reader = ImageReader(plot1_buf)

        plot2_buf = io.BytesIO()
        plot2.savefig(plot2_buf, format="PNG", bbox_inches="tight")
        plot2_buf.seek(0)
        plot2_reader = ImageReader(plot2_buf)

        # Get positions
        pos_image = positions[(image_count % 6) // 3 * 3]
        pos_plot1 = positions[(image_count % 6) // 3 * 3 + 1]
        pos_plot2 = positions[(image_count % 6) // 3 * 3 + 2]

        # Draw the images and plots
        c.drawImage(
            img_reader,
            pos_image[0],
            pos_image[1],
            width=image_width,
            height=image_height,
        )
        c.drawImage(
            plot1_reader,
            pos_plot1[0],
            pos_plot1[1],
            width=image_width,
            height=image_height,
        )
        c.drawImage(
            plot2_reader,
            pos_plot2[0],
            pos_plot2[1],
            width=image_width,
            height=image_height,
        )

        # Add title for the loaded image
        title = os.path.basename(path)
        c.setFont("Monospace", 8)
        c.drawString(pos_image[0], pos_image[1] - 0.2 * inch, title)

        image_count += 1

        # Start a new page if six images (two rows) are filled
        if image_count % 6 == 0:
            c.showPage()

    # Save the PDF
    c.save()

    print(
        f"PDF created with {image_count} images (each with two plots) on {math.ceil(image_count / 6)} pages."
    )

In [None]:
create_dataset_pdf_with_plots(
    preds_cal,
    output_filename="monotonicity_HAUSDORFF.pdf",
);

  fig, ax = plt.subplots(figsize=(4, 3))
  ax.set_yscale("log")
400it [01:51,  3.59it/s]


PDF created with 400 images (each with two plots) on 67 pages.
