In [1]:
import os
import shutil
import warnings
import csv
import yaml
import json
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Any, TypedDict

from megadetector.detection.run_detector import load_detector, model_string_to_model_version
from megadetector.detection.run_detector_batch import process_images, write_results_to_file

from os import PathLike
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split


In [21]:
def load_path_config(path_to_config):
    with open(path_to_config, 'r') as f:
        path_config = yaml.safe_load(f)
    return {k: Path(v) for k, v in path_config.items()}

paths = load_path_config('/cfs/earth/scratch/kraftjul/BA/code/path_config.yml')

In [4]:
class MegaDetectorRunner:
    """
    A class to run the MegaDetector model on images. Designed to be used on a set of image sequences,
    only loading the model once and running it on all sequences.

    Parameters
    ----------
    model_path : str | PathLike
        Path to the MegaDetector model file. Or a string representing the model version available online.
    confidence : float
        Confidence threshold for the model. Default is 0.25.
    """
    def __init__(
            self, 
            model_path: str | PathLike, 
            confidence: float = 0.25
            ):
        
        self.model = load_detector(str(model_path))
        self.confidence = confidence

    def run_on_images(
            self,
            images: list[PathLike],
            output_file_path: PathLike = None,
            ):

        results = process_images(
            im_files=images,
            detector=self.model,
            confidence_threshold=self.confidence,
            quiet=True
        )

        all_confidences = []

        for r in results:
            r["file"] = r["file"].name

            r["detections"] = [
                det for det in r.get("detections", [])
                if det["category"] == "1"
            ]
        
            all_confidences.extend(det["conf"] for det in r["detections"])

        all_confidences.sort(reverse=True)
        
        if output_file_path is not None:
            with open(output_file_path, "w") as f:
                json.dump(results, f, indent=2)

        return all_confidences      


