Skip to content

Commit

Permalink
1586 add condition to image dataset drift (#1665)
Browse files Browse the repository at this point in the history
The condition pps_less_than is now implemented for the check
  • Loading branch information
TheSolY committed Jun 20, 2022
1 parent 4478be7 commit 06bef87
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import pandas as pd

from deepchecks.core import CheckResult, DatasetKind
from deepchecks.core import CheckResult, ConditionCategory, ConditionResult, DatasetKind
from deepchecks.core.check_utils.whole_dataset_drift_utils import run_whole_dataset_drift
from deepchecks.utils.strings import format_number
from deepchecks.vision import Batch, Context, TrainTestCheck
from deepchecks.vision.utils.image_properties import default_image_properties, get_column_type, validate_properties

Expand Down Expand Up @@ -141,3 +142,26 @@ def compute(self, context: Context) -> CheckResult:
displays.insert(0, headnote)

return CheckResult(value=values_dict, display=displays, header='Image Dataset Drift')

def add_condition_drift_score_less_than(self, threshold: float = 0.1):
"""
Add condition - require drift score to be less than the threshold.
The drift score used here is the domain_classifier_drift_Score attribute of the check result.
Parameters
----------
threshold: float , default: 0.1
The max threshold for the drift score.
"""
def condition(result):
drift_score = result['domain_classifier_drift_score']
if drift_score < threshold:
return ConditionResult(ConditionCategory.PASS,
f'Drift score {format_number(drift_score, 3)} is less than '
f'{format_number(threshold)}')
else:
return ConditionResult(ConditionCategory.FAIL,
f'Drift score {format_number(drift_score, 3)} is not less than '
f'{format_number(threshold)}')

return self.add_condition(f'Drift score is less than {threshold}', condition)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* `Which Image Properties Are Used? <#which-image-properties-are-used>`__
* `Loading The Data <#loading-the-data>`__
* `Run The Check <#run-the-check>`__
* `Define a Condition <#define-a-condition>`__
What Is Image Dataset Drift?
------------------------------------
Expand Down Expand Up @@ -109,10 +110,21 @@ def batch_to_images(self, batch):
test_dataloader = load_dataset(train=False, object_type='DataLoader')

drifted_train_ds = DriftedCOCO(train_dataloader)
test_ds = COCOData(test_dataloader)
test_ds_coco = COCOData(test_dataloader)

#%%
# Run the check again
# ^^^^^^^^^^^^^^^^^^^
check = ImageDatasetDrift()
check.run(train_dataset=drifted_train_ds, test_dataset=test_ds)
check.run(train_dataset=drifted_train_ds, test_dataset=test_ds_coco)


#%%
# Define a Condition
# -------------
# Now, we will define a condition that the maximum drift score is less than a certain threshold. In this example we will
# set the threshold at 0.2.
# In order to demonstrate the condition, we will use again the original (not drifted) train dataset.

check = ImageDatasetDrift().add_condition_drift_score_less_than(0.2)
check.run(train_dataset=train_ds, test_dataset=test_ds).show(show_additional_outputs=False)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from deepchecks.vision.checks import ImageDatasetDrift
from deepchecks.vision.datasets.detection.coco import COCOData
from tests.base.utils import equal_condition_result
from tests.vision.vision_conftest import *


Expand Down Expand Up @@ -143,3 +144,40 @@ def batch_to_images(self, batch):
'Mean Blue Relative Intensity': close_to(0, 0.01),
})
}))


def test_condition_fail(coco_train_dataloader, coco_test_dataloader, device):
# Arrange
class DriftCoco(COCOData):
def batch_to_images(self, batch):
return pil_drift_formatter(batch)

train = DriftCoco(coco_train_dataloader)
test = COCOData(coco_test_dataloader)
check = ImageDatasetDrift().add_condition_drift_score_less_than()

# Act
result = check.run(train, test, random_state=42, device=device)

# Assert
assert_that(result.conditions_results[0], equal_condition_result(
is_pass=False,
name=f'Drift score is less than 0.1',
details=f'Drift score 0.816 is not less than 0.1',
))


def test_condition_pass(coco_train_dataloader, coco_test_dataloader, device):
# Arrange
test = COCOData(coco_test_dataloader)
check = ImageDatasetDrift().add_condition_drift_score_less_than(0.3)

# Act
result = check.run(test, test, random_state=42, device=device)

# Assert
assert_that(result.conditions_results[0], equal_condition_result(
is_pass=True,
name=f'Drift score is less than 0.3',
details=f'Drift score 0 is less than 0.3',
))

0 comments on commit 06bef87

Please sign in to comment.