In [1]:
import os
from pathlib import Path
from sklearn.metrics import ConfusionMatrixDisplay

# from ba_dev.eval_helpers import *
from ba_dev.eval_helpers import set_custom_plot_style, plot_image_with_bbox

set_custom_plot_style()

In [2]:
import ast
import json
import yaml
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics as skm
import matplotlib.patches as patches

from matplotlib.figure import Figure
from matplotlib.axes import Axes
from IPython.display import display
from PIL import Image
from pathlib import Path
from os import PathLike
from typing import Dict, Any

from ba_dev.dataset import MammaliaDataImage
from ba_dev.utils import BBox

In [3]:
class LoadRun:
    def __init__(
        self,
        log_path: str | PathLike
    ):
        self.log_path = Path(log_path)
        self.info = self.get_experiment_info()
        self.cross_val = self.info['cross_val']['apply']
        if self.cross_val:
            self.folds = range(2 if self.info.get('dev_run', False) else self.info['cross_val']['n_folds'])
        else:
            self.folds = [None]

        self.ds_path = Path(self.info['paths']['dataset'])
        self.run_dataset = self.get_run_dataset()

        self.decoder = self.info['output']['label_decoder']

    def show_sample(
            self,
            idx: int,
            info_to_print: str | list[str] | None = None,
            show_figures: bool = True
            ) -> Figure:

        sample = self.get_sample(idx=idx)

        if isinstance(info_to_print, str):
            info_to_print = [info_to_print]

        if info_to_print:
            for info in info_to_print:
                if info in sample:
                    print(f"{info}: {sample[info]}")
                else:
                    print(f"Info '{info}' not an available key.")

        figure = plot_image_with_bbox(
                image=Image.open(sample['path']),
                bbox=sample['bbox'],
                conf=sample['conf']
                )

        if show_figures:
            display(figure)

        return figure

    def show_all_bboxes_for_image(
            self,
            idx: int,
            first_n: int = -1,
            show_figures: bool = True
            ) -> list[Figure]:

        sample = self.get_sample(idx=idx)
        detections = self.get_bb_for_file(idx=idx)
        if not first_n == -1:
            detections = detections[:first_n]
        image = Image.open(sample['path'])

        figures = []

        for det in detections:
            figure = plot_image_with_bbox(
                    image=image,
                    bbox=det[0],
                    conf=det[1]
                    )

            if show_figures:
                display(figure)

            figures.append(figure)

        return figures

    def get_bb_for_file(
            self,
            idx: int
            ):

        seq_id = self.run_dataset.iloc[idx]['seq_id']
        file_name = Path(self.run_dataset.iloc[idx]['file_path']).name

        path_to_detection_results = Path(self.info['paths']['md_output']) / f"{seq_id}.json"
        with open(path_to_detection_results, 'r') as f:
            data = json.load(f)

        for entry in data:
            if entry['file'] == file_name:
                detections = entry['detections']

        pairs = [
                (det['bbox'], det['conf'])
                for det in detections
                if int(det['category']) == 1
                ]

        pairs_sorted = sorted(pairs, key=lambda x: x[1], reverse=True)

        return pairs_sorted

    def get_dataset(
            self,
            fold: int = 0,
            ) -> MammaliaDataImage:

        return MammaliaDataImage(
            path_labelfiles=self.info['paths']['labels'],
            path_to_dataset=self.info['paths']['dataset'],
            path_to_detector_output=self.info['paths']['md_output'],
            n_folds=self.info['cross_val']['n_folds'],
            test_fold=fold,
            mode='eval',
            image_pipeline=None,
            **self.info['dataset']
            )

    def calculate_metrics(
            self,
            metric: str,
            set_selection: list[str] | str = 'test',
            **kwargs
            ):

        func = getattr(skm, metric, None)
        if func is None:
            raise ValueError(f"Metric '{metric}' not found in sklearn.metrics")

        def compute(df_pred: pd.DataFrame):
            y_true = df_pred['class_id']
            y_pred = df_pred['pred_id']
            return func(y_true, y_pred, **kwargs)

        if self.cross_val:
            results = []
            for fold in self.folds:
                df_pred = self.get_predictions(fold=fold, set_selection=set_selection)
                results.append(compute(df_pred))
            return results
        else:
            df_pred = self.get_predictions(set_selection=set_selection)
            return compute(df_pred)

    def get_sample(
            self,
            idx: int
            ) -> Dict[str, Any]:

        return {
            'idx': idx,
            'class_label': self.run_dataset.iloc[idx]['class_label'],
            'class_id': self.run_dataset.iloc[idx]['class_id'],
            'seq_id': self.run_dataset.iloc[idx]['seq_id'],
            'path': self.ds_path / self.run_dataset.iloc[idx]['file_path'],
            'bbox': self.run_dataset.iloc[idx]['bbox'],
            'conf': self.run_dataset.iloc[idx]['conf']
            }

    def get_experiment_info(self) -> Dict:
        yaml_path = self.log_path / 'experiment_info.yaml'
        if not yaml_path.exists():
            raise FileNotFoundError(f"Experiment info file not found at {yaml_path}")
        with open(yaml_path, 'r') as f:
            return yaml.safe_load(f)

    def get_run_dataset(self) -> pd.DataFrame:
        csv_path = self.log_path / 'dataset.csv'
        if not csv_path.exists():
            raise FileNotFoundError(f"Dataset file not found at {csv_path}")
        df = pd.read_csv(csv_path)
        return self._enforce_dtypes_and_idx(df)

    def get_predictions(
            self,
            fold: int | None = None,
            set_selection: list[str] | str | None = None,
            filter_by: str | None = None,
            sort: str | None = None
            ) -> pd.DataFrame:
        
        if set_selection:
            if isinstance(set_selection, str):
                set_selection = [set_selection]
            
            for set_sel in set_selection:
                if set_sel not in ['train', 'val', 'test']:
                    raise ValueError("set_selection must be 'train', 'val', or 'test'")

        prediction_path = self._handle_crossval_or_not('predictions', fold)

        df = pd.read_csv(prediction_path)
        ds = self.get_run_dataset()
        df['seq_id'] = ds['seq_id']

        df['correct'] = df['class_id'] == df['pred_id']

        df = self._enforce_dtypes_and_idx(df)
        
        if set_selection:
            mask = df['set'].isin(set_selection)
            df = df[mask]

        if 'probs' in df.columns:
            df['probs_max'] = [
                prob_list[pred]
                for prob_list, pred in zip(df['probs'], df['pred_id'])
            ]

        if filter_by:
            if filter_by == 'correct':
                df = df[df['correct']]
            elif filter_by == 'incorrect':
                df = df[~df['correct']]
            else:
                raise ValueError("filter_by must be either 'correct' or 'incorrect'")

        if sort:
            if sort == 'probs_max':
                df = df.sort_values(by='probs_max', ascending=False)

        return df

    def get_metrics(
            self,
            fold: int | None = None,
            ) -> pd.DataFrame:

        metrics_path = self._handle_crossval_or_not('metrics', fold)

        df = pd.read_csv(metrics_path)
        df = self._enforce_dtypes_and_idx(df)

        return df

    def _handle_crossval_or_not(
            self,
            type: None | str = None,
            fold: None | int = None,
            ) -> Path:

        options = {'metrics': 'metrics.csv', 'predictions': 'predictions.csv'}
        if type not in options:
            raise ValueError(f"Type must be one of {options.keys()}")

        if self.cross_val:
            if fold is None:
                raise ValueError("Fold number must be provided for cross-validation runs.")
            predictions_path = self.log_path / f'fold_{fold}' / options[type]
        else:
            predictions_path = self.log_path / options[type]

        if not predictions_path.exists():
            raise FileNotFoundError(f"{type} file not found at {predictions_path}")

        return predictions_path

    def _enforce_dtypes_and_idx(
            self,
            df: pd.DataFrame
            ) -> pd.DataFrame:
        df.insert(0, 'idx', df.index)
        cast_map = {
            'seq_id': 'int64',
            'class_id': 'int8',
            'fold': 'int8',
        }
        existing_casts = {k: v for k, v in cast_map.items() if k in df.columns}
        df = df.astype(existing_casts)

        def to_float_list(x):
            if isinstance(x, str):
                x = ast.literal_eval(x)
            return [float(i) for i in x]

        for col in ['bbox', 'probs']:
            if col in df.columns:
                df[col] = df[col].apply(to_float_list)

        return df


