Skip to content

Commit

Permalink
1141 bug cv unclear if the prediction was correct in classification d…
Browse files Browse the repository at this point in the history
…ata validation (#1214)

* Add support for prediction bbox notation

* Update extractors validation

* Remove kaliedo requirement
  • Loading branch information
matanper committed Apr 7, 2022
1 parent 6793e13 commit 2419033
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 84 deletions.
21 changes: 16 additions & 5 deletions deepchecks/vision/utils/detection_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from PIL.Image import Image


__all__ = ['verify_bbox_format_notation', 'convert_batch_of_bboxes', 'convert_bbox', ]
__all__ = ['verify_bbox_format_notation', 'convert_batch_of_bboxes', 'convert_bbox', 'DEFAULT_PREDICTION_FORMAT']


DEFAULT_PREDICTION_FORMAT = 'xywhsl'


def verify_bbox_format_notation(notation: str) -> Tuple[bool, List[str]]:
Expand All @@ -32,7 +35,7 @@ def verify_bbox_format_notation(notation: str) -> Tuple[bool, List[str]]:
-------
Tuple[
bool,
List[Literal['label', 'width', 'height', 'xmin', 'ymin', 'xmax', 'ymax', 'xcenter', 'ycenter']]
List[Literal['label', 'score', 'width', 'height', 'xmin', 'ymin', 'xmax', 'ymax', 'xcenter', 'ycenter']]
]
first item indicates whether coordinates are normalized or not,
second represents format of the bbox
Expand All @@ -47,6 +50,10 @@ def verify_bbox_format_notation(notation: str) -> Tuple[bool, List[str]]:
tokens.append('l')
current = current[1:]
current_pos = current_pos + 1
elif current.startswith('s'):
tokens.append('s')
current = current[1:]
current_pos = current_pos + 1
elif current.startswith('wh'):
tokens.append('wh')
current = current[2:]
Expand Down Expand Up @@ -75,11 +82,13 @@ def verify_bbox_format_notation(notation: str) -> Tuple[bool, List[str]]:
)

received_combination = Counter(tokens)
allowed_combinations = (
allowed_combinations = [
{'l': 1, 'xy': 2},
{'l': 1, 'xy': 1, 'wh': 1},
{'l': 1, 'cxcy': 1, 'wh': 1}
)
]
# All allowed combinations are also allowed with or without score to support both label and prediction
allowed_combinations += [{**c, 's': 1} for c in allowed_combinations]

