In [None]:
from cods.od.cp import ODConformalizer
from cods.od.data import MSCOCODataset
from cods.od.models import DETRModel
from cods.od.visualization import plot_preds, create_pdf_with_plots


COCO_PATH = "/datasets/shared_datasets/coco/"
data = MSCOCODataset(root=COCO_PATH, split="val")
calibration_ratio = 0.5
data_cal, data_val = data.split_dataset(calibration_ratio, shuffle=False)

model = DETRModel(model_name="detr_resnet50", pretrained=True, device="cpu")

preds_cal = model.build_predictions(
    data_cal,
    dataset_name="mscoco",
    split_name="cal",
    batch_size=12,
    collate_fn=data._collate_fn,  # TODO: make this a default for COCO
    shuffle=False,
    force_recompute=False,  # False,
    deletion_method="nms",
    filter_preds_by_confidence=1e-3,
)
preds_val = model.build_predictions(
    data_val,
    dataset_name="mscoco",
    split_name="test",
    batch_size=12,
    collate_fn=data._collate_fn,
    shuffle=False,
    force_recompute=False,  # False,
    deletion_method="nms",
    filter_preds_by_confidence=1e-3,
)

conf = ODConformalizer(
    guarantee_level="image",
    matching_function=row2["Matching Function"],
    multiple_testing_correction=None,
    confidence_method=row2["Confidence Method"],
    localization_method=row2["Localization Method"],
    localization_prediction_set=row2["Localization Prediction Set"],
    classification_method="binary",
    classification_prediction_set=row2["Classification Prediction Set"],
    backend="auto",
    optimizer="binary_search",
)

parameters = conf.calibrate(
    preds_cal,
    alpha_confidence=row2["Confidence Alpha"],
    alpha_localization=row2["Localization Alpha"],
    alpha_classification=row2["Classification Alpha"],
)

conformal_preds = conf.conformalize(preds_val, parameters=parameters)

results_val = conf.evaluate(
    preds_val,
    parameters=parameters,
    conformalized_predictions=conformal_preds,
    include_confidence_in_global=False,
)

In [None]:
results_val.localization_set_sizes.mean()

tensor(1.0915)

In [None]:
results_val.localization_coverages.mean()

tensor(0.1020)

In [None]:
create_pdf_with_plots(
    preds_val,
    conformal_preds,
    output_pdf="visualizations.pdf",
    idx_to_label=data.NAMES,
)