diff --git a/python/ClipDetection/clip_component/clip_component.py b/python/ClipDetection/clip_component/clip_component.py index 9a6df34f..85d5b669 100644 --- a/python/ClipDetection/clip_component/clip_component.py +++ b/python/ClipDetection/clip_component/clip_component.py @@ -28,6 +28,7 @@ import os import csv from pkg_resources import resource_filename +from typing import Mapping, Iterable from PIL import Image import cv2 @@ -54,15 +55,18 @@ def __init__(self): self._wrapper = ClipWrapper() def get_detections_from_image_reader(self, image_job, image_reader): + num_detections = 0 try: logger.info("received image job: %s", image_job) image = image_reader.get_image() - detections = self._wrapper.get_classifications(image, image_job.job_properties) - logger.info(f"Job complete. Found {len(detections)} detections.") - return detections + detections = self._wrapper.get_classifications((image,), image_job.job_properties) + for detection in detections: + yield detection + num_detections += 1 + logger.info(f"Job complete. Found {num_detections} detection{'s' if num_detections > 1 else ''}.") - except Exception: - logger.exception(f"Failed to complete job {image_job.job_name} due to the following exception:") + except Exception as e: + logger.exception(f'Job failed due to: {e}') raise class ClipWrapper(object): @@ -71,6 +75,7 @@ def __init__(self): model, _ = clip.load('ViT-B/32', device=device, download_root='/models') logger.info("Model loaded.") self._model = model + self._preprocessor = None self._classification_path = '' self._template_path = '' @@ -83,50 +88,64 @@ def __init__(self): self._inferencing_server = None self._triton_server_url = None - def get_classifications(self, image, job_properties): - image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + def get_classifications(self, images, job_properties: Mapping[str, str]) -> Iterable[mpf.ImageLocation]: kwargs = self._parse_properties(job_properties) - image_width, image_height = image.size - self._check_template_list(kwargs['template_path'], kwargs['num_templates']) self._check_class_list(kwargs['classification_path'], kwargs['classification_list']) - image = ImagePreprocessor(kwargs['enable_cropping']).preprocess(image).to(device) + self._preprocessor = ImagePreprocessor(kwargs['enable_cropping']) - if kwargs['enable_triton']: - if self._inferencing_server is None or kwargs['triton_server'] != self._triton_server_url: - self._inferencing_server = CLIPInferencingServer(kwargs['triton_server']) - self._triton_server_url = kwargs['triton_server'] - - results = self._inferencing_server.get_responses(image) - image_tensors= torch.Tensor(np.copy(results)).to(device=device) - image_features = torch.mean(image_tensors, 0) - else: + for image in images: + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + image_width, image_height = image.size + + image = self._preprocessor.preprocess(image).to(device) + + if kwargs['enable_triton']: + if self._inferencing_server is None or kwargs['triton_server'] != self._triton_server_url: + self._inferencing_server = CLIPInferencingServer(kwargs['triton_server']) + self._triton_server_url = kwargs['triton_server'] + + results = self._inferencing_server.get_responses(image) + image_tensors= torch.Tensor(np.copy(results)).to(device=device) + image_features = torch.mean(image_tensors, 0) + else: + with torch.no_grad(): + image_features = self._model.encode_image(image).float() + image_features = torch.mean(image_features, 0).unsqueeze(0) + with torch.no_grad(): - image_features = self._model.encode_image(image).float() - image_features = torch.mean(image_features, 0).unsqueeze(0) - - with torch.no_grad(): - image_features /= image_features.norm(dim=-1, keepdim=True) - - similarity = (100.0 * image_features @ self._text_features).softmax(dim=-1).to(device) - similarity = torch.mean(similarity, 0) - values, indices = similarity.topk(kwargs['num_classifications']) - - classification_list = '; '.join([self._class_mapping[list(self._class_mapping.keys())[int(index)]] for index in indices]) - classification_confidence_list = '; '.join([str(value.item()) for value in values]) - - detection_properties = { - "CLASSIFICATION": classification_list.split('; ')[0], - "CLASSIFICATION CONFIDENCE LIST": classification_confidence_list, - "CLASSIFICATION LIST": classification_list - } - - if kwargs['include_features']: - detection_properties['FEATURE'] = base64.b64encode(image_features.cpu().numpy()).decode() + image_features /= image_features.norm(dim=-1, keepdim=True) + + similarity = (100.0 * image_features @ self._text_features).softmax(dim=-1).to(device) + similarity = torch.mean(similarity, 0) + values, indices = similarity.topk(len(self._class_mapping)) + + classification_list = [] + classification_confidence_list = [] + count = 0 + for value, index in zip(values, indices): + if count >= kwargs['num_classifications']: + break + class_name = self._class_mapping[list(self._class_mapping.keys())[int(index)]] + if class_name not in classification_list: + classification_list.append(class_name) + classification_confidence_list.append(str(value.item())) + count += 1 + + classification_list = '; '.join(classification_list) + classification_confidence_list = '; '.join(classification_confidence_list) + + detection_properties = { + "CLASSIFICATION": classification_list.split('; ')[0], + "CLASSIFICATION CONFIDENCE LIST": classification_confidence_list, + "CLASSIFICATION LIST": classification_list + } + + if kwargs['include_features']: + detection_properties['FEATURE'] = base64.b64encode(image_features.cpu().numpy()).decode() - return [ - mpf.ImageLocation( + yield mpf.ImageLocation( x_left_upper = 0, y_left_upper = 0, width = image_width, @@ -134,17 +153,16 @@ def get_classifications(self, image, job_properties): confidence = float(classification_confidence_list.split('; ')[0]), detection_properties = detection_properties ) - ] def _parse_properties(self, job_properties): classification_list = self._get_prop(job_properties, "CLASSIFICATION_LIST", 'coco', ['coco', 'imagenet']) - classification_path = self._get_prop(job_properties, "CLASSIFICATION_PATH", '') + classification_path = os.path.expandvars(self._get_prop(job_properties, "CLASSIFICATION_PATH", '')) enable_cropping = self._get_prop(job_properties, "ENABLE_CROPPING", True) enable_triton = self._get_prop(job_properties, "ENABLE_TRITON", False) include_features = self._get_prop(job_properties, "INCLUDE_FEATURES", False) num_classifications = self._get_prop(job_properties, "NUMBER_OF_CLASSIFICATIONS", 1) num_templates = self._get_prop(job_properties, "NUMBER_OF_TEMPLATES", 80, [1, 7, 80]) - template_path = self._get_prop(job_properties, "TEMPLATE_PATH", '') + template_path = os.path.expandvars(self._get_prop(job_properties, "TEMPLATE_PATH", '')) triton_server = self._get_prop(job_properties, "TRITON_SERVER", 'clip-detection-server:8001') return dict( @@ -401,4 +419,4 @@ def _get_crops(imgs): return crops -EXPORT_MPF_COMPONENT = ClipComponent \ No newline at end of file +EXPORT_MPF_COMPONENT = ClipComponent diff --git a/python/ClipDetection/tests/data/rollup.csv b/python/ClipDetection/tests/data/rollup.csv new file mode 100644 index 00000000..d251e69c --- /dev/null +++ b/python/ClipDetection/tests/data/rollup.csv @@ -0,0 +1,7 @@ +dog,indoor animal +cat,indoor animal +lion,wild animal +sedan,vehicle +truck,vehicle +guitar,musical instrument +house,building \ No newline at end of file diff --git a/python/ClipDetection/tests/test_clip.py b/python/ClipDetection/tests/test_clip.py index f29a6bc8..50171345 100644 --- a/python/ClipDetection/tests/test_clip.py +++ b/python/ClipDetection/tests/test_clip.py @@ -63,7 +63,7 @@ def test_image_file(self): def test_image_file_custom(self): job = mpf.ImageJob( - job_name='test-image', + job_name='test-image-custom', data_uri=self._get_test_file('riot.jpg'), job_properties=dict( NUMBER_OF_CLASSIFICATIONS = 4, @@ -77,6 +77,22 @@ def test_image_file_custom(self): self.assertEqual(job.job_properties["NUMBER_OF_CLASSIFICATIONS"], len(self._output_to_list(result.detection_properties["CLASSIFICATION LIST"]))) self.assertTrue("violent scene" in self._output_to_list(result.detection_properties["CLASSIFICATION LIST"])) self.assertEqual("violent scene", result.detection_properties["CLASSIFICATION"]) + + def test_image_file_rollup(self): + job = mpf.ImageJob( + job_name='test-image-rollup', + data_uri=self._get_test_file('dog.jpg'), + job_properties=dict( + NUMBER_OF_CLASSIFICATIONS = 4, + NUMBER_OF_TEMPLATES = 1, + CLASSIFICATION_PATH = self._get_test_file("rollup.csv"), + ENABLE_CROPPING='False' + ), + media_properties={}, + feed_forward_location=None + ) + result = list(ClipComponent().get_detections_from_image(job))[0] + self.assertEqual("indoor animal", result.detection_properties["CLASSIFICATION"]) @staticmethod def _get_test_file(filename):