### Load Dependencies

In [1]:
from utils.HN import *
from utils.model_eval import *
import random

### Define variables
- `class_name2ID` is a dictionary mapping class names to their name. The ID of a species is its position in the Species column of Table 1. It's name is the name given in Table 1. E.g., for the JE-TL19 dataset, : `classID2name = {"Elephant": 0, "Giraffe": 1, "Zebra": 2}`
- The DoR values can be found in Table 12.
- The radii can be found in the Size Avg. column of Table 1. They need to be passed as a dictionary where `key = class ID`, `value = radius`. Again, for the JE-TL19 dataset, the `radii` dict would look as follows: `{0: 62, 1: 81, 2: 49}`


In [None]:
imgs_dir = "/home/giacomo/data/test_sets/EW-IL22/"   #TODO: Insert Path to image directory (as downloaded from Zenodo).
model_weights = "/home/giacomo/data/weights/EW-IL22/HN_EW-IL22.pth" #TODO: Insert Path to the model weights you want to run (as downloaded from Zenodo).
class_name2ID = {"Brant": 0, "Other": 1, "Gull": 2, "Canada": 3, "Emperor": 4}    #TODO: Insert class names and IDs as found in Table 1.
dor_thresh = 0.4    #TODO: Define DoR threshold for NMS and computing evaluation metrics (cf. Table 12)
radii = {0: 50, 1: 50, 2: 37.5, 3: 50, 4: 50}     #TODO: Insert radius for each species in the dataset; cf. Table 1.
output_dir = "/home/giacomo/projects/How-to-Minimize-the-Annotation-Cost-of-Aerial-Wildlife-Censuses/outputs/EW-IL22/HN"     #TODO: Insert path to directory you want your output stored in.

### Run inference
This cell will run HerdNet on the images from the “imgs_dir” entered above and will automatically create a folder inside it (HerdNet_results) in which the .csv file containing the detections will be saved.
- `dets_file`: This is the path to the .csv file containing the HerdNet predictions. One row represents one detection and is expected to have the following columns:
    - `images`: Contains the file names of the images.
    - `x`: Contains the x-coordinate of the detection.
    - `y`: Contains the y-coordinate of the detection.
    - `scores`: Contains the confidence score of the prediction.
    - `species`: Contains the name of the detected species.

    ADD POINT ABOUT PATCH SIZE and OVERLAP
    ADD NAMES DICTS
    ADD OPTION FOR BOX_WH/PT_DEFAULT

In [None]:
random.seed(0)

%run -m utils.inference_herdnet {imgs_dir} {model_weights} --size 640 --over 64 

from pathlib import Path

dets_file = next(Path(imgs_dir).rglob("*_HerdNet_detections.csv"), None)

### Read HerdNet output and run evaluation pipeline (no input required)

In [7]:
read_detections_HN(dets_file=dets_file, cls_name2id=class_name2ID, imgs_dir=imgs_dir, dor_thresh=dor_thresh, radii=radii, class_ids=list(radii.keys()), output_dir=output_dir,
                   ann_file=f"{imgs_dir}/test_annotations.json", ann_format="BX_WH")

class_ID2name = {cid: name for name,cid in class_name2ID.items()}
compute_errors_img_lvl(gt_counts_dir=f"{imgs_dir}/image_counts", pred_counts_dir=f"{output_dir}/detections", class_ids=list(class_ID2name.keys()), 
                           output_dir=output_dir)
compute_em_img_lvl(preds_dir=f"{output_dir}/detections", class_id2name=class_ID2name, task="locate", output_dir=output_dir)    

{0: {'confusion': 4.605263785491687,
  'precision': 86.82634678547097,
  'recall': 51.2367489355592,
  'f1': 64.44444369124939,
  'metrics/precision(B)': 0.8686255977716412,
  'metrics/recall(B)': 0.519434628975265,
  'metrics/mAP100(B)': 0.6942146881078664,
  'metrics/mAP100-10(B)': 0.6665091319674638},
 1: {'confusion': 1.8518527606309898,
  'precision': 79.6992475210583,
  'recall': 43.08943071914866,
  'f1': 55.9366747110087,
  'metrics/precision(B)': 0.7585153046336648,
  'metrics/recall(B)': 0.4341334425476499,
  'metrics/mAP100(B)': 0.5945180494330806,
  'metrics/mAP100-10(B)': 0.5847158845621653},
 2: {'confusion': 3.311258918468485,
  'precision': 85.3801164597654,
  'recall': 48.5049832275582,
  'f1': 61.86440605545283,
  'metrics/precision(B)': 0.8056913498057943,
  'metrics/recall(B)': 0.4850498338870432,
  'metrics/mAP100(B)': 0.6446622440538027,
  'metrics/mAP100-10(B)': 0.6197604018462411},
 'binary': {'confusion': 3.256125154863721,
  'precision': 86.87089696527156,
  '