Skip to content

Commit

Permalink
New vis metrics (#1698)
Browse files Browse the repository at this point in the history
* v0

* v0

* v0

* v0

* v0

* v1

* v1

* new_vis_metrics

* prrr

* prrr

* docs

* docs

* tests

* comment_fixes

* comment_fixes

* comment_fixes

* comment_fixes

* comment_fixes

* rename

* linting

* renames

* renames

* docs

* no_scorer_names

* linting
  • Loading branch information
JKL98ISR committed Jul 3, 2022
1 parent d57bb29 commit ac32522
Show file tree
Hide file tree
Showing 12 changed files with 347 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ def __init__(self,
def initialize_run(self, context: Context):
"""Initialize run by creating the _state member with metrics for train and test."""
self._data_metrics = {}
self._data_metrics[DatasetKind.TRAIN] = get_scorers_list(context.train, self.alternative_metrics)
self._data_metrics[DatasetKind.TEST] = get_scorers_list(context.train, self.alternative_metrics)
self._data_metrics[DatasetKind.TRAIN] = get_scorers_list(context.train,
alternative_scorers=self.alternative_metrics)
self._data_metrics[DatasetKind.TEST] = get_scorers_list(context.train,
alternative_scorers=self.alternative_metrics)

if not self.metric_to_show_by:
self.metric_to_show_by = list(self._data_metrics[DatasetKind.TRAIN].keys())[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from deepchecks.core.condition import ConditionCategory
from deepchecks.utils.strings import format_number
from deepchecks.vision import Batch, Context, SingleDatasetCheck
from deepchecks.vision.metrics_utils.object_detection_precision_recall import ObjectDetectionAveragePrecision
from deepchecks.vision.metrics_utils.detection_precision_recall import ObjectDetectionAveragePrecision
from deepchecks.vision.vision_data import TaskType

__all__ = ['MeanAveragePrecisionReport']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from deepchecks.core.condition import ConditionCategory
from deepchecks.utils.strings import format_number
from deepchecks.vision import Batch, Context, SingleDatasetCheck
from deepchecks.vision.metrics_utils.object_detection_precision_recall import ObjectDetectionAveragePrecision
from deepchecks.vision.metrics_utils.detection_precision_recall import ObjectDetectionAveragePrecision
from deepchecks.vision.vision_data import TaskType

__all__ = ['MeanAverageRecallReport']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from deepchecks.core.errors import DeepchecksValueError
from deepchecks.utils.strings import format_number
from deepchecks.vision import Batch, Context, SingleDatasetCheck
from deepchecks.vision.metrics_utils.object_detection_precision_recall import ObjectDetectionAveragePrecision
from deepchecks.vision.metrics_utils.detection_precision_recall import ObjectDetectionAveragePrecision
from deepchecks.vision.vision_data import TaskType

__all__ = ['SingleDatasetScalarPerformance']
Expand Down
63 changes: 63 additions & 0 deletions deepchecks/vision/metrics_utils/confusion_matrix_counts_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2022 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
"""Module for confusion matrix counts metrics."""


def _calc_recall(tp: float, fp: float, fn: float) -> float: # pylint: disable=unused-argument
"""Calculate recall for given matches and number of positives."""
if tp + fn == 0:
return -1
rc = tp / (tp + fn)
return rc


def _calc_precision(tp: float, fp: float, fn: float) -> float:
"""Calculate precision for given matches and number of positives."""
if tp + fn == 0:
return -1
if tp + fp == 0:
return 0
pr = tp / (tp + fp)
return pr


def _calc_f1(tp: float, fp: float, fn: float) -> float:
"""Calculate F1 for given matches and number of positives."""
if tp + fn == 0:
return -1
if tp + fp == 0:
return 0
rc = tp / (tp + fn)
pr = tp / (tp + fp)
f1 = (2 * rc * pr) / (rc + pr)
return f1


def _calc_fpr(tp: float, fp: float, fn: float) -> float:
"""Calculate FPR for given matches and number of positives."""
if tp + fn == 0:
return -1
if tp + fp == 0:
return 0
return fp / (tp + fn)


def _calc_fnr(tp: float, fp: float, fn: float) -> float:
"""Calculate FNR for given matches and number of positives."""
if tp + fn == 0:
return -1
if tp + fp == 0:
return 1
return fn / (tp + fn)


AVAILABLE_EVALUTING_FUNCTIONS = {"recall": _calc_recall, "fpr": _calc_fpr,
"fnr": _calc_fnr, "precision": _calc_precision, "f1": _calc_f1}
60 changes: 24 additions & 36 deletions deepchecks/vision/metrics_utils/detection_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
#
"""Module for calculating detection precision and recall."""
import warnings
from abc import abstractmethod
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce

from deepchecks.vision.metrics_utils.metric_mixin import MetricMixin, ObjectDetectionMetricMixin


def _dict_conc(test_list):
result = defaultdict(list)
Expand All @@ -35,7 +36,7 @@ def _dict_conc(test_list):
return result


class AveragePrecisionRecall(Metric):
class AveragePrecisionRecall(Metric, MetricMixin):
"""Abstract class to calculate average precision and recall for various vision tasks.
Parameters
Expand All @@ -44,16 +45,15 @@ class AveragePrecisionRecall(Metric):
Maximum number of detections per class.
area_range: tuple, default: (32**2, 96**2)
Slices for small/medium/large buckets.
return_option: int, default: 0
0: ap only, 1: ar only, None: all (not ignite complient)
return_option: str, default: 'ap'
ap: ap only, ar: ar only, None: all (not ignite complient)
"""

def __init__(self, *args, max_dets: Union[List[int], Tuple[int]] = (1, 10, 100),
area_range: Tuple = (32**2, 96**2),
return_option: Optional[int] = 0, **kwargs):
return_option: Optional[int] = "ap", **kwargs):
super().__init__(*args, **kwargs)

self._evals = defaultdict(lambda: {"scores": [], "matched": [], "NP": []})
self.return_option = return_option
if self.return_option is not None:
max_dets = [max_dets[-1]]
Expand All @@ -63,7 +63,6 @@ def __init__(self, *args, max_dets: Union[List[int], Tuple[int]] = (1, 10, 100),
self.iou_thresholds = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
self.max_detections_per_class = max_dets
self.area_range = area_range
self.i = 0

@reinit__is_reduced
def reset(self):
Expand All @@ -79,9 +78,9 @@ def update(self, output):

for detected, ground_truth in zip(y_pred, y):
if isinstance(detected, torch.Tensor):
detected = detected.cpu()
detected = detected.cpu().detach()
if isinstance(ground_truth, torch.Tensor):
ground_truth = ground_truth.cpu()
ground_truth = ground_truth.cpu().detach()

self._group_detections(detected, ground_truth)
self.i += 1
Expand Down Expand Up @@ -123,12 +122,12 @@ def compute(self):
recall_list[class_id] = recall
reses["precision"][iou_i, area_i, dets_i] = precision_list
reses["recall"][iou_i, area_i, dets_i] = recall_list
if self.return_option == 0:
if self.return_option == "ap":
return torch.tensor(self.get_classes_scores_at(reses["precision"],
max_dets=self.max_detections_per_class[0],
area=self.area_ranges_names[0],
get_mean_val=False))
elif self.return_option == 1:
elif self.return_option == "ar":
return torch.tensor(self.get_classes_scores_at(reses["recall"],
max_dets=self.max_detections_per_class[0],
area=self.area_ranges_names[0],
Expand Down Expand Up @@ -335,27 +334,16 @@ def get_classes_scores_at(self, res: np.ndarray, iou: float = None, area: str =
res = res.clip(min=0)
return res[0][0]

@abstractmethod
def get_confidences(self, detections) -> List[float]:
"""Get detections object of single image and should return confidence for each detection."""
pass

@abstractmethod
def calc_pairwise_ious(self, detections, labels) -> Dict[int, np.ndarray]:
"""Get single result from group_class_detection_label and return matrix of IOUs."""
pass

@abstractmethod
def group_class_detection_label(self, detections, labels) -> dict:
"""Group detection and labels in dict of format {class_id: {'detected' [...], 'ground_truth': [...]}}."""
pass

@abstractmethod
def get_detection_areas(self, detections) -> List[int]:
"""Get detection object of single image and should return area for each detection."""
pass

@abstractmethod
def get_labels_areas(self, labels) -> List[int]:
"""Get labels object of single image and should return area for each label."""
pass

class ObjectDetectionAveragePrecision(AveragePrecisionRecall, ObjectDetectionMetricMixin):
"""Calculate average precision and recall for object detection.
Parameters
----------
max_dets: Union[List[int], Tuple[int]], default: [1, 10, 100]
Maximum number of detections per class.
area_range: tuple, default: (32**2, 96**2)
Slices for small/medium/large buckets.
return_option: str, default: 'ap'
ap: ap only, ar: ar only, None: all (not ignite complient)
"""
155 changes: 155 additions & 0 deletions deepchecks/vision/metrics_utils/detection_tp_fp_fn_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2022 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
"""Module for calculating verious detection metrics."""
import typing as t
from collections import defaultdict

import numpy as np
import torch
from ignite.metrics import Metric
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce

from deepchecks.vision.metrics_utils.confusion_matrix_counts_metrics import AVAILABLE_EVALUTING_FUNCTIONS
from deepchecks.vision.metrics_utils.metric_mixin import MetricMixin, ObjectDetectionMetricMixin


class TpFpFn(Metric, MetricMixin):
"""Abstract class to calculate the TP, FP, FN and runs an evaluting function on the result.
Parameters
----------
iou_thres: float, default: 0.5
IoU below this threshold will be ignored.
confidence_thres: float, default: 0.5
Confidence below this threshold will be ignored.
evaluting_function: Union[Callable, str], default: "recall"
will run on each class result i.e `func(tp, fp, fn)`
"""

def __init__(self, *args, iou_thres: float = 0.5, confidence_thres: float = 0.5,
evaluting_function: t.Union[t.Callable, str] = "recall", **kwargs):
super().__init__(*args, **kwargs)

self.iou_thres = iou_thres
self.confidence_thres = confidence_thres
if isinstance(evaluting_function, str):
evaluting_function = AVAILABLE_EVALUTING_FUNCTIONS.get(evaluting_function)
if evaluting_function is None:
raise ValueError(
f"Expected evaluting_function one of {list(AVAILABLE_EVALUTING_FUNCTIONS.keys())},"
f" recived: {evaluting_function}")
self.evaluting_function = evaluting_function

@reinit__is_reduced
def reset(self):
"""Reset metric state."""
super().reset()
self._evals = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})
self._i = 0

@reinit__is_reduced
def update(self, output):
"""Update metric with batch of samples."""
y_pred, y = output

for detected, ground_truth in zip(y_pred, y):
if isinstance(detected, torch.Tensor):
detected = detected.cpu().detach()
if isinstance(ground_truth, torch.Tensor):
ground_truth = ground_truth.cpu().detach()

self._group_detections(detected, ground_truth)
self._i += 1

@sync_all_reduce("_evals")
def compute(self):
"""Compute metric value."""
# now reduce accumulations
sorted_classes = [int(class_id) for class_id in sorted(self._evals.keys())]
max_class = max(sorted_classes)
res = -np.ones(max_class + 1)
for class_id in sorted_classes:
ev = self._evals[class_id]
res[class_id] = self.evaluting_function(ev["tp"], ev["fp"], ev["fn"])
return res

def _group_detections(self, detected, ground_truth):
"""Group gts and dts on a imageXclass basis."""
# Calculating pairwise IoUs on classes
bb_info = self.group_class_detection_label(detected, ground_truth)
ious = {k: self.calc_pairwise_ious(v["detected"], v["ground_truth"]) for k, v in bb_info.items()}

for class_id in ious.keys():
tp, fp, fn = self._evaluate_image(
np.array(self.get_confidences(bb_info[class_id]["detected"])),
bb_info[class_id]["ground_truth"],
ious[class_id]
)

acc = self._evals[class_id]
acc["tp"] += tp
acc["fp"] += fp
acc["fn"] += fn

def _evaluate_image(self, confidences: t.List[float], ground_truths: t.List, ious: np.ndarray) -> \
t.Tuple[float, float, float]:
"""Evaluate image."""
# Sort detections by decreasing confidence
confidences = confidences[confidences > self.confidence_thres]
sorted_confidence_ids = np.argsort(confidences, kind="stable")[::-1]
orig_ious = ious

# sort list of dts and chop by max dets
ious = orig_ious[sorted_confidence_ids]

detection_matches = self._get_best_matches(ground_truths, ious)
matched = np.array([d_idx in detection_matches for d_idx in range(len(ious))])
if len(matched) == 0:
tp, fp = 0, 0
else:
tp = np.sum(matched)
fp = len(matched) - tp
return tp, fp, len(ground_truths) - tp

def _get_best_matches(self, ground_truths: t.List, ious: np.ndarray) -> t.Dict[int, int]:
ground_truth_matched = {}
detection_matches = {}

for d_idx in range(len(ious)):
# information about best match so far (best_match=-1 -> unmatched)
best_iou = min(self.iou_thres, 1 - 1e-10)
best_match = -1
for g_idx in range(len(ground_truths)):
# if this gt already matched, continue
if g_idx in ground_truth_matched:
continue

if ious[d_idx, g_idx] >= best_iou:
best_iou = ious[d_idx, g_idx]
best_match = g_idx
if best_match != -1:
detection_matches[d_idx] = best_match
ground_truth_matched[best_match] = d_idx
return detection_matches


class ObjectDetectionTpFpFn(TpFpFn, ObjectDetectionMetricMixin):
"""Calculate the TP, FP, FN and runs an evaluting function on the result.
Parameters
----------
iou_thres: float, default: 0.5
Threshold of the IoU.
confidence_thres: float, default: 0.5
Threshold of the confidence.
evaluting_function: Union[Callable, str], default: "recall"
will run on each class result i.e `func(tp, fp, fn)`
"""

0 comments on commit ac32522

Please sign in to comment.