### Load dependencies 
We start by importing all necessary python packages, and the functionalities implemented in the utils folder of this repo.

In [None]:
import random

from utils.inference_POLO import *
from utils.model_eval import *
from utils.processing_utils import *

### Insert dataset name, output directories, etc. 
Here we specify the dataset we want to test, and set the some paths and inference parameters. Below you can find short explanations for each parameter. 
- `imgs_dir`: Path to the directory containing the images you want to run inference on.
- `output_dir`: Path to the directory ourtputs can be stored in (must exist, is NOT going to be created).
- `mdl_path`: Path to the trained model file (.pt).
- `device`: Device on which to load and run the model. E.g., `"0"`. Pass `"cpu"` if you do not have a GPU. 
- `patch_dims`: Dictionary specifying the tile dimensions.
- `ovrlp`: Amount of overlap between tiles (fraction of the tile size. E.g., if tiles are 640x640 and overlap == 0.2, tiles will overlap by 128 pixels).
- `dor_thresh`: DoR threshold to use during inference. Can be different from what was used during training. Will affect performance metrics.
- `radii`: Dictionary mapping category ID to radius in pixels. Can also be different from training and will also affect performance.
- `img_format`: Format extension of the images inference will be perfromed on. 

In [None]:
imgs_dir = ""
output_dir = "" 
mdl_path = ""
device = "0"
patch_dims = {"width": 640, "height": 640}  
ovrlp = 0.2  
dor_thresh = 0.3  
radii = {0: 50, 1: 80}
classID2name = {0: "example1", 1: "example2"}
img_format = "jpg"
random.seed(0)

### Optional: Annotation file and -format
During tiled inference, one has the option to collect evaluation metrics and counting erros for each image. However, this requires creating a .json file containing a dictionary that maps image names (file name of the image without format extension) to a list of all the ground truth labels in that image. Annotations must be dictionaries, containing at least the following two keyowrds: `point`, and `category_id`. The value behind the "point" key is expected to be a list storing the x- and y-coordinate of the label, whereas for `category_id` the integer class ID must be provided as value:

```
{
    "img1": [
                {"point": [x1, y1], "category_id": 0},
                {"point": [x2, y2], "category_id": 2},
                ...
            ],
    "img2": [
                {"point": [x3, y3], "category_id": 1},
                {"point": [x4, y4], "category_id": 1},
                ...
            ],
    ...

}
```

In [None]:
# Set to None or remove from below method call if you don't have an annotation file
ann_file = f""
ann_format = "PT_DEFAULT"

### Run tiled inference 


In [None]:

run_tiled_inference_POLO(model=mdl_path, 
                         class_ids=list(radii.keys()),
                         imgs_dir=imgs_dir, 
                         img_files_ext=img_format,
                         patch_dims=patch_dims, 
                         patch_overlap=ovrlp, 
                         output_dir=output_dir,
                         dor_thresh=dor_thresh,
                         radii=radii,
                         ann_file=ann_file,
                         ann_format=ann_format)

### Compute Evaluation metrics
If annotation-file and -format were passed to `run_tiled_inference_POLO()` above, this cell can be used to evaluated the inference results. It will generate a number of files:
- `count_diffs_img_lvl.xlsx`: Excel sheet containing the difference between predicted and ground truth count for each image.
- `counts_gt_pred_*.png`: Plot of predicted vs. forund truth count for class `*`.
- `counts_total.json`: Predicted counts summed over all images.
- `em.json`: Evaluation metrics.
- `errors_img_lvl.json`: Counting metrics.
- `F1_curve.png`: F1 score plotted against the confidence threshold.
- `P_curve.png`: Precision plotted against the confidence threshold.
- `R_curve.png`: Recall plotted against the confidence threshold.

Calling `compute_errors_img_lvl()` will compute countign errors per image (MAE, MSE, etc.), whereas `compute_em_img_lvl()` will calculate traditional detection metrics like precision and recall. Note that `compute_errors_img_lvl()` will look for a directory called `image_counts`  (which you will have to create) within the directory you specify under the `imgs_dir` parameter. The `image_counts` directory must contain one .json file for each image in `imgs_dir` named exactly like the image (`img1.jpg` --> `img1.json`). The .json file in turn must contain a dictionary mapping class IDs to counts. For example, if `img1.jpg` contains 5 objects of class 0 and no objects of any other class, the dictionary in `img1.json` should look as follows: `{0: 5}`.

In [None]:
compute_errors_img_lvl(gt_counts_dir=f"{imgs_dir}/image_counts", pred_counts_dir=f"{output_dir}/detections", class_ids=list(classID2name.keys()), 
                           output_dir=output_dir)
compute_em_img_lvl(preds_dir=f"{output_dir}/detections", class_id2name=classID2name, task="locate", output_dir=output_dir)    