if sum(c == received_combination for c in allowed_combinations) != 1:
raise ValueError(
Expand All @@ -102,6 +111,8 @@ def verify_bbox_format_notation(notation: str) -> Tuple[bool, List[str]]:
for t in tokens:
if t == 'l':
normalized_tokens.append('label')
elif t == 's':
normalized_tokens.append('score')
elif t == 'wh':
normalized_tokens.extend(('width', 'height'))
elif t == 'cxcy':
Expand Down Expand Up @@ -265,7 +276,7 @@ def _convert_bbox(
(image_width is not None and image_height is not None) \
or (image_width is None and image_height is None)

data = dict(zip(notation_tokens, bbox[:5]))
data = dict(zip(notation_tokens, bbox))

if 'xcenter' in data and 'ycenter' in data:
if image_width is not None and image_height is not None:
Expand Down
26 changes: 1 addition & 25 deletions deepchecks/vision/utils/image_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .detection_formatters import convert_bbox


__all__ = ['ImageInfo', 'numpy_to_image_figure', 'label_bbox_add_to_figure', 'numpy_grayscale_to_heatmap_figure',
__all__ = ['ImageInfo', 'numpy_grayscale_to_heatmap_figure', 'ensure_image',
'apply_heatmap_image_properties', 'draw_bboxes', 'prepare_thumbnail', 'crop_image']


Expand All @@ -50,17 +50,6 @@ def is_equals(self, img_b) -> bool:
return np.array_equal(self.img, img_b)


def numpy_to_image_figure(data: np.ndarray):
"""Create image graph object from given numpy array data."""
dimension = data.shape[2]
if dimension == 1:
data = cv2.cvtColor(data, cv2.COLOR_GRAY2RGB)
elif dimension != 3:
raise DeepchecksValueError(f'Don\'t know to plot images with {dimension} dimensions')

return go.Image(z=data, hoverinfo='skip')


def ensure_image(
image: t.Union[pilimage.Image, np.ndarray, torch.Tensor],
copy: bool = True
Expand Down Expand Up @@ -215,19 +204,6 @@ def apply_heatmap_image_properties(fig):
fig.update_xaxes(constrain='domain')


def label_bbox_add_to_figure(labels: torch.Tensor, figure, row=None, col=None, color='red',
prediction=False):
"""Add a bounding box label and rectangle to given figure."""
for single in labels:
if prediction:
x, y, w, h, _, clazz = single.tolist()
else:
clazz, x, y, w, h = single.tolist()
figure.add_shape(type='rect', x0=x, y0=y, x1=x+w, y1=y+h, row=row, col=col, line=dict(color=color))
figure.add_annotation(x=x + w / 2, y=y, text=str(clazz), showarrow=False, yshift=10, row=row, col=col,
font=dict(color=color))


def crop_image(img: np.ndarray, x, y, w, h) -> np.ndarray:
"""Return the cropped numpy array image by x, y, w, h coordinates (top left corner, width and height."""
# Convert x, y, w, h to integers if not integers already:
Expand Down
81 changes: 28 additions & 53 deletions deepchecks/vision/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@
from deepchecks.core.errors import ValidationError
from deepchecks.utils.ipython import is_headless, is_notebook
from deepchecks.utils.strings import create_new_file_name
from deepchecks.vision.utils.detection_formatters import DEFAULT_PREDICTION_FORMAT
from deepchecks.vision.batch_wrapper import apply_to_tensor
from deepchecks.vision.vision_data import TaskType
from deepchecks.vision.utils.image_functions import numpy_to_image_figure, label_bbox_add_to_figure
from deepchecks.vision.utils.image_functions import ensure_image, draw_bboxes, prepare_thumbnail
from deepchecks.vision.vision_data import VisionData

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from PIL import Image
from io import BytesIO
from IPython.display import display, HTML


Expand Down Expand Up @@ -112,57 +109,33 @@ def validate_extractors(dataset: VisionData, model, device=None, image_save_loca
classes = None
# Plot
if image_formatter_error is None:
sample_image = images[0]
image = ensure_image(images[0], copy=False)
image_title = 'Visual example of an image.'
if dataset.task_type == TaskType.OBJECT_DETECTION:
# In case both label and prediction are valid show image side by side
if prediction_formatter_error is None and label_formatter_error is None:
fig = make_subplots(rows=1, cols=2)
fig.add_trace(numpy_to_image_figure(sample_image), row=1, col=1)
fig.add_trace(numpy_to_image_figure(sample_image), row=1, col=2)
label_bbox_add_to_figure(labels[0], fig, row=1, col=1)
label_bbox_add_to_figure(predictions[0], fig, prediction=True, color='orange', row=1, col=2)
fig.update_xaxes(title_text='Label', row=1, col=1)
fig.update_xaxes(title_text='Prediction', row=1, col=2)
fig.update_layout(title='Visual examples of an image with prediction and label data')
else:
fig = go.Figure(numpy_to_image_figure(sample_image))
# In here only label formatter or prediction formatter are valid (or none of them)
if label_formatter_error is None:
label_bbox_add_to_figure(labels[0], fig)
fig.update_xaxes(title='Label')
fig.update_layout(title='Visual example of an image with label data')
elif prediction_formatter_error is None:
label_bbox_add_to_figure(predictions[0], fig, prediction=True, color='orange')
fig.update_xaxes(title='Prediction')
fig.update_layout(title='Visual example of an image with prediction data')
if label_formatter_error is None:
image = draw_bboxes(image, labels[0], copy_image=False)
if prediction_formatter_error is None:
image = draw_bboxes(image, predictions[0], copy_image=False, color='blue',
bbox_notation=DEFAULT_PREDICTION_FORMAT)

elif dataset.task_type == TaskType.CLASSIFICATION:
fig = go.Figure(numpy_to_image_figure(sample_image))
# Create figure title
title = 'Visual example of an image'
if label_formatter_error is None and prediction_formatter_error is None:
title += ' with prediction and label data'
image_title = 'Visual examples of an image with prediction and label data. Label is red, ' \
'prediction is blue, and deepchecks loves you.'
elif label_formatter_error is None:
title += ' with label data'
image_title = 'Visual example of an image with label data. Could not display prediction.'
elif prediction_formatter_error is None:
title += ' with prediction data'
# Create x-axis title
x_title = []
image_title = 'Visual example of an image with prediction data. Could not display label.'
else:
image_title = 'Visual example of an image. Could not display label or prediction.'
elif dataset.task_type == TaskType.CLASSIFICATION:
if label_formatter_error is None:
x_title.append(f'Label: {labels[0]}')
image_title += f' Label class {labels[0]}'
if prediction_formatter_error is None:
x_title.append(f'Prediction: {predictions[0]}')

fig.update_layout(title=title)
fig.update_xaxes(title=', '.join(x_title))
else:
fig = go.Figure(numpy_to_image_figure(sample_image))
fig.update_layout(title='Visual example of an image')

fig.update_yaxes(showticklabels=False, visible=True, fixedrange=True, automargin=True)
fig.update_xaxes(showticklabels=False, visible=True, fixedrange=True, automargin=True)
pred_class = predictions[0].argmax()
image_title += f' Prediction class {pred_class}'
else:
fig = None
image = None
image_title = None

def get_header(x):
if is_notebook():
Expand All @@ -186,20 +159,21 @@ def get_header(x):
else:
msg += f'Unable to show due to invalid label formatter.{line_break}'

if fig:
if image:
if not is_notebook():
msg += 'Visual images & label & prediction: should open in a new window'
else:
msg += 'Visual images & label & prediction: Unable to show due to invalid image formatter.'

if is_notebook():
display(HTML(msg))
if fig:
display(HTML(fig.to_image('svg').decode('utf-8')))
if image:
image_html = '<div style="display:flex;flex-direction:column;align-items:baseline;">' \
f'{prepare_thumbnail(image, size=(200,200))}<p>{image_title}</p></div>'
display(HTML(image_html))
else:
print(msg)
if fig:
image = Image.open(BytesIO(fig.to_image('jpg')))
if image:
if is_headless():
if save_images:
if image_save_location is None:
Expand All @@ -213,6 +187,7 @@ def get_header(x):
print('This machine does not support GUI')
print('The formatted image was saved in:')
print(full_image_path)
print(image_title)
print('validate_extractors can be set to skip the image saving or change the save path')
print('*******************************************************************************')
else:
Expand Down
1 change: 0 additions & 1 deletion requirements/vision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
pytorch-ignite>=0.4.8
opencv-python>=4.5.5.62
albumentations>=1.1.0
kaleido>=0.2.1
imgaug>=0.4.0
requests>=2.22.0
seaborn>=0.1.0
Expand Down

0 comments on commit 2419033

Please sign in to comment.