Skip to content

Commit

Permalink
better format validation for detection data class (#1428)
Browse files Browse the repository at this point in the history
better format validation for detection data class
Improved tutorial for vision data validation
  • Loading branch information
Nadav-Barak committed May 17, 2022
1 parent d80f851 commit fca4e27
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 72 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft,Other,Error'
exclude: '(pyzmq.*22\.3\.0|debugpy.*1\.6\.0|certifi.*2021\.10\.8|tqdm.*4\.64\.0|webencodings.*0\.5\.1|torch.*1\.10\.2.*|torchaudio.*0\.10\.2.*|torchvision.*0\.11\.3.*)'
exclude: '(pyzmq.*22\.3\.0|debugpy.*1\.6\.0|certifi.*2021\.10\.8|tqdm.*4\.64\.0|webencodings.*0\.5\.1|torch.*1\.10\.2.*|torchaudio.*0\.10\.2.*|torchvision.*0\.11\.3.*|terminado.*0\.15\.0)'
# pyzmq is Revised BSD https://github.com/zeromq/pyzmq/blob/main/examples/LICENSE
# debugpy is MIT https://github.com/microsoft/debugpy/blob/main/LICENSE
# certifi is MPL-2.0 https://github.com/certifi/python-certifi/blob/master/LICENSE
Expand All @@ -102,6 +102,7 @@ jobs:
# torch is BSD https://github.com/pytorch/pytorch/blob/master/LICENSE
# torchvision is BSD https://github.com/pytorch/vision/blob/main/LICENSE
# torchaudio is BSD https://github.com/pytorch/audio/blob/main/LICENSE
# terminado is BSD https://github.com/jupyter/terminado/blob/main/LICENSE

- name: Print report
if: ${{ always() }}
Expand Down
33 changes: 25 additions & 8 deletions deepchecks/vision/detection_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from deepchecks.core.errors import (DeepchecksNotImplementedError,
DeepchecksValueError, ValidationError)
ValidationError)
from deepchecks.vision.vision_data import TaskType, VisionData

logger = logging.getLogger('deepchecks')
Expand Down Expand Up @@ -122,6 +122,7 @@ def infer_on_batch(self, batch, model, device) -> List[torch.Tensor]:

def get_classes(self, batch_labels: List[torch.Tensor]):
"""Get a labels batch and return classes inside it."""

def get_classes_from_single_label(tensor: torch.Tensor):
return list(tensor[:, 0].type(torch.IntTensor).tolist()) if len(tensor) > 0 else []