In [4]:
path_to_models = Path('/cfs/earth/scratch/kraftjul/BA/output/complete')

model_paths = [p for p in path_to_models.iterdir() if p.is_dir()]

In [5]:
model_paths[0]

PosixPath('/cfs/earth/scratch/kraftjul/BA/output/complete/resNet50_v1_no_pretrained_cross_val')

In [6]:
run = LoadRun(
        log_path=model_paths[0]
        )

In [19]:
ds = run.get_predictions(fold=1)

In [20]:
ds.keys()

Index(['idx', 'class_id', 'set', 'pred_id', 'probs', 'seq_id', 'correct',
       'probs_max'],
      dtype='object')

In [None]:
def agg_probs(ps):
    summed = [sum(col) for col in zip(*ps)]
    total = sum(summed)
    return [v/total for v in summed]

In [50]:
aggregated = (
    ds
    .groupby('seq_id')
    .agg(
        class_id = ('class_id', 'first'),
        set = ('set', 'first'),
        count = ('pred_id', 'size'),
        pred_id_majority = ('pred_id', lambda x: x.mode()),
        probs   = ('probs',   agg_probs)
    )
    .reset_index()
)

aggregated['prob_max'] = aggregated['probs'].apply(max)
aggregated['pred_id_max'] = aggregated['probs'].apply(lambda p: p.index(max(p)))


In [51]:
aggregated

Unnamed: 0,seq_id,class_id,set,count,pred_id_majority,probs,prob_max,pred_id_max
0,1000001,0,train,6,0,"[1.0, 0.0, 0.0, 0.0]",1.000000,0
1,1000002,1,train,3,1,"[3.333333333333333e-05, 0.9999666666666666, 0....",0.999967,1
2,1000003,1,test,3,1,"[0.3169438981299376, 0.6687222907430247, 0.014...",0.668722,1
3,1000004,1,test,2,"[0, 1]","[0.51025, 0.4813, 0.0078, 0.00065]",0.510250,0
4,1000005,1,train,2,1,"[0.0, 1.0, 0.0, 0.0]",1.000000,1
...,...,...,...,...,...,...,...,...
21781,7000005,2,train,2,2,"[0.0, 0.0, 1.0, 0.0]",1.000000,2
21782,7000006,2,val,16,2,"[0.0, 0.0, 1.0, 0.0]",1.000000,2
21783,7000007,2,train,5,2,"[0.0, 0.0, 1.0, 0.0]",1.000000,2
21784,7000008,2,train,4,2,"[0.0, 0.0, 1.0, 0.0]",1.000000,2
