Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor for labels & image outlier #1091

Merged
merged 8 commits into from Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepchecks/core/errors.py
Expand Up @@ -19,7 +19,8 @@
'DatasetValidationError',
'ModelValidationError',
'DeepchecksNotImplementedError',
'ValidationError'
'ValidationError',
'DeepchecksBaseError'
]


Expand Down
3 changes: 2 additions & 1 deletion deepchecks/utils/strings.py
Expand Up @@ -42,7 +42,8 @@
'format_datetime',
'get_docs_summary',
'get_ellipsis',
'to_snake_case'
'to_snake_case',
'create_new_file_name'
]


Expand Down
5 changes: 3 additions & 2 deletions deepchecks/vision/__init__.py
Expand Up @@ -10,11 +10,12 @@
#
"""Package for vision functionality."""
import logging
from .batch_wrapper import Batch
from .vision_data import VisionData
from .context import Context
from .suite import Suite
from .classification_data import ClassificationData
from .detection_data import DetectionData
from .context import Context
from .suite import Suite, Batch
from .base_checks import SingleDatasetCheck, TrainTestCheck, ModelOnlyCheck


Expand Down
21 changes: 13 additions & 8 deletions deepchecks/vision/base_checks.py
Expand Up @@ -24,8 +24,7 @@
)
from deepchecks.vision.context import Context
from deepchecks.vision.vision_data import VisionData

from .context import Batch
from deepchecks.vision.batch_wrapper import Batch


