Skip to content

Commit

Permalink
Fix vision suite run (#982)
Browse files Browse the repository at this point in the history
* Update suite run to remove code duplication and fix bug

* Fix typo

* Add missing check to suite

* Fix lint
  • Loading branch information
matanper committed Mar 9, 2022
1 parent d676fde commit 0d48726
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 79 deletions.
3 changes: 2 additions & 1 deletion deepchecks/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
'DeepchecksProcessError',
'NumberOfFeaturesLimitError',
'DatasetValidationError',
'ModelValidationError'
'ModelValidationError',
'DeepchecksNotImplementedError'
]


Expand Down
124 changes: 48 additions & 76 deletions deepchecks/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch
from torch import nn
from torch.utils.data import DataLoader
from ignite.metrics import Metric

from deepchecks.core.check import (
Expand Down Expand Up @@ -366,116 +365,68 @@ def run(
random_state=random_state
)

# Create instances of SingleDatasetCheck for train and test if train and test exist.
# This is needed because in the vision package checks update their internal state with update, so it will be
# easier to iterate and keep the check order if we have an instance for each dataset.
checks: Dict[
Union[str, int],
Union[SingleDatasetCheck, TrainTestCheck, ModelOnlyCheck]
] = OrderedDict({})

results: Dict[
Union[str, int],
Union[CheckResult, CheckFailure]
] = OrderedDict({})

for check_idx, check in list(self.checks.items()):
if isinstance(check, (TrainTestCheck, ModelOnlyCheck)):
run_train_test_checks = train_dataset is not None and test_dataset is not None

# Initialize here all the checks that are not single dataset, since those are initialized inside the update loop
for index, check in self.checks.items():
if not isinstance(check, SingleDatasetCheck):
try:
check.initialize_run(context)
except Exception as exp:
results[check_idx] = CheckFailure(check, exp)
checks[check_idx] = check
elif isinstance(check, SingleDatasetCheck):
if train_dataset is not None:
checks[str(check_idx) + ' - Train'] = check
if test_dataset is not None:
checks[str(check_idx) + ' - Test'] = check
else:
raise DeepchecksNotSupportedError(f'Don\'t know to handle check type {type(check)}')

run_train_test_checks = train_dataset is not None and test_dataset is not None
results[index] = CheckFailure(check, exp)

if train_dataset is not None:
self._update_loop(
checks=checks,
data_loader=train_dataset.data_loader,
context=context,
run_train_test_checks=run_train_test_checks,
results=results,
dataset_kind=DatasetKind.TRAIN
)
for check_idx, check in checks.items():
if check_idx not in results:
if str(check_idx).endswith('Train'):
try:
results[check_idx] = check.compute(context, dataset_kind=DatasetKind.TRAIN)
except Exception as exp:
results[check_idx] = CheckFailure(check, exp, ' - Train')

if test_dataset is not None:
self._update_loop(
checks=checks,
data_loader=test_dataset.data_loader,
context=context,
run_train_test_checks=run_train_test_checks,
results=results,
dataset_kind=DatasetKind.TEST
)
for check_idx, check in checks.items():
if check_idx not in results:
if str(check_idx).endswith('Test'):
try:
results[check_idx] = check.compute(context, dataset_kind=DatasetKind.TEST)
except Exception as exp:
results[check_idx] = CheckFailure(check, exp, ' - Test')

for check_idx, check in checks.items():
if check_idx not in results:
try:
if not isinstance(check, SingleDatasetCheck):
results[check_idx] = check.compute(context)
except Exception as exp:
results[check_idx] = CheckFailure(check, exp)

# Update check result names for SingleDatasetChecks and finalize results
for check_idx, result in results.items():
if isinstance(result, CheckResult):
result = finalize_check_result(result, checks[check_idx])
results[check_idx] = result
# Update header only if both train and test ran
if run_train_test_checks:
result.header = (
f'{result.get_header()} - Train Dataset'
if str(check_idx).endswith(' - Train')
else f'{result.get_header()} - Test Dataset'
)

for check_idx, check in self.checks.items():
try:
# if check index in results we had failure, and SingleDatasetCheck have already been calculated inside
# the loops
if check_idx not in results and not isinstance(check, SingleDatasetCheck):
result = check.compute(context)
result = finalize_check_result(result, check)
results[check_idx] = result
except Exception as exp:
results[check_idx] = CheckFailure(check, exp)

# The results are ordered as they ran instead of in the order they were defined, therefore sort by key
sorted_result_values = [value for name, value in sorted(results.items(), key=lambda pair: str(pair[0]))]
return SuiteResult(self.name, sorted_result_values)

def _update_loop(
self,
checks: Dict[
Union[str, int],
Union[SingleDatasetCheck, TrainTestCheck, ModelOnlyCheck]
],
data_loader: DataLoader,
context: Context,
run_train_test_checks: bool,
results: Dict[Union[str, int], Union[CheckResult, CheckFailure]],
dataset_kind
dataset_kind: DatasetKind
):
if dataset_kind == DatasetKind.TEST:
type_suffix = ' - Test'
else:
type_suffix = ' - Train'
type_suffix = ' - Test Dataset' if dataset_kind == DatasetKind.TEST else ' - Train Dataset'
data_loader = context.get_data_by_kind(dataset_kind)
n_batches = len(data_loader)
progress_bar = ProgressBar(self.name + type_suffix, n_batches)

for idx, check in checks.items():
if str(idx).endswith(type_suffix):
# SingleDatasetChecks have different handling, need to initialize them here (to have them ready for different
# dataset kind)
for idx, check in self.checks.items():
if isinstance(check, SingleDatasetCheck):
try:
check.initialize_run(context, dataset_kind=dataset_kind)
except Exception as exp:
Expand All @@ -484,7 +435,10 @@ def _update_loop(
for batch_id, batch in enumerate(data_loader):
progress_bar.set_text(f'{100 * batch_id / (1. * n_batches):.0f}%')
batch = apply_to_tensor(batch, lambda it: it.to(context.device))
for check_idx, check in checks.items():
for check_idx, check in self.checks.items():
# If index in results the check already failed before
if check_idx in results:
continue
try:
if isinstance(check, TrainTestCheck):
if run_train_test_checks is True:
Expand All @@ -493,8 +447,7 @@ def _update_loop(
msg = 'Check is irrelevant if not supplied with both train and test datasets'
results[check_idx] = self._get_unsupported_failure(check, msg)
elif isinstance(check, SingleDatasetCheck):
if str(check_idx).endswith(type_suffix):
check.update(context, batch, dataset_kind=dataset_kind)
check.update(context, batch, dataset_kind=dataset_kind)
elif isinstance(check, ModelOnlyCheck):
pass
else:
Expand All @@ -506,6 +459,25 @@ def _update_loop(

progress_bar.close()

# SingleDatasetChecks have different handling. If we had failure in them need to add suffix to the index of
# the results, else need to compute it.
for idx, check in self.checks.items():
if isinstance(check, SingleDatasetCheck):
index_of_kind = str(idx) + type_suffix
# If index in results we had a failure
if idx in results:
results[index_of_kind] = results.pop(idx)
continue
try:
result = check.compute(context, dataset_kind=dataset_kind)
result = finalize_check_result(result, check)
# Update header with dataset type only if both train and test ran
if run_train_test_checks:
result.header = result.get_header() + type_suffix
results[index_of_kind] = result
except Exception as exp:
results[index_of_kind] = CheckFailure(check, exp, type_suffix)

@classmethod
def _get_unsupported_failure(cls, check, msg):
return CheckFailure(check, DeepchecksNotSupportedError(msg))
5 changes: 3 additions & 2 deletions deepchecks/vision/suites/default_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
from deepchecks.vision.checks import ClassPerformance, TrainTestLabelDrift, MeanAveragePrecisionReport, \
MeanAverageRecallReport, ImagePropertyDrift, ImageDatasetDrift, SimpleModelComparison, ConfusionMatrixReport, \
RobustnessReport, TrainTestPredictionDrift
RobustnessReport, TrainTestPredictionDrift, ImageSegmentPerformance
from deepchecks.vision import Suite


Expand Down Expand Up @@ -46,7 +46,8 @@ def model_evaluation() -> Suite:
MeanAverageRecallReport(),
SimpleModelComparison(),
ConfusionMatrixReport(),
RobustnessReport().add_condition_degradation_not_greater_than()
RobustnessReport().add_condition_degradation_not_greater_than(),
ImageSegmentPerformance()
)


Expand Down

0 comments on commit 0d48726

Please sign in to comment.