Skip to content

Commit

Permalink
Enable validation of empty predictions (#1530)
Browse files Browse the repository at this point in the history
Find a non-empty label/prediction to validate, and finish validation with no error if non exist

Co-authored-by: Itay Gabbay <itay@deepchecks.com>
  • Loading branch information
2 people authored and Matan Perlmutter committed May 31, 2022
1 parent 67ee47c commit 5107884
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 19 deletions.
4 changes: 1 addition & 3 deletions deepchecks/tabular/suites/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ Using an Existing Suite
List of Prebuilt Suites
---------------------------

- single_dataset_integrity - Runs a set of checks that are meant to
- dataset_integrity - Runs a set of checks that are meant to
detect integrity issues within a single dataset.
- train_test_leakage - Runs a set of checks that are meant to detect
data leakage from the training dataset to the test dataset.
- train_test_validation - Runs a set of checks that are meant to
validate correctness of train-test split, including integrity, drift
and leakage.
Expand Down
30 changes: 21 additions & 9 deletions deepchecks/vision/detection_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,21 @@ def validate_label(self, batch):
raise ValidationError('Check requires object detection label to be a non-empty list')
if not isinstance(labels[0], torch.Tensor):
raise ValidationError('Check requires object detection label to be a list of torch.Tensor')
if len(labels[0].shape) != 2:
sample_idx = 0
# Find a non empty tensor to validate
while labels[sample_idx].shape[0] == 0:
sample_idx += 1
if sample_idx == len(labels):
return # No labels to validate
if len(labels[sample_idx].shape) != 2:
raise ValidationError('Check requires object detection label to be a list of 2D tensors')
if labels[0].shape[1] != 5:
if labels[sample_idx].shape[1] != 5:
raise ValidationError('Check requires object detection label to be a list of 2D tensors, when '
'each row has 5 columns: [class_id, x, y, width, height]')
if torch.min(labels[0]) < 0:
if torch.min(labels[sample_idx]) < 0:
raise ValidationError('Found one of coordinates to be negative, check requires object detection '
'bounding box coordinates to be of format [class_id, x, y, width, height].')
if torch.max(labels[0][:, 0] % 1) > 0:
if torch.max(labels[sample_idx][:, 0] % 1) > 0:
raise ValidationError('Class_id must be a positive integer. Object detection labels per image should '
'be a Bx5 tensor of format [class_id, x, y, width, height].')

Expand Down Expand Up @@ -183,18 +189,24 @@ def validate_infered_batch_predictions(batch_predictions):
raise ValidationError('Check requires detection predictions to be a non-empty list')
if not isinstance(batch_predictions[0], torch.Tensor):
raise ValidationError('Check requires detection predictions to be a list of torch.Tensor')
if len(batch_predictions[0].shape) != 2:
sample_idx = 0
# Find a non empty tensor to validate
while batch_predictions[sample_idx].shape[0] == 0:
sample_idx += 1
if sample_idx == len(batch_predictions):
return # No predictions to validate
if len(batch_predictions[sample_idx].shape) != 2:
raise ValidationError('Check requires detection predictions to be a list of 2D tensors')
if batch_predictions[0].shape[1] != 6:
if batch_predictions[sample_idx].shape[1] != 6:
raise ValidationError('Check requires detection predictions to be a list of 2D tensors, when '
'each row has 6 columns: [x, y, width, height, class_probability, class_id]')
if torch.min(batch_predictions[0]) < 0:
if torch.min(batch_predictions[sample_idx]) < 0:
raise ValidationError('Found one of coordinates to be negative, Check requires object detection '
'bounding box predictions to be of format [x, y, width, height, confidence,'
' class_id]. ')
if torch.min(batch_predictions[0][:, 4]) < 0 or torch.max(batch_predictions[0][:, 4]) > 1:
if torch.min(batch_predictions[sample_idx][:, 4]) < 0 or torch.max(batch_predictions[sample_idx][:, 4]) > 1:
raise ValidationError('Confidence must be between 0 and 1. Object detection predictions per image '
'should be a Bx6 tensor of format [x, y, width, height, confidence, class_id].')
if torch.max(batch_predictions[0][:, 5] % 1) > 0:
if torch.max(batch_predictions[sample_idx][:, 5] % 1) > 0:
raise ValidationError('Class_id must be a positive integer. Object detection predictions per image '
'should be a Bx6 tensor of format [x, y, width, height, confidence, class_id].')
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
git+https://github.com/sphinx-doc/sphinx.git@4.x#egg=sphinx
sphinx==4.5.0
nbsphinx>=0.8.7
pydata-sphinx-theme>=0.7.2
sphinx-copybutton>=0.4.0
Expand Down
8 changes: 2 additions & 6 deletions tests/vision/base/test_vision_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,7 @@ def infer_on_batch(self, batch, model, device):
assert_that(calling(detection_data.validate_label).with_args([8]),
raises(ValidationError,
'Check requires object detection label to be a list of torch.Tensor'))
assert_that(calling(detection_data.validate_label).with_args([torch.Tensor([])]),
raises(ValidationError,
'Check requires object detection label to be a list of 2D tensors'))
assert_that(detection_data.validate_label([torch.Tensor([])]), equal_to(None))
assert_that(calling(detection_data.validate_label).with_args([torch.Tensor([[1, 2], [1, 2]])]),
raises(ValidationError,
'Check requires object detection label to be a list of 2D tensors, when '
Expand All @@ -408,9 +406,7 @@ def infer_on_batch(self, batch, model, device):
assert_that(calling(detection_data.validate_prediction).with_args([8], None, None),
raises(ValidationError,
'Check requires detection predictions to be a list of torch.Tensor'))
assert_that(calling(detection_data.validate_prediction).with_args([torch.Tensor([])], None, None),
raises(ValidationError,
'Check requires detection predictions to be a list of 2D tensors'))
assert_that(detection_data.validate_prediction([torch.Tensor([])], None, None), equal_to(None))
assert_that(calling(detection_data.validate_prediction).with_args([torch.Tensor([[1, 2], [1, 2]])], None, None),
raises(ValidationError,
'Check requires detection predictions to be a list of 2D tensors, when '
Expand Down

0 comments on commit 5107884

Please sign in to comment.