Skip to content

Commit

Permalink
Suggestion: Add assert in vision objects to Batch (#1078)
Browse files Browse the repository at this point in the history
* Add assert in vision objects to Batch

* fix docstring
  • Loading branch information
matanper committed Mar 22, 2022
1 parent 58c421c commit df47639
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 26 deletions.
9 changes: 5 additions & 4 deletions deepchecks/vision/checks/distribution/image_property_drift.py
Expand Up @@ -70,14 +70,15 @@ def __init__(
self.max_num_categories = max_num_categories
self.classes_to_display = classes_to_display
self.min_samples = min_samples
self._train_properties = defaultdict(list)
self._test_properties = defaultdict(list)
self._train_properties = None
self._test_properties = None
self._class_to_string = None

def initialize_run(self, context: Context):
"""Initialize self state, and validate the run context."""
context.train.assert_image_formatter_valid()
context.test.assert_image_formatter_valid()
self._class_to_string = context.train.label_id_to_name
self._train_properties = defaultdict(list)
self._test_properties = defaultdict(list)

def update(
self,
Expand Down
39 changes: 30 additions & 9 deletions deepchecks/vision/context.py
Expand Up @@ -21,7 +21,7 @@
from deepchecks.vision.utils.validation import apply_to_tensor
from deepchecks.core.errors import (
DatasetValidationError, DeepchecksNotImplementedError, ModelValidationError,
DeepchecksNotSupportedError, DeepchecksValueError
DeepchecksNotSupportedError, DeepchecksValueError, ValidationError
)


Expand Down Expand Up @@ -52,6 +52,7 @@ def labels(self):
"""Return labels for the batch, formatted in deepchecks format."""
if self._labels is None:
dataset = self._context.get_data_by_kind(self._dataset_kind)
dataset.assert_labels_valid()
self._labels = dataset.batch_to_labels(self._batch)
return self._labels

Expand All @@ -60,14 +61,18 @@ def predictions(self):
"""Return predictions for the batch, formatted in deepchecks format."""
if self._predictions is None:
dataset = self._context.get_data_by_kind(self._dataset_kind)
self._predictions = dataset.infer_on_batch(self._batch, self._context.model, self._context.device)
# Calling model will raise error if model was not given
model = self._context.model
self._context.assert_predictions_valid(self._dataset_kind)
self._predictions = dataset.infer_on_batch(self._batch, model, self._context.device)
return self._predictions

@property
def images(self):
"""Return images for the batch, formatted in deepchecks format."""
if self._images is None:
dataset = self._context.get_data_by_kind(self._dataset_kind)
dataset.assert_images_valid()
self._images = dataset.batch_to_images(self._batch)
return self._images

Expand Down Expand Up @@ -124,22 +129,32 @@ def __init__(self,
train.validate_shared_label(test)

self._device = torch.device(device) if isinstance(device, str) else (device if device else torch.device('cpu'))
self._prediction_formatter_error = {}

if model is not None:
if not isinstance(model, nn.Module):
logger.warning('Model is not a torch.nn.Module. Deepchecks can\'t validate that model is in '
'evaluation state.')
else:
if model.training:
raise DatasetValidationError('Model is not in evaluation state. Please set model training '
'parameter to False or run model.eval() before passing it.')
for dataset, dataset_type in zip([train, test], ['train', 'test']):
elif model.training:
raise DatasetValidationError('Model is not in evaluation state. Please set model training '
'parameter to False or run model.eval() before passing it.')

for dataset, dataset_type in zip([train, test], [DatasetKind.TRAIN, DatasetKind.TEST]):
if dataset is not None:
try:
dataset.validate_prediction(next(iter(dataset.data_loader)), model, self._device)
msg = None
except DeepchecksNotImplementedError:
logger.warning('validate_prediction() was not implemented in %s dataset, '
'some checks will not run', dataset_type)
msg = f'infer_on_batch() was not implemented in {dataset_type} ' \
f'dataset, some checks will not run'
except ValidationError as ex:
msg = f'batch_to_images() was not implemented correctly in {dataset_type}, the ' \
f'validation has failed with the error: {ex}. To test your prediction formatting use the ' \
f'function `vision_data.validate_prediction(batch, model, device)`'

if msg:
self._prediction_formatter_error[dataset_type] = msg
logger.warning(msg)

# The copy does 2 things: Sample n_samples if parameter exists, and shuffle the data.
# we shuffle because the data in VisionData is set to be sampled in a fixed order (in the init), so if the user
Expand Down Expand Up @@ -203,6 +218,12 @@ def assert_task_type(self, *expected_types: TaskType):
f'Check is irrelevant for task of type {self.train.task_type}')
return True

def assert_predictions_valid(self, kind: DatasetKind = None):
"""Assert that for given DatasetKind the model & dataset infer_on_batch return predictions in right format."""
error = self._prediction_formatter_error.get(kind)
if error:
raise DeepchecksValueError(error)

def get_data_by_kind(self, kind: DatasetKind):
"""Return the relevant VisionData by given kind."""
if kind == DatasetKind.TRAIN:
Expand Down
18 changes: 13 additions & 5 deletions deepchecks/vision/detection_data.py
Expand Up @@ -122,7 +122,7 @@ def infer_on_batch(self, batch, model, device) -> List[torch.Tensor]:
def get_classes(self, batch_labels: List[torch.Tensor]):
"""Get a labels batch and return classes inside it."""
def get_classes_from_single_label(tensor: torch.Tensor):
return list(tensor[:, 0].tolist()) if len(tensor) > 0 else []
return list(tensor[:, 0].type(torch.IntTensor).tolist()) if len(tensor) > 0 else []

return [get_classes_from_single_label(x) for x in batch_labels]

Expand All @@ -134,11 +134,12 @@ def validate_label(self, batch):
----------
batch
Returns
Raises
-------
Optional[str]
None if the label is valid, otherwise a string containing the error message.
DeepchecksValueError
If labels format is invalid
DeepchecksNotImplementedError
If batch_to_labels not implemented
"""
labels = self.batch_to_labels(batch)
if not isinstance(labels, list):
Expand All @@ -164,6 +165,13 @@ def validate_prediction(self, batch, model, device):
Batch from DataLoader
model : t.Any
device : torch.Device
Raises
------
DeepchecksValueError
If predictions format is invalid
DeepchecksNotImplementedError
If infer_on_batch not implemented
"""
batch_predictions = self.infer_on_batch(batch, model, device)
if not isinstance(batch_predictions, list):
Expand Down
2 changes: 1 addition & 1 deletion deepchecks/vision/utils/validation.py
Expand Up @@ -58,7 +58,7 @@ def set_seeds(seed: int):
def apply_to_tensor(
x: T,
fn: t.Callable[[torch.Tensor], torch.Tensor]
) -> T:
) -> t.Any:
"""Apply provided function to tensor instances recursivly."""
if isinstance(x, torch.Tensor):
return t.cast(T, fn(x))
Expand Down
4 changes: 2 additions & 2 deletions deepchecks/vision/vision_data.py
Expand Up @@ -421,12 +421,12 @@ def __len__(self):
"""Return the number of batches in the dataset dataloader."""
return len(self._data_loader)

def assert_image_formatter_valid(self):
def assert_images_valid(self):
"""Assert the image formatter defined is valid. Else raise exception."""
if self._image_formatter_error is not None:
raise DeepchecksValueError(self._image_formatter_error)

def assert_label_formatter_valid(self):
def assert_labels_valid(self):
"""Assert the label formatter defined is valid. Else raise exception."""
if self._label_formatter_error is not None:
raise DeepchecksValueError(self._label_formatter_error)
Expand Down
4 changes: 2 additions & 2 deletions tests/vision/base/test_custom_task.py
Expand Up @@ -25,12 +25,12 @@ class CustomData(VisionData):

# Assert
assert_that(
calling(data.assert_image_formatter_valid).with_args(),
calling(data.assert_images_valid).with_args(),
raises(DeepchecksValueError, r'batch_to_images\(\) was not implemented, some checks will not run')
)

assert_that(
calling(data.assert_label_formatter_valid).with_args(),
calling(data.assert_labels_valid).with_args(),
raises(DeepchecksValueError, r'batch_to_labels\(\) was not implemented, some checks will not run')
)

Expand Down
6 changes: 3 additions & 3 deletions tests/vision/base/test_vision_data.py
Expand Up @@ -304,7 +304,7 @@ def get_classes(self, batch_labels):

# Assert
assert_that(
calling(data.assert_label_formatter_valid).with_args(),
calling(data.assert_labels_valid).with_args(),
raises(DeepchecksValueError,
r'get_classes\(\) was not implemented correctly, the validation has failed with the error: "The classes '
r'must be a sequence\."\. '
Expand All @@ -323,7 +323,7 @@ def get_classes(self, batch_labels):

# Assert
assert_that(
calling(data.assert_label_formatter_valid).with_args(),
calling(data.assert_labels_valid).with_args(),
raises(DeepchecksValueError,
r'get_classes\(\) was not implemented correctly, the validation has failed with the error: "The '
r'classes sequence must contain as values sequences of ints \(sequence per sample\).". To test your '
Expand All @@ -342,7 +342,7 @@ def get_classes(self, batch_labels):

# Assert
assert_that(
calling(data.assert_label_formatter_valid).with_args(),
calling(data.assert_labels_valid).with_args(),
raises(DeepchecksValueError,
r'get_classes\(\) was not implemented correctly, the validation has failed with the error: "The '
r'samples sequence must contain only int values.". To test your formatting use the function '
Expand Down

0 comments on commit df47639

Please sign in to comment.