Skip to content

Commit

Permalink
fix suite progress bar in single checks (#1232)
Browse files Browse the repository at this point in the history
* v1

* v1

* add_single_tests

* add_single_tests

* real_fix
  • Loading branch information
JKL98ISR committed Apr 10, 2022
1 parent 6fa04e6 commit 7a0b9de
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
26 changes: 13 additions & 13 deletions deepchecks/vision/suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ def _update_loop(

# Run on all the batches
batch_start_index = 0
for batch_id, batch in enumerate(vision_data):
progress_bar.set_text(f'{100 * batch_id / (1. * n_batches):.0f}%')
for batch in vision_data:
batch = Batch(batch, context, dataset_kind, batch_start_index)
vision_data.update_cache(batch)
for check_idx, check in self.checks.items():
Expand Down Expand Up @@ -215,7 +214,8 @@ def _update_loop(
# SingleDatasetChecks have different handling. If we had failure in them need to add suffix to the index of
# the results, else need to compute it.
if single_dataset_checks:
progress_bar = ProgressBar('Computing Single Dataset Checks' + type_suffix, len(single_dataset_checks),
progress_bar = ProgressBar('Computing Single Dataset Checks' + type_suffix,
len(single_dataset_checks),
unit='Check')
progress_bars.append(progress_bar)
for idx, check in single_dataset_checks.items():
Expand All @@ -224,16 +224,16 @@ def _update_loop(
# If index in results we had a failure
if idx in results:
results[index_of_kind] = results.pop(idx)
continue
try:
result = check.compute(context, dataset_kind=dataset_kind)
result = check.finalize_check_result(result)
# Update header with dataset type only if both train and test ran
if run_train_test_checks:
result.header = result.get_header() + type_suffix
results[index_of_kind] = result
except Exception as exp:
results[index_of_kind] = CheckFailure(check, exp, type_suffix)
else:
try:
result = check.compute(context, dataset_kind=dataset_kind)
result = check.finalize_check_result(result)
# Update header with dataset type only if both train and test ran
if run_train_test_checks:
result.header = result.get_header() + type_suffix
results[index_of_kind] = result
except Exception as exp:
results[index_of_kind] = CheckFailure(check, exp, type_suffix)
progress_bar.inc_progress()

@classmethod
Expand Down
32 changes: 24 additions & 8 deletions tests/vision/base/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,35 @@ def compute(self, context) -> CheckResult:

assert_that(executions, is_({'initialize_run': 1, 'compute': 1}))


def test_full_suite_execution_mnist(mnist_dataset_train, mnist_dataset_test, mock_trained_mnist, device):
suite = full_suite()
args = {'train_dataset': mnist_dataset_train, 'test_dataset': mnist_dataset_test,
'model':mock_trained_mnist, 'device': device}
result = suite.run(**args)
length = get_expected_results_length(suite, args)
validate_suite_result(result, length)
arguments = (
dict(train_dataset=mnist_dataset_train, test_dataset=mnist_dataset_test,
model=mock_trained_mnist, device=device),
dict(train_dataset=mnist_dataset_train,
model=mock_trained_mnist, device=device),
)

for args in arguments:
result = suite.run(**args)
length = get_expected_results_length(suite, args)
validate_suite_result(result, length)


def test_full_suite_execution_coco(coco_train_visiondata, coco_test_visiondata,
mock_trained_yolov5_object_detection, device):
suite = full_suite()
args = {'train_dataset': coco_train_visiondata, 'test_dataset': coco_test_visiondata,
'model':mock_trained_yolov5_object_detection, 'device': device}
result = suite.run(**args)
length = get_expected_results_length(suite, args)
validate_suite_result(result, length)
arguments = (
dict(train_dataset=coco_train_visiondata, test_dataset=coco_test_visiondata,
model=mock_trained_yolov5_object_detection, device=device),
dict(train_dataset=coco_train_visiondata,
model=mock_trained_yolov5_object_detection, device=device),
)

for args in arguments:
result = suite.run(**args)
length = get_expected_results_length(suite, args)
validate_suite_result(result, length)

0 comments on commit 7a0b9de

Please sign in to comment.