In [28]:
class MammaliaData(Dataset):
    """
    A class to load and process the Mammalia dataset. It can be uset for the initial detection of the images
    utilizing the MegaDetector model, or for training a custom model for classification on the detected images.
    The dataset is divided into training and testing sets based on the sequence IDs.
    
    Parameters
    ----------
    path_labelfiles : str | PathLike
        Path to the directory containing the label files.
    path_to_dataset : str | PathLike
        Path to the main directory of the dataset, referenced in the labelfiles.
    path_to_detector_output : str | PathLike
        Path to the directory where the detector output is available for training or where the output will be saved
        if detection is applied.
    categories_to_drop : list[str], optional
        By default all non-empty labels are used. To drop certain labels from the dataset, provide a list of labels to drop.
        In detect mode, this parameter is ignored.
    detector_model : str
        If a detector model is provided, the detection will be applied to the whole dataset and stored for training.
        The model must be one of the available models in the MegaDetector repository.
        The default is None. A valid detection output must be available at the path_to_detector_output.
    detection_confidence : float
        The detection is done with a confidence of 0.25 by default to provide some flexibility
        with the training. The confidence can be set to a higher value to reduce the number of detections used from
        the output. The default is 0.25.
    sample_length : int
        For trainig this parameter specifies the range (1 - sample_length) of randomly seletded samples per sequence.
        For testing this parameter specifies the maximum number of samples per sequence.
        The default is 10.
    sample_img_size : [int, int]
        The size to which the detected areas are resized. The default is [224, 224].
    mode : str
        The mode in which the dataset is used. Can be either 'train', 'test' or 'init' defining which data will be 
        sampled and adjusting how it is sampled. The default is 'train'.
    """
    
    def __init__(
            self,
            path_labelfiles: str | PathLike,
            path_to_dataset: str | PathLike,
            path_to_detector_output: str | PathLike,
            categories_to_drop: list[str] = None,
            detector_model: str = None,
            detection_confidence: float = 0.25,
            sample_length: int = 10,
            sample_img_size: [int, int] = [224, 224],
            mode: str = 'train',
            ):
        super().__init__()

        if mode in ['train', 'test', 'init']:
            self.mode = mode
        else:
            raise ValueError("Please choose a mode from ['train', 'test'].")
        
        if detection_confidence < 0.25:
            raise ValueError("Detection confidence must be at least 0.25.")
        
        self.categories_to_drop = categories_to_drop if categories_to_drop is not None else []
        self.detection_confidence = detection_confidence
        self.sample_length = sample_length
        self.sample_img_size = sample_img_size

        self.path_labelfiles = Path(path_labelfiles)
        if not self.path_labelfiles.exists():
            raise ValueError("The path to the label files does not exist.")
        
        self.path_to_dataset = Path(path_to_dataset)
        if not path_to_dataset.exists():
            raise ValueError("The path to the dataset does not exist.")

        self.path_to_detector_output = Path(path_to_detector_output)
        self.detector_model = detector_model

        self.ds_full = self.reading_all_metadata(
                    list_of_files = self.getting_all_files_of_type(self.path_labelfiles, file_type='.csv'),
                    categories_to_drop = self.categories_to_drop
                    )        

        if self.mode == 'init':
            if self.detector_model is not None:
                self.run_detector()
            else:
                if not any(self.path_to_detector_output.glob("*.json")):
                    raise ValueError('A valid detection output must be available at the path_to_detector_output.')
        
        if self.ds_full['seq_id'].duplicated().any():
            duplicates = self.ds_full['seq_id'][self.ds_full['seq_id'].duplicated()].unique()
            raise ValueError(f"Duplicate seq_id(s) found in metadata: {duplicates[:5]} ...")

        train_seq_ids, test_seq_ids = train_test_split(
                                            self.ds_full['seq_id'],
                                            test_size=0.2,
                                            random_state=55,
                                            stratify=self.ds_full['label2']
                                            )

        filtered_train_seq_ids = self.exclude_ids_with_no_detections(
            set_type='train',
            sequences_to_filter=train_seq_ids
        )

        filtered_test_seq_ids = self.exclude_ids_with_no_detections(
            set_type='test',
            sequences_to_filter=test_seq_ids
        )
 
        if self.mode in ['train', 'init']:
            active_seq_ids = filtered_train_seq_ids
        elif self.mode == 'test':
            active_seq_ids = filtered_test_seq_ids
           
        self.ds = self.ds_full[self.ds_full['seq_id'].isin(active_seq_ids)]
        self.seq_ids = self.ds['seq_id'].tolist()

    def getting_all_files_of_type(
            self, 
            path: str | PathLike, 
            file_type: str = None, 
            get_full_path: bool = True
            ) -> list[str]:
        
        path = Path(path)
        files = []
        for file in os.listdir(path):
            if file_type is None or file.endswith(file_type):
                if get_full_path:
                    files.append(path / file)
                else:
                    files.append(file)
        return files
    
    def reading_all_metadata(
            self,
            list_of_files: list[PathLike],
            categories_to_drop: list[str] = []
            ) -> pd.DataFrame:
        
        metadata = pd.DataFrame()
        for file in list_of_files:
            metadata = pd.concat([metadata, pd.read_csv(file)], ignore_index=True)
            metadata = metadata.dropna(subset=['label2'])
            metadata = metadata[~metadata['label2'].isin(categories_to_drop)]
        return metadata
    
    def exclude_ids_with_no_detections(
            self,
            set_type: str,
            sequences_to_filter: list[int],
            ) -> list[int]:
        
        detection_summary = self.get_detection_summary(
            usecols=["seq_id", "max_conf"]
            )
        
        seq_ids_to_exclude_set = set(detection_summary[detection_summary["max_conf"] < self.detection_confidence]["seq_id"].tolist())
        seq_ids_to_filter_set = set(sequences_to_filter)

        excluded_seq_ids = list(seq_ids_to_filter_set & seq_ids_to_exclude_set)

        if excluded_seq_ids:
            suffix = "" if len(excluded_seq_ids) <= 10 else " ..."
            warnings.warn(
                f"With the current detection confidence of {self.detection_confidence},\n"
                f"{len(excluded_seq_ids)} sequences of the {set_type} set had no detections and will be excluded.\n"
                f"Excluded sequences: {excluded_seq_ids[:10]}{suffix}",
                UserWarning
            )
        
        return list(seq_ids_to_filter_set - seq_ids_to_exclude_set)

    def get_detection_summary(
            self,
            usecols: list[str] = None,
            ) -> pd.DataFrame:
        
        return pd.read_csv(
                self.path_to_detector_output / "detection_summary.csv",
                usecols=usecols
                )
    
    def get_clss_weight(                                        # still to be implemented
            self
            ) -> torch.Tensor:
        
        if self.mode == 'test':
            raise ValueError("Class weights are not available in test mode.")
        
        class_weights = 5
        
        return class_weights
    
    def get_all_images_of_sequence(
            self, 
            seq_id: int,
            dataframe: pd.DataFrame = None,
            )-> dict[str, PathLike]:
        
        if dataframe is None:
            dataframe = self.ds_full

        image_dict = {}
        row = dataframe.loc[dataframe['seq_id'] == seq_id].squeeze()
        seq_path = Path(row['Directory'])
        all_files = row['all_files'].split(',')
        for file in all_files:
            image_dict[file] = self.path_to_dataset / seq_path / file
        return image_dict

    def run_detector(
            self,
            ) -> None:
        
        if self.detector_model is None:
            raise ValueError('Method not available - No detector model provided.')
        elif self.detector_model not in model_string_to_model_version.keys():
            raise ValueError(f"The model {self.detector_model} is not supported. Please choose from {model_string_to_model_version.keys()}.")
        elif not self.path_to_detector_output.exists():
                os.makedirs(self.path_to_detector_output)
        elif any(self.path_to_detector_output.iterdir()):
            raise ValueError("The path to the detector output contains files. Please clear or choose a different path.")
          
        runner = MegaDetectorRunner(
            model_path=self.detector_model,
            confidence=0.25
            )

        metadata = self.reading_all_metadata(
                    list_of_files = self.getting_all_files_of_type(self.path_labelfiles, file_type='.csv'),
                    )
            
        sequences = metadata['seq_id'].unique().tolist()

        detection_rows = []

        for seq_id in sequences:
            seq_images = list(self.get_all_images_of_sequence(seq_id).values())
            output_file_path = self.path_to_detector_output / f"{seq_id}.json"
            detections = runner.run_on_images(
                images=seq_images,
                output_file_path=output_file_path
                )

            detection_row = {
                    "seq_id": seq_id,
                    "max_conf": max(detections) if len(detections) > 0 else 0,
                    "n_detections": len(detections),
                    "conf_list": json.dumps(detections)
                }
            
            detection_rows.append(detection_row)
        
        all_detections = pd.DataFrame(detection_rows, columns=["seq_id", "max_conf", "n_detections", "conf_list"])

        all_detections.to_csv(
            self.path_to_detector_output / "detection_summary.csv", 
            index=False,
            quoting=csv.QUOTE_NONNUMERIC
            )
            
    def getting_bb_list_for_seq(
            self,
            seq_id: int,
            confidence: float = None,
            ) -> list[dict]:
        
        if self.mode != 'detect':
            raise ValueError("Only available if dataset is in detect mode.")
        
        if confidence is None:
            confidence = self.detection_confidence

        path_to_detection_results = self.path_to_detector_output / f"{seq_id}.json"
        with open(path_to_detection_results, 'r') as f:
            data = json.load(f)

        bb_list = []

        for entry in data:
            file_name = entry['file']
            detections = entry.get('detections', [])

            for det in detections:
                if det['category'] == "1" and det['conf'] >= confidence:
                    bb_list({
                        'file': file_name,
                        'conf': det['conf'],
                        'bbox': det['bbox']
                    })
        
        bb_list = sorted(bb_list, key=lambda x: x['conf'], reverse=True)

        return bb_list

    def __len__(self) -> int:
        return len(self.ds)

    def __getitem__(self, index: int) -> Any:               # still to be implemented
        seq_id = self.seq_ids[index]

        images = self.get_all_images_of_sequence(seq_id)
        bounding_boxes = self.getting_bb_list_for_seq(seq_id)


