Skip to content

Commit

Permalink
RobustnessReport & ClassPerformance work on custom task (#1235)
Browse files Browse the repository at this point in the history
* Fix robustness report to work on custom task

* Fix class performance to work on custom task

* Fix lint

* Fix test
  • Loading branch information
matanper committed Apr 10, 2022
1 parent 7a0b9de commit 0fc6e9c
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 18 deletions.
3 changes: 0 additions & 3 deletions deepchecks/vision/checks/performance/class_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from deepchecks.utils import plot
from deepchecks.utils.strings import format_percent, format_number
from deepchecks.vision import TrainTestCheck, Context, Batch
from deepchecks.vision.vision_data import TaskType
from deepchecks.vision.metrics_utils.metrics import get_scorers_list, metric_results_to_df, \
filter_classes_for_display

Expand Down Expand Up @@ -84,8 +83,6 @@ def __init__(self,

def initialize_run(self, context: Context):
"""Initialize run by creating the _state member with metrics for train and test."""
context.assert_task_type(TaskType.CLASSIFICATION, TaskType.OBJECT_DETECTION)

self._data_metrics = {}
self._data_metrics[DatasetKind.TRAIN] = get_scorers_list(context.train, self.alternative_metrics)
self._data_metrics[DatasetKind.TEST] = get_scorers_list(context.train, self.alternative_metrics)
Expand Down
11 changes: 4 additions & 7 deletions deepchecks/vision/checks/performance/robustness_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(self,

def initialize_run(self, context: Context, dataset_kind):
"""Initialize the metrics for the check, and validate task type is relevant."""
context.assert_task_type(TaskType.CLASSIFICATION, TaskType.OBJECT_DETECTION)
dataset = context.get_data_by_kind(dataset_kind)
# Set empty version of metrics
self._state = {'metrics': get_scorers_list(dataset, self.alternative_metrics)}
Expand Down Expand Up @@ -182,8 +181,8 @@ def _validate_augmenting_affects(self, transform_handler, dataset: VisionData):
f'Dataset.__getitem__'
raise DeepchecksValueError(msg)

# For classification does not check label for difference
if dataset.task_type != TaskType.CLASSIFICATION:
# For object detection check that the label is affected
if dataset.task_type == TaskType.OBJECT_DETECTION:
labels = dataset.batch_to_labels(batch)
if torch.equal(labels[0], labels[1]):
msg = f'Found that labels have not been affected by adding augmentation to field ' \
Expand All @@ -210,7 +209,7 @@ def calc_percent(a, b):
aug_top_affected[metric].append({'class': index_class,
'value': single_metric_scores.at[index_class, 'Value'],
'diff': diff_value,
'samples': dataset.n_of_samples_per_class[index_class]})
'samples': dataset.n_of_samples_per_class.get(index_class, 0)})
return aug_top_affected

def _calc_performance_diff(self, mean_base, augmented_metrics):
Expand Down Expand Up @@ -388,10 +387,8 @@ def get_random_image_pairs_from_dataset(original_dataset: VisionData,
base_class_label = [x for x in base_label if x[0] == class_id]
aug_class_label = [x for x in aug_label if x[0] == class_id]
samples.append((images[0], images[1], class_id, (base_class_label, aug_class_label)))
elif original_dataset.task_type == TaskType.CLASSIFICATION:
samples.append((images[0], images[1], class_id))
else:
raise DeepchecksValueError('Not implemented')
samples.append((images[0], images[1], class_id))

return samples

Expand Down
6 changes: 5 additions & 1 deletion deepchecks/vision/metrics_utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,18 @@ def _validate_metric_type(metric_name: str, score: t.Any) -> bool:

def metric_results_to_df(results: dict, dataset: VisionData) -> pd.DataFrame:
"""Get dict of metric name to tensor of classes scores, and convert it to dataframe."""
# The data might contain fewer classes than the model was trained on. filtering out any class id which is not
# presented in the data.
data_classes = dataset.classes_indices.keys()

per_class_result = [
[metric, class_id, dataset.label_id_to_name(class_id),
class_score.item() if isinstance(class_score, torch.Tensor) else class_score]
for metric, score in results.items()
if _validate_metric_type(metric, score)
# scorer returns results as array, containing result per class
for class_id, class_score in enumerate(score)
if not np.isnan(class_score)
if not np.isnan(class_score) and class_id in data_classes
]

return pd.DataFrame(per_class_result, columns=['Metric',
Expand Down
9 changes: 9 additions & 0 deletions tests/vision/checks/performance/class_performance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,12 @@ def test_condition_class_performance_imbalance_ratio_not_greater_than_fail(mnist
device=device)

assert_that(result.conditions_results[0].is_pass, is_(False))


def test_custom_task(mnist_train_custom_task, mnist_test_custom_task, device, mock_trained_mnist):
# Arrange
metrics = {'metric': Precision()}
check = ClassPerformance(alternative_metrics=metrics)

# Act & Assert - check runs without errors
check.run(mnist_train_custom_task, mnist_test_custom_task, model=mock_trained_mnist, device=device)
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_coco_and_condition(coco_train_visiondata, mock_trained_yolov5_object_de
}),
has_entries({
'start': 0.6671875, 'stop': 0.75, 'count': 11, 'display_range': '[0.67, 0.75)',
'metrics': has_entries({'AP': close_to(0.364, 0.001), 'AR': close_to(0.366, 0.001)})
'metrics': has_entries({'AP': close_to(0.367, 0.001), 'AR': close_to(0.4, 0.001)})
}),
has_entries({
'start': 0.75, 'stop': close_to(1.102, 0.001), 'count': 28, 'display_range': '[0.75, 1.1)',
Expand Down
11 changes: 11 additions & 0 deletions tests/vision/checks/performance/robustness_report_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import albumentations
import numpy as np
from ignite.metrics import Precision

from deepchecks.vision.datasets.detection.coco import COCOData, CocoDataset

from tests.checks.utils import equal_condition_result
Expand Down Expand Up @@ -116,3 +118,12 @@ def new_apply(self, img, bboxes):
assert_that(calling(check.run).with_args(vision_data, mock_trained_yolov5_object_detection,
device=device),
raises(DeepchecksValueError, msg))


def test_custom_task(mnist_train_custom_task, device, mock_trained_mnist):
# Arrange
metrics = {'metric': Precision()}
check = RobustnessReport(alternative_metrics=metrics)

# Act & Assert - check runs without errors
check.run(mnist_train_custom_task, model=mock_trained_mnist, device=device)
9 changes: 3 additions & 6 deletions tests/vision/vision_conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from PIL import Image

from deepchecks.core import DatasetKind
from deepchecks.vision import VisionData, Context, Batch
Expand All @@ -35,10 +36,6 @@
from tests.vision.assets.mnist_predictions_dict import mnist_predictions_dict


from PIL import Image



# Fix bug with torch.hub path on windows
PROJECT_DIR = pathlib.Path(__file__).absolute().parent.parent.parent
torch.hub.set_dir(str(PROJECT_DIR))
Expand Down Expand Up @@ -288,7 +285,7 @@ class CustomTask(MNISTData):
def task_type(self) -> TaskType:
return TaskType.OTHER

return CustomTask(mnist_data_loader_train)
return CustomTask(mnist_data_loader_train, transform_field='transform')


@pytest.fixture(scope='session')
Expand All @@ -298,7 +295,7 @@ class CustomTask(MNISTData):
def task_type(self) -> TaskType:
return TaskType.OTHER

return CustomTask(mnist_data_loader_test)
return CustomTask(mnist_data_loader_test, transform_field='transform')


@pytest.fixture(scope='session')
Expand Down

0 comments on commit 0fc6e9c

Please sign in to comment.