### Load dependencies 
We start by importing all necessary python packages, and the functionalities implemented in the utils folder of this repo.

In [1]:
import random

from utils.inference_POLO import *
from utils.model_eval import *
from utils.processing_utils import *
from sahi.predict import predict
from utils.data_params import parameters

### Insert paths, output directories, patch dimensions etc.
Here we specify the inference parameters, including which model and test set to use, the dimensions of the patches, their overlap, etc. Some general info:
- We use 640x640 patches in the paper. 
- The amount of overlap can be calculated from the values specified in Table 2. E.g.: $\frac{128}{640} = 0.2$
- The DoR and IoU values can be found in Table 12.
- The radii can be found in the Size Avg. column of Table 1. They need to be passed as a dictionary where `key = class ID`, `value = radius`. The ID of a species is its position in the Species column of Table 1. E.g., for the JE-TL19 dataset, the `radii` dict would look as follows: `{0: 62, 1: 81, 2: 49}`
- `classID2name` is a dictionary mapping class ids to their name. Again, for the JE-TL19 dataset: `classID2name = {0: "Elephant", 1: "Giraffe", 2: "Zebra"}`

In [2]:
data_set = "EW-IL22"
imgs_dir = "/home/giacomo/data/test_sets/EW-IL22"   #TODO: Insert Path to image directory (as downloaded from Zenodo).
output_dir = "/home/giacomo/projects/How-to-Minimize-the-Annotation-Cost-of-Aerial-Wildlife-Censuses/outputs/EW-IL22/YOLOv8n_p"     #TODO: Insert path to directory you want your output stored in
mdl_path = "/home/giacomo/data/weights/EW-IL22/yolov8n_p_EW-IL22.pt"     #TODO: Insert path to the model .pt file
is_pseudo = True       #TODO: Set to True if working with a pseudo model
device = "0"     #TODO: Insert device to be used for inference. "cuda" if you have access to a GPU, "cpu" otherwise.


In [3]:
patch_dims = {"width": 640, "height": 640}  
ovrlp = parameters[data_set]["ovrlp"]   
dor_thresh = parameters[data_set]["dor_thresh"]   
iou_thresh = parameters[data_set]["iou_thresh"] if not is_pseudo else  parameters[data_set]["iou_thresh_pseudo"]
radii = parameters[data_set]["radii"]
classID2name = parameters[data_set]["classID2name"]
img_format = parameters[data_set]["img_format"]

### Set some more paths (no input required)
Here the path to the file containing the test set annotations, and to the tiling folder are set. The annotations file is needed to compute evaluation metrics and counting errors after inference, and the tiling folder is where the patches extracted from each image are going to be stored. We also set the random seed for reproducibility. 

In [4]:
ann_file = f"{imgs_dir}/test_annotations.json"
tiling_dir = f"{imgs_dir}/tiles"
random.seed(0)

### Define Task
Here we specify what model we will be using. Set the `task` variable to `"locate"` if you are working with a POLO model, use `"detect"` otherwise.

In [5]:
task = "detect"     #TODO: define task

### Run tiled inference (no input required)
This is where we run the actual inference For bounding box models, we use the `SAHI` library, for POLO we use the methods implemented in `utils/inference_POLO.py`. `coco_file_path` will point to a json file required by `SAHI` to run tiled inference for bounding box models.

In [6]:
if task == "locate":
    bx_dims = {cid: {"width": radii[cid], "height": radii[cid]} for cid in radii.keys()} if data_set == "EW-IL22" else None
    run_tiled_inference_POLO(model=mdl_path, 
                             class_ids=list(radii.keys()),
                             imgs_dir=imgs_dir, 
                             img_files_ext=img_format,
                             patch_dims=patch_dims, 
                             patch_overlap=ovrlp, 
                             output_dir=output_dir,
                             dor_thresh=dor_thresh,
                             radii=radii,
                             ann_file=ann_file,
                             ann_format="BX_WH",
                             box_dims=bx_dims)