### Running Tests

In [None]:

path_to_testset = Path('/cfs/earth/scratch/kraftjul/BA/data/test_set')
output_path = Path('/cfs/earth/scratch/kraftjul/BA/output')
categories_to_drop=['other', 'glis_glis']

In [29]:
path_to_dataset = paths['dataset']
path_labelfiles = Path('/cfs/earth/scratch/kraftjul/BA/data/test_set_large')
path_to_detector_output = path_labelfiles / 'MD_out'
detector_model = 'mdv5a'
mode = 'init'


dataset = MammaliaData(
    path_to_dataset = path_to_dataset,
    path_labelfiles = path_labelfiles,
    path_to_detector_output = path_to_detector_output,
    detector_model = detector_model,
    mode = mode,
    )

Bypassing download of already-downloaded file md_v5a.0.0.pt
Model v5a.0.0 available at /tmp/megadetector_models/md_v5a.0.0.pt
Bypassing imports for model type yolov5
Loading PT detector with compatibility mode classic


Fusing layers... 
Fusing layers... 
Model summary: 733 layers, 140054656 parameters, 0 gradients, 208.8 GFLOPs
Model summary: 733 layers, 140054656 parameters, 0 gradients, 208.8 GFLOPs
7 sequences of the train set had no detections and will be excluded.
Excluded sequences: [6000161, 6000163, 6000293, 6000530, 6000691, 6000372, 6000953]
1 sequences of the test set had no detections and will be excluded.
Excluded sequences: [6000186]