Expand All @@ -144,17 +145,23 @@ def validate_label(self, batch):
"""
labels = self.batch_to_labels(batch)
if not isinstance(labels, list):
raise DeepchecksValueError('Check requires object detection label to be a list with an entry for each '
'sample')
raise ValidationError('Check requires object detection label to be a list with an entry for each '
'sample')
if len(labels) == 0:
raise DeepchecksValueError('Check requires object detection label to be a non-empty list')
raise ValidationError('Check requires object detection label to be a non-empty list')
if not isinstance(labels[0], torch.Tensor):
raise DeepchecksValueError('Check requires object detection label to be a list of torch.Tensor')
raise ValidationError('Check requires object detection label to be a list of torch.Tensor')
if len(labels[0].shape) != 2:
raise DeepchecksValueError('Check requires object detection label to be a list of 2D tensors')
raise ValidationError('Check requires object detection label to be a list of 2D tensors')
if labels[0].shape[1] != 5:
raise DeepchecksValueError('Check requires object detection label to be a list of 2D tensors, when '
'each row has 5 columns: [class_id, x, y, width, height]')
raise ValidationError('Check requires object detection label to be a list of 2D tensors, when '
'each row has 5 columns: [class_id, x, y, width, height]')
if torch.min(labels[0]) < 0:
raise ValidationError('Found one of coordinates to be negative, check requires object detection '
'bounding box coordinates to be of format [class_id, x, y, width, height].')
if torch.max(labels[0][:, 0] % 1) > 0:
raise ValidationError('Class_id must be a positive integer. Object detection labels per image should '
'be a Bx5 tensor of format [class_id, x, y, width, height].')

def validate_prediction(self, batch, model, device):
"""
Expand Down Expand Up @@ -187,3 +194,13 @@ def validate_prediction(self, batch, model, device):
if batch_predictions[0].shape[1] != 6:
raise ValidationError('Check requires detection predictions to be a list of 2D tensors, when '
'each row has 6 columns: [x, y, width, height, class_probability, class_id]')
if torch.min(batch_predictions[0]) < 0:
raise ValidationError('Found one of coordinates to be negative, Check requires object detection '
'bounding box predictions to be of format [x, y, width, height, confidence,'
' class_id]. ')
if torch.min(batch_predictions[0][:, 4]) < 0 or torch.max(batch_predictions[0][:, 4]) > 1:
raise ValidationError('Confidence must be between 0 and 1. Object detection predictions per image '
'should be a Bx6 tensor of format [x, y, width, height, confidence, class_id].')
if torch.max(batch_predictions[0][:, 5] % 1) > 0:
raise ValidationError('Class_id must be a positive integer. Object detection predictions per image '
'should be a Bx6 tensor of format [x, y, width, height, confidence, class_id].')
4 changes: 2 additions & 2 deletions deepchecks/vision/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def validate_extractors(dataset: VisionData, model, device=None, image_save_loca
dataset.validate_label(batch)
labels = dataset.batch_to_labels(batch)
except ValidationError as ex:
label_formatter_error = str(ex)
label_formatter_error = 'Fail! ' + str(ex)
except Exception: # pylint: disable=broad-except
label_formatter_error = 'Got exception \n' + traceback.format_exc()

try:
dataset.validate_image_data(batch)
images = dataset.batch_to_images(batch)
except ValidationError as ex:
image_formatter_error = str(ex)
image_formatter_error = 'Fail! ' + str(ex)
except Exception: # pylint: disable=broad-except
image_formatter_error = 'Got exception \n' + traceback.format_exc()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
Test Your Vision Data Class During Development
Test Your Deepchecks Vision Data Class
================================================
"Data Classes" are used to transform the structure of your data to the
Expand All @@ -18,7 +18,7 @@
* `Understand validation results <#understand-validation-results>`__
* `The end result <#the-end-result>`__
"""
#%%
# %%
# Load data and model
# -------------------
# In the first step we load the DataLoader and our model
Expand All @@ -28,7 +28,7 @@
data_loader = load_dataset(train=False, batch_size=1000, object_type='DataLoader')
model = load_model()

#%%
# %%
# Create simple DetectionData object
# ----------------------------------
# In the second step since this is an object detection task we will override a
Expand All @@ -38,19 +38,21 @@
# are not passing, and then we will implement a correct functions.

from deepchecks.vision.detection_data import DetectionData
import torch


class CocoDetectionData(DetectionData):
def batch_to_images(self, batch):
return batch[0]

def batch_to_labels(self, batch):
return batch[1]
return [torch.round(x) for x in batch[1]]

def infer_on_batch(self, batch, model, device):
return model.to(device)(batch[0])

#%%

# %%
# Running the extractors validation
# ---------------------------------
# Now we will load our validate function and see the results while running
Expand All @@ -62,33 +64,29 @@ def infer_on_batch(self, batch, model, device):

validate_extractors(CocoDetectionData(data_loader), model)

#%%

# %%
# Understand validation results
# -----------------------------
# When looking at the result first thing we see is that it's separated into 2 parts.
# When looking at the result we can see is that it is separated into 2 parts.
#
# First one is about the structure we expect to get. This validation is automatic
# First part is about the structure we expect to get. This validation is automatic
# since it's purely technical and doesn't check content correctness. For example,
# in our validation above we see that the label extractor is passing, meaning the
# labels are returning in the expected format. Second part is about the content,
# which can't be automatically validated and requires your attention. This part
# labels are provided in the expected format.
#
# Second part is about the content, which cannot be automatically validated and requires your attention. This part
# includes looking visually at data outputted by the formatters to validate it is
# right. In the validation above we see a list of classes that doesn't seem to make
# sense. This is because although our labels are in the right structure, the content
# inside is not valid.
# correct. In the validation above we see a list of classes that doesn't seem to make much
# sense - it contains class_ids in values ranging from 0 to 596 while in the COCO dataset there are only 80 classes.
#
# We know that the classes in our data are represented by class id which is an int,
# therefore we understand the labels does not contain the data in the right order.
# For the next step we'll fix the label extractor and then validate again:

import torch


class CocoDetectionData(DetectionData):
def batch_to_labels(self, batch):
# Translate labels to deepchecks format.
# the label_id here is in the last position of the tensor, and the DetectionLabelFormatter expects it
# at the first position.
# Originally the label_id was at the last position of the tensor while Deepchecks expects it
# to be at the first position.
formatted_labels = []
for tensor in batch[1]:
tensor = torch.index_select(tensor, 1, torch.LongTensor([4, 0, 1, 2, 3])) if len(tensor) > 0 else tensor
Expand All @@ -101,41 +99,38 @@ def batch_to_images(self, batch):
def infer_on_batch(self, batch, model, device):
return model.to(device)(batch[0])


validate_extractors(CocoDetectionData(data_loader), model)

#%%

# %%
# Now we can see in the content section that our classes are indeed as we expect
# them to be, class ids of type int. Now we can continue and fix the prediction extractor
# them to be, values between 0 and 79. Now we can continue and fix the prediction extractor

class CocoDetectionData(DetectionData):
def infer_on_batch(self, batch, model, device):
# Convert from yolo Detections object to List (per image) of Tensors of the shape [N, 6]"""
# Convert from yolo Detections object to List (per image) of Tensors of the shape [B, 6]"""
return_list = []
predictions = model.to(device)(batch[0])

# yolo Detections objects have List[torch.Tensor] xyxy output in .pred
for single_image_tensor in predictions.pred:
return_list.append(single_image_tensor)

return return_list

# using the same label extractor
def batch_to_labels(self, batch):
# Translate labels to deepchecks format.
# the label_id here is in the last position of the tensor, and the DetectionLabelFormatter expects it
# at the first position.
formatted_labels = []
for tensor in batch[1]:
tensor = torch.index_select(tensor, 1, torch.LongTensor([4, 0, 1, 2, 3])) if len(tensor) > 0 else tensor
formatted_labels.append(tensor)
return formatted_labels

def batch_to_images(self, batch):
return batch[0]


validate_extractors(CocoDetectionData(data_loader), model)

#%%
# %%
# Now our prediction formatter also have valid structure. But in order to really
# validate it we also need visual assertion and for that we need the image extractor to work.

Expand All @@ -152,18 +147,12 @@ def infer_on_batch(self, batch, model, device):
# Convert from yolo Detections object to List (per image) of Tensors of the shape [N, 6]"""
return_list = []
predictions = model.to(device)(batch[0])

# yolo Detections objects have List[torch.Tensor] xyxy output in .pred
for single_image_tensor in predictions.pred:
return_list.append(single_image_tensor)

return return_list

# using the same label extractor
def batch_to_labels(self, batch):
# Translate labels to deepchecks format.
# the label_id here is in the last position of the tensor, and the DetectionLabelFormatter expects it
# at the first position.
formatted_labels = []
for tensor in batch[1]:
tensor = torch.index_select(tensor, 1, torch.LongTensor([4, 0, 1, 2, 3])) if len(tensor) > 0 else tensor
Expand All @@ -173,13 +162,13 @@ def batch_to_labels(self, batch):

validate_extractors(CocoDetectionData(data_loader), model)

#%%

# %%
# Now that that image extractor is valid it displays for us visually the label and prediction.
# When we look at the label we see it is correct, but when we look at the prediction something
# is broken.
# When we look at the label we see it is correct, but when we look at the bounding box predictions something
# seems broken.
#
# We need to fix the prediction so the prediction will be returned in
# [x, y, w, h, confidence, class] format.
# We need to fix the prediction so the prediction will be returned in [x, y, w, h, confidence, class] format.

class CocoDetectionData(DetectionData):
def infer_on_batch(self, batch, model, device):
Expand All @@ -199,9 +188,6 @@ def infer_on_batch(self, batch, model, device):

# using the same label extractor
def batch_to_labels(self, batch):
# Translate labels to deepchecks format.
# the label_id here is in the last position of the tensor, and the DetectionLabelFormatter expects it
# at the first position.
formatted_labels = []
for tensor in batch[1]:
tensor = torch.index_select(tensor, 1, torch.LongTensor([4, 0, 1, 2, 3])) if len(tensor) > 0 else tensor
Expand All @@ -210,10 +196,10 @@ def batch_to_labels(self, batch):

# using the same image extractor
def batch_to_images(self, batch):
# Yolo works on PIL and ImageFormatter expects images as numpy arrays
return [np.array(x) for x in batch[0]]

#%%

# %%
# The end result
# --------------
validate_extractors(CocoDetectionData(data_loader), model)

0 comments on commit fca4e27

Please sign in to comment.