Skip to content

Commit

Permalink
Add progress bars to single checks runs (#1236)
Browse files Browse the repository at this point in the history
  • Loading branch information
matanper committed Apr 11, 2022
1 parent 1f10d97 commit e49958b
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions deepchecks/vision/base_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ModelOnlyBaseCheck,
DatasetKind
)
from deepchecks.core.display_suite import ProgressBar
from deepchecks.vision.context import Context
from deepchecks.vision.vision_data import VisionData
from deepchecks.vision.batch_wrapper import Batch
Expand Down Expand Up @@ -52,6 +53,8 @@ def run(
) -> CheckResult:
"""Run check."""
assert self.context_type is not None
p_bar = ProgressBar('Validating Input', 1, unit='')

# Context is copying the data object, then not using the original after the init
context: Context = self.context_type(dataset,
model=model,
Expand All @@ -60,20 +63,29 @@ def run(
n_samples=n_samples)

self.initialize_run(context, DatasetKind.TRAIN)
p_bar.inc_progress()
p_bar.close()

p_bar = ProgressBar('Ingesting Batches', len(context.train), unit='Batch')
context.train.init_cache()
batch_start_index = 0
for batch in context.train:
batch = Batch(batch, context, DatasetKind.TRAIN, batch_start_index)
context.train.update_cache(batch)
self.update(context, batch, DatasetKind.TRAIN)
batch_start_index += len(batch)
p_bar.inc_progress()
p_bar.close()

p_bar = ProgressBar('Computing Check', 1, unit='Check')
result = self.compute(context, DatasetKind.TRAIN)
footnote = context.get_is_sampled_footnote(DatasetKind.TRAIN)
if footnote:
result.display.append(footnote)
return self.finalize_check_result(result)
result = self.finalize_check_result(result)
p_bar.inc_progress()
p_bar.close()
return result

def initialize_run(self, context: Context, dataset_kind: DatasetKind):
"""Initialize run before starting updating on batches. Optional."""
Expand Down Expand Up @@ -107,6 +119,7 @@ def run(
) -> CheckResult:
"""Run check."""
assert self.context_type is not None
p_bar = ProgressBar('Validating Input', 1, unit='')
# Context is copying the data object, then not using the original after the init
context: Context = self.context_type(train_dataset,
test_dataset,
Expand All @@ -116,28 +129,40 @@ def run(
n_samples=n_samples)

self.initialize_run(context)
p_bar.inc_progress()
p_bar.close()

p_bar = ProgressBar('Ingesting Batches - Train Dataset', len(context.train), unit='Batch')
context.train.init_cache()
batch_start_index = 0
for batch in context.train:
batch = Batch(batch, context, DatasetKind.TRAIN, batch_start_index)
context.train.update_cache(batch)
self.update(context, batch, DatasetKind.TRAIN)
batch_start_index += len(batch)
p_bar.inc_progress()
p_bar.close()

p_bar = ProgressBar('Ingesting Batches - Test Dataset', len(context.train), unit='Batch')
context.test.init_cache()
batch_start_index = 0
for batch in context.test:
batch = Batch(batch, context, DatasetKind.TEST, batch_start_index)
context.test.update_cache(batch)
self.update(context, batch, DatasetKind.TEST)
batch_start_index += len(batch)
p_bar.inc_progress()
p_bar.close()

p_bar = ProgressBar('Computing Check', 1, unit='Check')
result = self.compute(context)
footnote = context.get_is_sampled_footnote()
if footnote:
result.display.append(footnote)
return self.finalize_check_result(result)
result = self.finalize_check_result(result)
p_bar.inc_progress()
p_bar.close()
return result

def initialize_run(self, context: Context):
"""Initialize run before starting updating on batches. Optional."""
Expand Down Expand Up @@ -165,10 +190,19 @@ def run(
) -> CheckResult:
"""Run check."""
assert self.context_type is not None
p_bar = ProgressBar('Validating Input', 1, unit='')
context: Context = self.context_type(model=model, device=device, random_state=random_state)

self.initialize_run(context)
return self.finalize_check_result(self.compute(context))

p_bar.inc_progress()
p_bar.close()

p_bar = ProgressBar('Computing Check', 1, unit='Check')
result = self.finalize_check_result(self.compute(context))
p_bar.inc_progress()
p_bar.close()
return result

def initialize_run(self, context: Context):
"""Initialize run before starting updating on batches. Optional."""
Expand Down

0 comments on commit e49958b

Please sign in to comment.