PosixPath('/cfs/earth/scratch/iunr/shared/iunr-mammaliabox/dataset')

### Sampling dataset

In [None]:
# Paths
path_to_labelfiles = Path("/cfs/earth/scratch/iunr/shared/iunr-mammaliabox/dataset/info/labels")
dataset_root = Path("/cfs/earth/scratch/iunr/shared/iunr-mammaliabox/dataset")
target_dir = Path("/cfs/earth/scratch/kraftjul/BA/data/test_set_large")
output_metadata_csv = target_dir / "metadata_larger_sample_set.csv"


# Load metadata
metadata = dataset.reading_all_metadata(
    list_of_files=dataset.getting_all_files_of_type(path_to_labelfiles, file_type='.csv'),
    categories_to_drop=['other', 'glis_glis']
)

metadata_filtered = metadata[metadata['n_files']<60]

metadata_sampled = metadata_filtered.groupby("label2", group_keys=False).sample(n=40, random_state=42)

metadata_sampled.to_csv(output_metadata_csv, index=False)

In [17]:
metadata_sampled

Unnamed: 0,session,SerialNumber,seq_nr,seq_id,Directory,DateTime_start,DateTime_end,duration_seconds,first_file,last_file,n_files,all_files,label,duplicate_label,label2
11567,4,H550HG09194945,233,4007156,sessions/session_04/W2-WK02,2020-06-09T23:21:46Z,2020-06-09T23:22:18Z,32.0,IMG_6154.JPG,IMG_6180.JPG,27,"IMG_6154.JPG,IMG_6155.JPG,IMG_6156.JPG,IMG_615...",apodemus_sp,False,apodemus_sp
15877,4,H550HF07158832,180,4011466,sessions/session_04/W5-KH08,2020-06-28T22:25:32Z,2020-06-28T22:25:38Z,6.0,IMG_5446.JPG,IMG_5454.JPG,9,"IMG_5446.JPG,IMG_5447.JPG,IMG_5448.JPG,IMG_544...",apodemus_sp,False,apodemus_sp
1815,1,H550HF08161305,229,1001887,sessions/session_01/H550HF08161305_2,2019-09-10T02:06:30Z,2019-09-10T02:07:31Z,61.0,IMG_3034.JPG,IMG_3051.JPG,18,"IMG_3034.JPG,IMG_3035.JPG,IMG_3036.JPG,IMG_303...",apodemus_sp,0.0,apodemus_sp
15095,4,H550HF07158933,34,4010684,sessions/session_04/W4-WK02,2020-06-20T23:40:40Z,2020-06-20T23:42:00Z,80.0,IMG_0607.JPG,IMG_0654.JPG,48,"IMG_0607.JPG,IMG_0608.JPG,IMG_0609.JPG,IMG_061...",apodemus_sp,False,apodemus_sp
4586,4,H,77,4000175,sessions/session_04/Testwoche1/KH08,2020-05-11T21:16:28Z,2020-05-11T21:16:30Z,2.0,RCNX1125.JPG,RCNX1127.JPG,3,"RCNX1125.JPG,RCNX1126.JPG,RCNX1127.JPG",apodemus_sp,False,apodemus_sp
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21906,4,H550HF07158839,222,4017495,sessions/session_04/W7-R25,2020-07-16T22:03:52Z,2020-07-16T22:03:52Z,0.0,IMG_3736.JPG,IMG_3738.JPG,3,"IMG_3736.JPG,IMG_3737.JPG,IMG_3738.JPG",sorex_sp,False,soricidae
12723,4,H550HG09194886,174,4008312,sessions/session_04/W3-M7,2020-06-18T04:43:02Z,2020-06-18T04:43:02Z,0.0,IMG_3646.JPG,IMG_3648.JPG,3,"IMG_3646.JPG,IMG_3647.JPG,IMG_3648.JPG",crocidura_sp,False,soricidae
20378,4,H550HG09194894,161,4015967,sessions/session_04/W6-R26,2020-07-11T22:25:42Z,2020-07-11T22:25:44Z,2.0,IMG_2272.JPG,IMG_2274.JPG,3,"IMG_2272.JPG,IMG_2273.JPG,IMG_2274.JPG",sorex_sp,False,soricidae
18825,4,H550HF07158832,147,4014414,sessions/session_04/W6-M2,2020-07-05T05:16:24Z,2020-07-05T05:16:26Z,2.0,IMG_3880.JPG,IMG_3882.JPG,3,"IMG_3880.JPG,IMG_3881.JPG,IMG_3882.JPG",crocidura_sp,False,soricidae
