Skip to content

Commit

Permalink
Add Similar image leakage check (#1135)
Browse files Browse the repository at this point in the history
  • Loading branch information
noamzbr committed Apr 5, 2022
1 parent 000fd77 commit 826fef9
Show file tree
Hide file tree
Showing 12 changed files with 483 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ can be found in our `API Reference`_.
train_df = pd.read_csv('train_data.csv')
test_df = pd.read_csv('test_data.csv')
# Initialize and run desired check
TrainTestFeatureDrift().run(train_data, test_data)
TrainTestFeatureDrift().run(train_df, test_df)
Will produce output of the type:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def run_logic(self, context: Context) -> CheckResult:
if self.sort_feature_by == 'feature importance' and features_importance is not None:
columns_order = features_importance.sort_values(ascending=False).head(self.n_top_columns).index
else:
columns_order = sorted(train_dataset.features, key=lambda col: values_dict[col]['Drift score'], reverse=True
)[:self.n_top_columns]
columns_order = sorted(list(values_dict.keys()), key=lambda col: values_dict[col]['Drift score'],
reverse=True)[:self.n_top_columns]

sorted_by = self.sort_feature_by if features_importance is not None else 'drift score'

Expand Down
5 changes: 3 additions & 2 deletions deepchecks/vision/checks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
RobustnessReport, ConfusionMatrixReport, SimpleModelComparison, ImageSegmentPerformance
from .distribution import TrainTestLabelDrift, ImageDatasetDrift, ImagePropertyDrift, TrainTestPredictionDrift, \
ImagePropertyOutliers, LabelPropertyOutliers, HeatmapComparison
from .methodology import SimpleFeatureContribution
from .methodology import SimpleFeatureContribution, SimilarImageLeakage

__all__ = [
'ClassPerformance',
Expand All @@ -31,5 +31,6 @@
'SimpleFeatureContribution',
'ImagePropertyOutliers',
'LabelPropertyOutliers',
'HeatmapComparison'
'HeatmapComparison',
'SimilarImageLeakage'
]
2 changes: 2 additions & 0 deletions deepchecks/vision/checks/methodology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#
"""Module containing the distribution checks in the vision package."""
from .simple_feature_contribution import SimpleFeatureContribution
from .similar_image_leakage import SimilarImageLeakage

__all__ = [
'SimpleFeatureContribution',
'SimilarImageLeakage'
]
202 changes: 202 additions & 0 deletions deepchecks/vision/checks/methodology/similar_image_leakage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2022 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
"""Module contains the similar image leakage check."""
import random
from typing import TypeVar, List, Tuple
import numpy as np
from PIL.Image import fromarray
from imagehash import average_hash

from deepchecks import ConditionResult, ConditionCategory
from deepchecks.core import CheckResult, DatasetKind
from deepchecks.core.errors import DeepchecksValueError
from deepchecks.vision import Context, TrainTestCheck, Batch

__all__ = ['SimilarImageLeakage']

from deepchecks.vision.utils.image_functions import prepare_thumbnail

SIL = TypeVar('SIL', bound='SimilarImageLeakage')


class SimilarImageLeakage(TrainTestCheck):
"""Check for images in training that are similar to images in test.
Parameters
----------
n_top_show: int, default: 5
Number of images to show, sorted by the similarity score between them
hash_size: int, default: 8
Size of hashed image. Algorithm will hash the image to a hash_size*hash_size binary image. Increasing this
value will increase the accuracy of the algorithm, but will also increase the time and memory requirements.
similarity_threshold: float, default: 0.1
Similarity threshold (0,1). The similarity score defines what is the ratio of pixels that are different between
the two images. If the similarity score is below the threshold, the images are considered similar.
Note: The score is defined such that setting it to 1 will result in similarity being detected for all images
with up to half their pixels differing from each other. For a value of 1, random images (which on average
differ from each other by half their pixels) will be detected as similar half the time. To further illustrate,
for a hash of 8X8, setting the score to 1 will result with all images with up to 32 different pixels being
considered similar.
"""

_THUMBNAIL_SIZE = (200, 200)

def __init__(
self,
n_top_show: int = 10,
hash_size: int = 8,
similarity_threshold: float = 0.1
):
super().__init__()
if not (isinstance(n_top_show, int) and (n_top_show >= 0)):
raise DeepchecksValueError('n_top_show must be a positive integer')
self.n_top_show = n_top_show
if not (isinstance(hash_size, int) and (hash_size >= 0)):
raise DeepchecksValueError('hash_size must be a positive integer')
self.hash_size = hash_size
if not (isinstance(similarity_threshold, (float, int)) and (0 <= similarity_threshold <= 1)):
raise DeepchecksValueError('similarity_threshold must be a float in range (0,1)')
self.similarity_threshold = similarity_threshold
self.min_pixel_diff = int(np.ceil(similarity_threshold * (hash_size**2 / 2)))

def initialize_run(self, context: Context):
"""Initialize the run by initializing the lists of image hashes."""
self._hashed_train_images = []
self._hashed_test_images = []

def update(self, context: Context, batch: Batch, dataset_kind: DatasetKind):
"""Calculate image hashes for train and test."""
hashed_images = [average_hash(fromarray(img), hash_size=self.hash_size) for img in batch.images]

if dataset_kind == DatasetKind.TRAIN:
self._hashed_train_images += hashed_images
else:
self._hashed_test_images += hashed_images

def compute(self, context: Context) -> CheckResult:
"""Find similar images by comparing image hashes between train and test.
Returns
-------
CheckResult
value: list of tuples of similar image instances, in format (train_index, test_index). The index is by the
order of the images deepchecks received the images.
display: pairs of similar images
"""
train_hashes = np.array(self._hashed_train_images)

similar_indices = {
'train': [],
'test': []
}

for i, h in enumerate(self._hashed_test_images):
is_similar = (train_hashes - h) < self.min_pixel_diff
if any(is_similar):
for j in np.argwhere(is_similar): # Return indices where True
similar_indices['train'].append(j[0]) # append only the first similar image in train
similar_indices['test'].append(i)

display_indices = random.sample(range(len(similar_indices['test'])),
min(self.n_top_show, len(similar_indices['test'])))

display_images = {
'train': [],
'test': []
}

data_obj = {
'train': context.train,
'test': context.test
}

display = []
similar_pairs = []
if similar_indices['test']:
for similar_index in display_indices:
for dataset in ('train', 'test'):
image = data_obj[dataset].batch_to_images(
data_obj[dataset].batch_of_index(similar_indices[dataset][similar_index])
)[0]
image_thumbnail = prepare_thumbnail(
image=image,
size=self._THUMBNAIL_SIZE,
copy_image=False
)
display_images[dataset].append(image_thumbnail)

html = HTML_TEMPLATE.format(
count=len(similar_indices['test']),
n_of_images=len(display_indices),
train_images=''.join(display_images['train']),
test_images=''.join(display_images['test']),
)

display.append(html)

# return tuples of indices in original respective dataset objects
similar_pairs = list(zip(
context.train.to_dataset_index(*similar_indices['train']),
context.test.to_dataset_index(*similar_indices['test'])
))

return CheckResult(value=similar_pairs, display=display, header='Similar Image Leakage')

def add_condition_similar_images_not_more_than(self: SIL, threshold: int = 0) -> SIL:
"""Add new condition.
Add condition that will check the number of similar images is not greater than X.
The condition count how many unique images in test are similar to those in train.
Parameters
----------
threshold : int , default: 0
Number of allowed unique images in test that are similar to train
Returns
-------
SIL
"""

def condition(value: List[Tuple[int, int]]) -> ConditionResult:
num_similar_images = len(set(t[1] for t in value))

if num_similar_images > threshold:
message = f'Number of similar images between train and test datasets: {num_similar_images}'
return ConditionResult(ConditionCategory.FAIL, message)
else:
return ConditionResult(ConditionCategory.PASS)

return self.add_condition(f'Number of similar images between train and test is not greater than '
f'{threshold}', condition)


HTML_TEMPLATE = """
<h3><b>Similar Images</b></h3>
<div>
Total number of test samples with similar images in train: {count}
</div>
<h4>Samples</h4>
<div
style="
overflow-x: auto;
display: grid;
grid-template-rows: auto 1fr 1fr;
grid-template-columns: auto repeat({n_of_images}, 1fr);
grid-gap: 1.5rem;
justify-items: center;
align-items: center;
padding: 2rem;
width: max-content;">
<h5>Train</h5>{train_images}
<h5>Test</h5>{test_images}
</div>
"""
53 changes: 53 additions & 0 deletions deepchecks/vision/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# ----------------------------------------------------------------------------
# Copyright (C) 2021-2022 Deepchecks (https://www.deepchecks.com)
#
# This file is part of Deepchecks.
# Deepchecks is distributed under the terms of the GNU Affero General
# Public License (version 3 or later).
# You should have received a copy of the GNU Affero General Public License
# along with Deepchecks. If not, see <http://www.gnu.org/licenses/>.
# ----------------------------------------------------------------------------
#
from deepchecks.vision import VisionData
from torch.utils.data import DataLoader
from typing import Callable


def get_modified_dataloader(vision_data: VisionData, func_to_apply: Callable, shuffle: bool = False) -> DataLoader:
"""Get a dataloader whose underlying dataset is modified by a function.
Parameters
----------
vision_data: VisionData
A vision data object of the type for which the modified dataloader is intended.
func_to_apply: Callable
A callable of the form func_to_apply(orig_dataset, idx) that returns a modified version of the original
dataset return value for the given index.
shuffle: bool, default: False
Whether return d dataloader with shuffling.
Returns
-------
DataLoader
The modified dataloader.
"""

class ModifiedDataset():
"""A modified dataset object, returning func_to_apply for each index."""

def __init__(self, orig_dataset):
self._orig_dataset = orig_dataset

def __getitem__(self, idx):
return func_to_apply(self._orig_dataset, idx)

def __len__(self):
return len(self._orig_dataset)

# Code needed to return a dataloader with the modified dataset that is otherwise identical to the original.
props = vision_data._get_data_loader_props(vision_data.data_loader) # pylint: disable=protected-access
props['dataset'] = ModifiedDataset(vision_data.data_loader.dataset)
props['shuffle'] = shuffle
data_loader = DataLoader(**props)
data_loader, _ = vision_data._get_data_loader_sequential(data_loader) # pylint: disable=protected-access
return data_loader

0 comments on commit 826fef9

Please sign in to comment.