Skip to content

Commit

Permalink
Static pred vision to use dicts (#1653)
Browse files Browse the repository at this point in the history
* static_pred_vision_v2

* static_pred_vision_v2

* static_pred_vision_v2

* static_pred_vision_v2

* seed

* seq

* Apply suggestions from code review

* fixes

* fixes
  • Loading branch information
JKL98ISR committed Jun 19, 2022
1 parent 379db4b commit 902ee4b
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 94 deletions.
22 changes: 8 additions & 14 deletions deepchecks/vision/base_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,15 @@ def run(
self.initialize_run(context, DatasetKind.TRAIN)

context.train.init_cache()
batch_start_index = 0

for batch in progressbar_factory.create(
for i, batch in enumerate(progressbar_factory.create(
iterable=context.train,
name='Ingesting Batches',
unit='Batch'
):
batch = Batch(batch, context, DatasetKind.TRAIN, batch_start_index)
)):
batch = Batch(batch, context, DatasetKind.TRAIN, i)
context.train.update_cache(batch)
self.update(context, batch, DatasetKind.TRAIN)
batch_start_index += len(batch)

with progressbar_factory.create_dummy(name='Computing Check', unit='Check'):
result = self.compute(context, DatasetKind.TRAIN)
Expand Down Expand Up @@ -139,26 +137,22 @@ def run(
)

context.train.init_cache()
batch_start_index = 0

for batch in train_pbar:
batch = Batch(batch, context, DatasetKind.TRAIN, batch_start_index)
for i, batch in enumerate(train_pbar):
batch = Batch(batch, context, DatasetKind.TRAIN, i)
context.train.update_cache(batch)
self.update(context, batch, DatasetKind.TRAIN)
batch_start_index += len(batch)

context.test.init_cache()
batch_start_index = 0

for batch in progressbar_factory.create(
for i, batch in enumerate(progressbar_factory.create(
iterable=context.test,
name='Ingesting Batches - Test Dataset',
unit='Batch'
):
batch = Batch(batch, context, DatasetKind.TEST, batch_start_index)
)):
batch = Batch(batch, context, DatasetKind.TEST, i)
context.test.update_cache(batch)
self.update(context, batch, DatasetKind.TEST)
batch_start_index += len(batch)

with progressbar_factory.create_dummy(name='Computing Check', unit='Check'):
result = self.compute(context)
Expand Down
28 changes: 9 additions & 19 deletions deepchecks/vision/batch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch

from deepchecks.core import DatasetKind
from deepchecks.vision.task_type import TaskType

if TYPE_CHECKING:
from deepchecks.vision.context import Context
Expand All @@ -31,11 +32,11 @@ def __init__(
batch: Tuple[Iterable[Any], Iterable[Any]],
context: 'Context', # noqa
dataset_kind: DatasetKind,
batch_start_index: int
batch_index: int
):
self._context = context
self._dataset_kind = dataset_kind
self.batch_start_index = batch_start_index
self.batch_index = batch_index
self._batch = apply_to_tensor(batch, lambda it: it.to(self._context.device))
self._labels = None
self._predictions = None
Expand All @@ -53,11 +54,11 @@ def labels(self):
def _do_static_pred(self):
preds = self._context.static_predictions[self._dataset_kind]
dataset = self._context.get_data_by_kind(self._dataset_kind)
indexes = [dataset.to_dataset_index(self.batch_start_index + index)[0]
for index in range(len(self))]
if isinstance(preds, torch.Tensor):
return preds[indexes]
return itemgetter(*indexes)(preds)
indexes = list(dataset.data_loader.batch_sampler)[self.batch_index]
preds = itemgetter(*indexes)(preds)
if dataset.task_type == TaskType.CLASSIFICATION:
return torch.stack(preds)
return preds

@property
def predictions(self):
Expand Down Expand Up @@ -91,18 +92,7 @@ def __getitem__(self, index: int):
def __len__(self):
"""Return length of batch."""
dataset = self._context.get_data_by_kind(self._dataset_kind)
dataset_len = dataset.num_samples
dataloader_len = len(dataset.data_loader)
max_len = int(dataset_len / dataloader_len + 0.5)
if self.batch_start_index + max_len > dataset_len: # last batch
return dataset_len - self.batch_start_index
return max_len

def get_index_in_dataset(self, index: int) -> int:
"""For given index in this batch returns the real index in the underlying dataset object. Can be used to \
later get samples for display."""
dataset = self._context.get_data_by_kind(self._dataset_kind)
return dataset.to_dataset_index(self.batch_start_index + index)[0]
return len(list(dataset.data_loader.batch_sampler)[self.batch_index])


T = TypeVar('T')
Expand Down
12 changes: 8 additions & 4 deletions deepchecks/vision/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
# ----------------------------------------------------------------------------
#
"""Module for base vision context."""
from typing import Dict, List, Mapping, Union
from operator import itemgetter
from typing import Dict, Mapping, Sequence, Union

import torch
from ignite.metrics import Metric
Expand Down Expand Up @@ -63,8 +64,8 @@ def __init__(self,
device: Union[str, torch.device, None] = None,
random_state: int = 42,
n_samples: int = None,
train_predictions: Union[List[torch.Tensor], torch.Tensor] = None,
test_predictions: Union[List[torch.Tensor], torch.Tensor] = None,
train_predictions: Dict[int, Union[Sequence[torch.Tensor], torch.Tensor]] = None,
test_predictions: Dict[int, Union[Sequence[torch.Tensor], torch.Tensor]] = None,
):
# Validations
if train is None and test is None and model is None:
Expand Down Expand Up @@ -116,7 +117,10 @@ def __init__(self,
[train_predictions, test_predictions]):
if dataset is not None:
try:
dataset.validate_infered_batch_predictions(predictions)
preds = itemgetter(*list(dataset.data_loader.batch_sampler)[0])(predictions)
if dataset.task_type == TaskType.CLASSIFICATION:
preds = torch.stack(preds)
dataset.validate_infered_batch_predictions(preds)
msg = None
self._static_predictions[dataset_type] = predictions
except ValidationError as ex:
Expand Down
20 changes: 10 additions & 10 deletions deepchecks/vision/detection_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#
"""The vision/dataset module containing the vision Dataset class and its functions."""
from abc import abstractmethod
from typing import List
from typing import List, Sequence

import torch

Expand Down Expand Up @@ -68,7 +68,7 @@ def batch_to_labels(self, batch) -> List[torch.Tensor]:
raise DeepchecksNotImplementedError('batch_to_labels() must be implemented in a subclass')

@abstractmethod
def infer_on_batch(self, batch, model, device) -> List[torch.Tensor]:
def infer_on_batch(self, batch, model, device) -> Sequence[torch.Tensor]:
"""Return the predictions of the model on a batch of data.
Parameters
Expand All @@ -82,8 +82,8 @@ def infer_on_batch(self, batch, model, device) -> List[torch.Tensor]:
Returns
-------
List[torch.Tensor]
The predictions of the model on the batch. The predictions should be in a List of length N containing
Sequence[torch.Tensor]
The predictions of the model on the batch. The predictions should be in a sequence of length N containing
tensors of shape (B, 6), where N is the number of images, B is the number of bounding boxes detected in the
sample and each bounding box is represented by 6 values. See the notes for more info.
Expand Down Expand Up @@ -182,23 +182,23 @@ def validate_infered_batch_predictions(batch_predictions):
DeepchecksNotImplementedError
If infer_on_batch not implemented
"""
if not isinstance(batch_predictions, list):
raise ValidationError('Check requires detection predictions to be a list with an entry for each'
if not isinstance(batch_predictions, Sequence):
raise ValidationError('Check requires detection predictions to be a sequence with an entry for each'
' sample')
if len(batch_predictions) == 0:
raise ValidationError('Check requires detection predictions to be a non-empty list')
raise ValidationError('Check requires detection predictions to be a non-empty sequence')
if not isinstance(batch_predictions[0], torch.Tensor):
raise ValidationError('Check requires detection predictions to be a list of torch.Tensor')
raise ValidationError('Check requires detection predictions to be a sequence of torch.Tensor')
sample_idx = 0
# Find a non empty tensor to validate
while batch_predictions[sample_idx].shape[0] == 0:
sample_idx += 1
if sample_idx == len(batch_predictions):
return # No predictions to validate
if len(batch_predictions[sample_idx].shape) != 2:
raise ValidationError('Check requires detection predictions to be a list of 2D tensors')
raise ValidationError('Check requires detection predictions to be a sequence of 2D tensors')
if batch_predictions[sample_idx].shape[1] != 6:
raise ValidationError('Check requires detection predictions to be a list of 2D tensors, when '
raise ValidationError('Check requires detection predictions to be a sequence of 2D tensors, when '
'each row has 6 columns: [x, y, width, height, class_probability, class_id]')
if torch.min(batch_predictions[sample_idx]) < 0:
raise ValidationError('Found one of coordinates to be negative, Check requires object detection '
Expand Down
7 changes: 2 additions & 5 deletions deepchecks/vision/suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,8 @@ def _update_loop(
)

# Run on all the batches
batch_start_index = 0
for batch in batches_pbar:
batch = Batch(batch, context, dataset_kind, batch_start_index)
for i, batch in enumerate(batches_pbar):
batch = Batch(batch, context, dataset_kind, i)
vision_data.update_cache(batch)
for check_idx, check in self.checks.items():
# If index in results the check already failed before
Expand All @@ -208,8 +207,6 @@ def _update_loop(
except Exception as exp:
results[check_idx] = CheckFailure(check, exp, type_suffix)

batch_start_index += len(batch)

# SingleDatasetChecks have different handling. If we had failure in them need to add suffix to the index of
# the results, else need to compute it.
if single_dataset_checks:
Expand Down
73 changes: 39 additions & 34 deletions tests/vision/base/test_static_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,48 +9,27 @@
# ----------------------------------------------------------------------------
#
#
import copy

import numpy as np
import torch
from hamcrest import (assert_that, calling, close_to, equal_to, has_entries, has_items, has_length, has_properties,
has_property, instance_of, is_, raises)
from hamcrest import (assert_that, close_to, equal_to, has_entries, has_items, has_length)

from deepchecks.core.check_result import CheckResult
from deepchecks.vision.base_checks import SingleDatasetCheck
from deepchecks.vision.batch_wrapper import Batch
from deepchecks.vision.checks.model_evaluation.class_performance import ClassPerformance
from deepchecks.vision.checks.model_evaluation.image_segment_performance import ImageSegmentPerformance
from deepchecks.vision.checks.model_evaluation.train_test_prediction_drift import TrainTestPredictionDrift
from deepchecks.vision.context import Context
from deepchecks.vision.suites.default_suites import full_suite
from deepchecks.vision.task_type import TaskType
from deepchecks.vision.vision_data import VisionData
from tests.base.utils import equal_condition_result
from tests.conftest import get_expected_results_length, validate_suite_result


class _StaticPred(SingleDatasetCheck):
def initialize_run(self, context: Context, dataset_kind):
self._pred_index = {}

def update(self, context: Context, batch: Batch, dataset_kind):
predictions = batch.predictions
indexes = [batch.get_index_in_dataset(index) for index in range(len(predictions))]
self._pred_index.update(dict(zip(indexes, predictions)))

def compute(self, context: Context, dataset_kind) -> CheckResult:
sorted_values = [v for _, v in sorted(self._pred_index.items(), key=lambda item: item[0])]
if context.get_data_by_kind(dataset_kind).task_type == TaskType.CLASSIFICATION:
sorted_values = torch.stack(sorted_values)
return CheckResult(sorted_values)


def _create_static_predictions(train: VisionData, test: VisionData, model):
def _create_static_predictions(train: VisionData, test: VisionData, model, device):
static_preds = []
for vision_data in [train, test]:
if vision_data is not None:
static_pred = _StaticPred().run(vision_data, model=model, n_samples=None).value
static_pred = {}
for i, batch in enumerate(vision_data):
predictions = vision_data.infer_on_batch(batch, model, device)
indexes = list(vision_data.data_loader.batch_sampler)[i]
static_pred.update(dict(zip(indexes, predictions)))
else:
static_pred = None
static_preds.append(static_pred)
Expand All @@ -61,7 +40,8 @@ def _create_static_predictions(train: VisionData, test: VisionData, model):
# copied from class_performance_test
def test_class_performance_mnist_largest(mnist_dataset_train, mnist_dataset_test, mock_trained_mnist, device):
# Arrange
train_preds, tests_preds = _create_static_predictions(mnist_dataset_train, mnist_dataset_test, mock_trained_mnist)
train_preds, tests_preds = _create_static_predictions(mnist_dataset_train, mnist_dataset_test,
mock_trained_mnist, device)
check = ClassPerformance(n_to_show=2, show_only='largest')
# Act
result = check.run(mnist_dataset_train, mnist_dataset_test,
Expand All @@ -76,10 +56,32 @@ def test_class_performance_mnist_largest(mnist_dataset_train, mnist_dataset_test
assert_that(first_row['Class'], equal_to(1))


# copied from class_performance_test but added a sample before
def test_class_performance_mnist_largest_sampled_before(mnist_dataset_train, mnist_dataset_test, mock_trained_mnist, device):
# Arrange
sampled_train = mnist_dataset_train.copy(shuffle=True, n_samples=1000, random_state=42)
sampled_test = mnist_dataset_test.copy(shuffle=True, n_samples=1000, random_state=42)
train_preds, tests_preds = _create_static_predictions(sampled_train, sampled_test,
mock_trained_mnist, device)
check = ClassPerformance(n_to_show=2, show_only='largest')
# Act
result = check.run(sampled_train, sampled_test,
train_predictions=train_preds, test_predictions=tests_preds,
device=device, n_samples=None)
first_row = result.value.sort_values(by='Number of samples', ascending=False).iloc[0]
# Assert
assert_that(len(set(result.value['Class'])), equal_to(2))
assert_that(len(result.value), equal_to(8))
assert_that(first_row['Value'], close_to(0.991, 0.001))
assert_that(first_row['Number of samples'], equal_to(123))
assert_that(first_row['Class'], equal_to(2))


# copied from class_performance_test but sampled
def test_class_performance_mnist_largest_sampled(mnist_dataset_train, mnist_dataset_test, mock_trained_mnist, device):
# Arrange
train_preds, tests_preds = _create_static_predictions(mnist_dataset_train, mnist_dataset_test, mock_trained_mnist)
train_preds, tests_preds = _create_static_predictions(mnist_dataset_train, mnist_dataset_test,
mock_trained_mnist, device)
check = ClassPerformance(n_to_show=2, show_only='largest')
# Act
result = check.run(mnist_dataset_train, mnist_dataset_test,
Expand All @@ -95,9 +97,10 @@ def test_class_performance_mnist_largest_sampled(mnist_dataset_train, mnist_data


# copied from image_segment_performance_test
def test_image_segment_performance_coco_and_condition(coco_train_visiondata, mock_trained_yolov5_object_detection):
def test_image_segment_performance_coco_and_condition(coco_train_visiondata, mock_trained_yolov5_object_detection, device):
# Arrange
train_preds, _ = _create_static_predictions(coco_train_visiondata, None, mock_trained_yolov5_object_detection)
train_preds, _ = _create_static_predictions(coco_train_visiondata, None,
mock_trained_yolov5_object_detection, device)
check = ImageSegmentPerformance().add_condition_score_from_mean_ratio_greater_than(0.5) \
.add_condition_score_from_mean_ratio_greater_than(0.1)
# Act
Expand Down Expand Up @@ -151,7 +154,8 @@ def test_train_test_prediction_with_drift_object_detection_change_max_cat(coco_t
# Arrange
train_preds, test_preds = _create_static_predictions(coco_train_visiondata,
coco_test_visiondata,
mock_trained_yolov5_object_detection)
mock_trained_yolov5_object_detection,
device)
check = TrainTestPredictionDrift(categorical_drift_method='PSI', max_num_categories_for_drift=100)

# Act
Expand All @@ -178,7 +182,8 @@ def test_suite(coco_train_visiondata, coco_test_visiondata,
mock_trained_yolov5_object_detection, device):
train_preds, test_preds = _create_static_predictions(coco_train_visiondata,
coco_test_visiondata,
mock_trained_yolov5_object_detection)
mock_trained_yolov5_object_detection,
device)

args = dict(train_dataset=coco_train_visiondata, test_dataset=coco_test_visiondata,
train_predictions=train_preds, test_predictions=test_preds)
Expand Down
8 changes: 4 additions & 4 deletions tests/vision/base/test_vision_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,17 +399,17 @@ def infer_on_batch(self, batch, model, device):

assert_that(calling(detection_data.validate_prediction).with_args(7, None, None),
raises(ValidationError,
'Check requires detection predictions to be a list with an entry for each sample'))
'Check requires detection predictions to be a sequence with an entry for each sample'))
assert_that(calling(detection_data.validate_prediction).with_args([], None, None),
raises(ValidationError,
'Check requires detection predictions to be a non-empty list'))
'Check requires detection predictions to be a non-empty sequence'))
assert_that(calling(detection_data.validate_prediction).with_args([8], None, None),
raises(ValidationError,
'Check requires detection predictions to be a list of torch.Tensor'))
'Check requires detection predictions to be a sequence of torch.Tensor'))
assert_that(detection_data.validate_prediction([torch.Tensor([])], None, None), equal_to(None))
assert_that(calling(detection_data.validate_prediction).with_args([torch.Tensor([[1, 2], [1, 2]])], None, None),
raises(ValidationError,
'Check requires detection predictions to be a list of 2D tensors, when '
'Check requires detection predictions to be a sequence of 2D tensors, when '
'each row has 6 columns: \[x, y, width, height, class_probability, class_id\]'))


Expand Down

0 comments on commit 902ee4b

Please sign in to comment.