else:
    categories = [{"id": k, "name": v} for k,v in classID2name.items()]
    coco_file_path = make_coco_file(imgs_dir=imgs_dir, categories=categories)

    predict(
        model_type="yolov8",
        model_path=mdl_path,
        model_device=f"cuda:{device}", 
        source=imgs_dir,
        slice_height=patch_dims["height"],
        slice_width=patch_dims["width"],
        overlap_height_ratio=ovrlp,
        overlap_width_ratio=ovrlp,
        postprocess_match_threshold=iou_thresh,
        postprocess_type="NMS" if data_set == "JE-TL19" else "GREEDYNMM",
        postprocess_match_metric="IOU"  if data_set == "JE-TL19" else "IOS",
        dataset_json_path=coco_file_path,
        project=output_dir, 
        name="output_SAHI",
        novisual=True, 
        verbose=0
    )


indexing coco dataset annotations...


Loading coco annotations: 100%|██████████| 748/748 [00:00<00:00, 796885.80it/s]
Performing inference on images: 100%|██████████| 748/748 [37:25<00:00,  3.00s/it]


Prediction results are successfully exported to /home/giacomo/projects/How-to-Minimize-the-Annotation-Cost-of-Aerial-Wildlife-Censuses/outputs/EW-IL22/YOLOv8n_p/output_SAHI


### Compute Evaluation metrics (no input required)
This cell evaluated the ouputs of the previous cell. It will generate a number of files:
- `count_diffs_img_lvl.xlsx`: Excel sheet containing the difference between predicted and ground truth count for each image.
- `counts_gt_pred_*.png`: Plot of predicted vs. forund truth count for class `*`.
- `counts_total.json`: Predicted counts summed over all images.
- `em.json`: Evaluation metrics.
- `errors_img_lvl.json`: Counting metrics.
- `F1_curve.png`: F1 score plotted against the confidence threshold.
- `P_curve.png`: Precision plotted against the confidence threshold.
- `R_curve.png`: Recall plotted against the confidence threshold.

In [None]:
if task == "detect":
    bx_dims = {cid: {"width": 50, "height": 50} for cid in radii.keys()} if data_set == "EW-IL22" else None
    read_output_SAHI(out_json_SAHI=f"{output_dir}/output_SAHI/result.json", dataset_json_SAHI=coco_file_path, class_ids=list(classID2name.keys()), 
                     iou_thresh=iou_thresh, ann_file=ann_file, ann_format="BX_WH", box_dims=bx_dims, output_dir=output_dir)

compute_errors_img_lvl(gt_counts_dir=f"{imgs_dir}/image_counts", pred_counts_dir=f"{output_dir}/detections", class_ids=list(classID2name.keys()), 
                           output_dir=output_dir)
compute_em_img_lvl(preds_dir=f"{output_dir}/detections", class_id2name=classID2name, task=task, output_dir=output_dir)    

  count_diffs_df = pd.concat([count_diffs_df, row_df], ignore_index=True)


{0: {'confusion': 1.9581443025915912,
  'precision': 74.27545069352118,
  'recall': 51.72009140800877,
  'f1': 60.97887281845156,
  'metrics/precision(B)': 0.5706254913610201,
  'metrics/recall(B)': 0.4676139494459104,
  'metrics/mAP50(B)': 0.46977220426815275,
  'metrics/mAP50-95(B)': 0.08319039976647617},
 1: {'confusion': 14.415094404215035,
  'precision': 44.78672984013162,
  'recall': 23.095723009552806,
  'f1': 30.475678123888617,
  'metrics/precision(B)': 0.24782543629956552,
  'metrics/recall(B)': 0.20814663951120163,
  'metrics/mAP50(B)': 0.13745838445591418,
  'metrics/mAP50-95(B)': 0.028411932501594933},
 2: {'confusion': 12.979351289146457,
  'precision': 44.696969629247015,
  'recall': 50.51369854364093,
  'f1': 47.42765215873492,
  'metrics/precision(B)': 0.4098975858010135,
  'metrics/recall(B)': 0.4931506849315068,
  'metrics/mAP50(B)': 0.3054284652313954,
  'metrics/mAP50-95(B)': 0.07388097993809636},
 3: {'confusion': 4.6496609930094746,
  'precision': 82.119021111757