### Load dependencies 
We start by importing all necessary python packages, and the functionalities implemented in the utils folder of this repo.

In [None]:
import random

from utils.inference_POLO import *
from utils.model_eval import *
from utils.processing_utils import *
from sahi.predict import predict
from utils.data_params import parameters

### Insert dataset name, output directories, etc. (**TODO**)
Here we specify the dataset we want to test, and set the some paths. For the dataset, please use the abbreviations given in the paper. E.g., `data_set = BK-L23`. This will load the corresponding IoU-/DoR-threshold, patch overlap, etc. in the next cell (values can also be found in the paper). Please set the following variables:
- `data_set`: Name (abbreviation of the dataset).
- `imgs_dir`: Path to the directory containing the test images of the dataset.
- `output_dir`: Path to the directory ourtputs can be stored in (must exist, is NOT going to be created).
- `mdl_path`: Path to the trained model file (.pt).
- `is_pseudo`: Set to true if you are using a YOLOv8_p model.
- `device`: Device on which to load and run the model. E.g., `"0"`. Pass `"cpu"` if you do not have a GPU. 

In [None]:
data_set = ""
imgs_dir = ""
output_dir = "" 
mdl_path = ""
is_pseudo = False
device = "0"


For this next cell, no input is required - just run it. 

In [None]:
patch_dims = {"width": 640, "height": 640}  
ovrlp = parameters[data_set]["ovrlp"]   
dor_thresh = parameters[data_set]["dor_thresh"]   
iou_thresh = parameters[data_set]["iou_thresh"] if not is_pseudo else  parameters[data_set]["iou_thresh_pseudo"]
radii = parameters[data_set]["radii"]
classID2name = parameters[data_set]["classID2name"]
img_format = parameters[data_set]["img_format"]

### Set some more paths (no input required)
Here the path to the file containing the test set annotations, and to the tiling folder are set. The annotations file is needed to compute evaluation metrics and counting errors after inference, and the tiling folder is where the patches extracted from each image are going to be stored. We also set the random seed for reproducibility. 

In [None]:
ann_file = f"{imgs_dir}/test_annotations.json"
tiling_dir = f"{imgs_dir}/tiles"
random.seed(0)

### Define Task (**TODO**)
Here we specify what model we will be using. Set the `task` variable to `"locate"` if you are working with a POLO model, use `"detect"` otherwise.

In [None]:
task = "detect"     #TODO: define task

### Run tiled inference (no input required)
This is where we run the actual inference For bounding box models, we use the `SAHI` library, for POLO we use the methods implemented in `utils/inference_POLO.py`. `coco_file_path` will point to a json file required by `SAHI` to run tiled inference for bounding box models.

In [None]:
if task == "locate":
    bx_dims = {cid: {"width": radii[cid], "height": radii[cid]} for cid in radii.keys()} if data_set == "EW-IL22" else None
    run_tiled_inference_POLO(model=mdl_path, 
                             class_ids=list(radii.keys()),
                             imgs_dir=imgs_dir, 
                             img_files_ext=img_format,
                             patch_dims=patch_dims, 
                             patch_overlap=ovrlp, 
                             output_dir=output_dir,
                             dor_thresh=dor_thresh,
                             radii=radii,
                             ann_file=ann_file,
                             ann_format="BX_WH",
                             box_dims=bx_dims)
else:
    categories = [{"id": k, "name": v} for k,v in classID2name.items()]
    coco_file_path = make_coco_file(imgs_dir=imgs_dir, categories=categories)

    predict(
        model_type="yolov8",
        model_path=mdl_path,
        model_device=f"cuda:{device}", 
        source=imgs_dir,
        slice_height=patch_dims["height"],
        slice_width=patch_dims["width"],
        overlap_height_ratio=ovrlp,
        overlap_width_ratio=ovrlp,
        postprocess_match_threshold=iou_thresh,
        postprocess_type="NMS" if data_set == "JE-TL19" else "GREEDYNMM",
        postprocess_match_metric="IOU"  if data_set == "JE-TL19" else "IOS",
        dataset_json_path=coco_file_path,
        project=output_dir, 
        name="output_SAHI",
        novisual=True, 
        verbose=0
    )


### Compute Evaluation metrics (no input required)
This cell evaluated the ouputs of the previous cell. It will generate a number of files:
- `count_diffs_img_lvl.xlsx`: Excel sheet containing the difference between predicted and ground truth count for each image.
- `counts_gt_pred_*.png`: Plot of predicted vs. forund truth count for class `*`.
- `counts_total.json`: Predicted counts summed over all images.
- `em.json`: Evaluation metrics.
- `errors_img_lvl.json`: Counting metrics.
- `F1_curve.png`: F1 score plotted against the confidence threshold.
- `P_curve.png`: Precision plotted against the confidence threshold.
- `R_curve.png`: Recall plotted against the confidence threshold.

In [None]:
if task == "detect":
    bx_dims = {cid: {"width": 50, "height": 50} for cid in radii.keys()} if data_set == "EW-IL22" else None
    read_output_SAHI(out_json_SAHI=f"{output_dir}/output_SAHI/result.json", dataset_json_SAHI=coco_file_path, class_ids=list(classID2name.keys()), 
                     iou_thresh=iou_thresh, ann_file=ann_file, ann_format="BX_WH", box_dims=bx_dims, output_dir=output_dir)

compute_errors_img_lvl(gt_counts_dir=f"{imgs_dir}/image_counts", pred_counts_dir=f"{output_dir}/detections", class_ids=list(classID2name.keys()), 
                           output_dir=output_dir)
compute_em_img_lvl(preds_dir=f"{output_dir}/detections", class_id2name=classID2name, task=task, output_dir=output_dir)    