logger = logging.getLogger('deepchecks')
Expand Down Expand Up @@ -61,10 +60,12 @@ def run(
self.initialize_run(context, DatasetKind.TRAIN)

context.train.init_cache()
batch_start_index = 0
for batch in context.train:
batch = Batch(batch, context, DatasetKind.TRAIN)
context.train.update_cache(batch.labels)
batch = Batch(batch, context, DatasetKind.TRAIN, batch_start_index)
context.train.update_cache(batch)
self.update(context, batch, DatasetKind.TRAIN)
batch_start_index += len(batch)

return self.finalize_check_result(self.compute(context, DatasetKind.TRAIN))

Expand Down Expand Up @@ -109,16 +110,20 @@ def run(
self.initialize_run(context)

context.train.init_cache()
batch_start_index = 0
for batch in context.train:
batch = Batch(batch, context, DatasetKind.TRAIN)
context.train.update_cache(batch.labels)
batch = Batch(batch, context, DatasetKind.TRAIN, batch_start_index)
context.train.update_cache(batch)
self.update(context, batch, DatasetKind.TRAIN)
batch_start_index += len(batch)

context.test.init_cache()
batch_start_index = 0
for batch in context.test:
batch = Batch(batch, context, DatasetKind.TEST)
context.test.update_cache(batch.labels)
batch = Batch(batch, context, DatasetKind.TEST, batch_start_index)
context.test.update_cache(batch)
self.update(context, batch, DatasetKind.TEST)
batch_start_index += len(batch)

return self.finalize_check_result(self.compute(context))

Expand Down
100 changes: 100 additions & 0 deletions deepchecks/vision/batch_wrapper.py
@@ -0,0 +1,100 @@
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2022 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
"""Contains code for BatchWrapper."""
from typing import Tuple, Iterable, Any, TypeVar, Callable, cast

import torch

from deepchecks.core import DatasetKind


__all__ = ['Batch']


class Batch:
"""Represents dataset batch returned by the dataloader during iteration."""

def __init__(
self,
batch: Tuple[Iterable[Any], Iterable[Any]],
context: 'Context', # noqa
dataset_kind: DatasetKind,
batch_start_index: int
):
self._context = context
self._dataset_kind = dataset_kind
self.batch_start_index = batch_start_index
self._batch = apply_to_tensor(batch, lambda it: it.to(self._context.device))
self._labels = None
self._predictions = None
self._images = None

@property
def labels(self):
"""Return labels for the batch, formatted in deepchecks format."""
if self._labels is None:
dataset = self._context.get_data_by_kind(self._dataset_kind)
dataset.assert_labels_valid()
self._labels = dataset.batch_to_labels(self._batch)
return self._labels

@property
def predictions(self):
"""Return predictions for the batch, formatted in deepchecks format."""
if self._predictions is None:
dataset = self._context.get_data_by_kind(self._dataset_kind)
# Calling model will raise error if model was not given
model = self._context.model
self._context.assert_predictions_valid(self._dataset_kind)
self._predictions = dataset.infer_on_batch(self._batch, model, self._context.device)
return self._predictions

@property
def images(self):
"""Return images for the batch, formatted in deepchecks format."""
if self._images is None:
dataset = self._context.get_data_by_kind(self._dataset_kind)
dataset.assert_images_valid()
self._images = dataset.batch_to_images(self._batch)
return self._images

def __getitem__(self, index):
"""Return batch item by index."""
return self._batch[index]

def __len__(self):
"""Return length of batch."""
return len(self._batch)

def get_index_in_dataset(self, index):
"""For given index in this batch returns the real index in the underlying dataset object. Can be used to \
later get samples for display."""
dataset = self._context.get_data_by_kind(self._dataset_kind)
return dataset.batch_index_to_dataset_index(self.batch_start_index + index)


T = TypeVar('T')


def apply_to_tensor(
x: T,
fn: Callable[[torch.Tensor], torch.Tensor]
) -> Any:
"""Apply provided function to tensor instances recursivly."""
if isinstance(x, torch.Tensor):
return cast(T, fn(x))
elif isinstance(x, (str, bytes, bytearray)):
return x
elif isinstance(x, (list, tuple, set)):
return type(x)(apply_to_tensor(it, fn) for it in x)
elif isinstance(x, dict):
return type(x)((k, apply_to_tensor(v, fn)) for k, v in x.items())
return x
Expand Up @@ -24,7 +24,8 @@
DEFAULT_CLASSIFICATION_LABEL_PROPERTIES,
DEFAULT_OBJECT_DETECTION_LABEL_PROPERTIES,
validate_properties,
get_column_type
get_column_type,
properties_flatten
)


Expand Down Expand Up @@ -127,7 +128,8 @@ def update(self, context: Context, batch: Batch, dataset_kind):
raise DeepchecksNotSupportedError(f'Unsupported dataset kind {dataset_kind}')

for label_property in self._label_properties:
properties[label_property['name']] += label_property['method'](batch.labels)
# Flatten the properties since I don't care in this check about the property-per-sample coupling
properties[label_property['name']] += properties_flatten(label_property['method'](batch.labels))

def compute(self, context: Context) -> CheckResult:
"""Calculate drift on label properties samples that were collected during update() calls.
Expand Down
Expand Up @@ -21,7 +21,8 @@
from deepchecks.vision import Context, TrainTestCheck, Batch
from deepchecks.vision.vision_data import TaskType
from deepchecks.vision.utils.label_prediction_properties import validate_properties, \
DEFAULT_CLASSIFICATION_PREDICTION_PROPERTIES, DEFAULT_OBJECT_DETECTION_PREDICTION_PROPERTIES, get_column_type
DEFAULT_CLASSIFICATION_PREDICTION_PROPERTIES, DEFAULT_OBJECT_DETECTION_PREDICTION_PROPERTIES, get_column_type, \
properties_flatten

__all__ = ['TrainTestPredictionDrift']

Expand Down Expand Up @@ -120,7 +121,10 @@ def update(self, context: Context, batch: Batch, dataset_kind):
raise DeepchecksNotSupportedError(f'Unsupported dataset kind {dataset_kind}')

for prediction_property in self._prediction_properties:
properties[prediction_property['name']] += prediction_property['method'](batch.predictions)
# Flatten the properties since I don't care in this check about the property-per-sample coupling
properties[prediction_property['name']] += properties_flatten(
prediction_property['method'](batch.predictions)
)

def compute(self, context: Context) -> CheckResult:
"""Calculate drift on prediction properties samples that were collected during update() calls.
Expand Down
55 changes: 2 additions & 53 deletions deepchecks/vision/context.py
Expand Up @@ -10,77 +10,26 @@
#
"""Module for base vision context."""
import logging
from typing import Mapping, Union, Iterable, Any, Tuple
from typing import Mapping, Union

import torch
from torch import nn
from ignite.metrics import Metric

from deepchecks.core import DatasetKind
from deepchecks.vision.vision_data import VisionData, TaskType
from deepchecks.vision.utils.validation import apply_to_tensor
from deepchecks.core.errors import (
DatasetValidationError, DeepchecksNotImplementedError, ModelValidationError,
DeepchecksNotSupportedError, DeepchecksValueError, ValidationError
)


__all__ = ['Context', 'Batch']
__all__ = ['Context']


logger = logging.getLogger('deepchecks')


class Batch:
"""Represents dataset batch returned by the dataloader during iteration."""

def __init__(
self,
batch: Tuple[Iterable[Any], Iterable[Any]],
context: 'Context',
dataset_kind: DatasetKind
):
self._context = context
self._dataset_kind = dataset_kind
self._batch = apply_to_tensor(batch, lambda it: it.to(self._context.device))
self._labels = None
self._predictions = None
self._images = None

@property
def labels(self):
"""Return labels for the batch, formatted in deepchecks format."""
if self._labels is None:
dataset = self._context.get_data_by_kind(self._dataset_kind)
dataset.assert_labels_valid()
self._labels = dataset.batch_to_labels(self._batch)
return self._labels

@property
def predictions(self):
"""Return predictions for the batch, formatted in deepchecks format."""
if self._predictions is None:
dataset = self._context.get_data_by_kind(self._dataset_kind)
# Calling model will raise error if model was not given
model = self._context.model
self._context.assert_predictions_valid(self._dataset_kind)
self._predictions = dataset.infer_on_batch(self._batch, model, self._context.device)
return self._predictions

@property
def images(self):
"""Return images for the batch, formatted in deepchecks format."""
if self._images is None:
dataset = self._context.get_data_by_kind(self._dataset_kind)
dataset.assert_images_valid()
self._images = dataset.batch_to_images(self._batch)
return self._images

def __getitem__(self, index):
"""Return batch item by index."""
return self._batch[index]


class Context:
"""Contains all the data + properties the user has passed to a check/suite, and validates it seamlessly.

Expand Down
10 changes: 6 additions & 4 deletions deepchecks/vision/suite.py
Expand Up @@ -26,8 +26,7 @@
from deepchecks.vision.base_checks import ModelOnlyCheck, SingleDatasetCheck, TrainTestCheck
from deepchecks.vision.context import Context
from deepchecks.vision.vision_data import VisionData

from .context import Batch
from deepchecks.vision.batch_wrapper import Batch


__all__ = ['Suite']
Expand Down Expand Up @@ -181,10 +180,11 @@ def _update_loop(
progress_bars.append(progress_bar)

# Run on all the batches
batch_start_index = 0
for batch_id, batch in enumerate(vision_data):
progress_bar.set_text(f'{100 * batch_id / (1. * n_batches):.0f}%')
batch = Batch(batch, context, dataset_kind)
vision_data.update_cache(batch.labels)
batch = Batch(batch, context, dataset_kind, batch_start_index)
vision_data.update_cache(batch)
for check_idx, check in self.checks.items():
# If index in results the check already failed before
if check_idx in results:
Expand All @@ -204,6 +204,8 @@ def _update_loop(
raise TypeError(f'Don\'t know how to handle type {check.__class__.__name__} in suite.')
except Exception as exp:
results[check_idx] = CheckFailure(check, exp, type_suffix)

batch_start_index += len(batch)
progress_bar.inc_progress()

# SingleDatasetChecks have different handling. If we had failure in them need to add suffix to the index of
Expand Down