Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 64 additions & 46 deletions python/ClipDetection/clip_component/clip_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import os
import csv
from pkg_resources import resource_filename
from typing import Mapping, Iterable

from PIL import Image
import cv2
Expand All @@ -54,15 +55,18 @@ def __init__(self):
self._wrapper = ClipWrapper()

def get_detections_from_image_reader(self, image_job, image_reader):
num_detections = 0
try:
logger.info("received image job: %s", image_job)
image = image_reader.get_image()
detections = self._wrapper.get_classifications(image, image_job.job_properties)
logger.info(f"Job complete. Found {len(detections)} detections.")
return detections
detections = self._wrapper.get_classifications((image,), image_job.job_properties)
for detection in detections:
yield detection
num_detections += 1
logger.info(f"Job complete. Found {num_detections} detection{'s' if num_detections > 1 else ''}.")

except Exception:
logger.exception(f"Failed to complete job {image_job.job_name} due to the following exception:")
except Exception as e:
logger.exception(f'Job failed due to: {e}')
raise

class ClipWrapper(object):
Expand All @@ -71,6 +75,7 @@ def __init__(self):
model, _ = clip.load('ViT-B/32', device=device, download_root='/models')
logger.info("Model loaded.")
self._model = model
self._preprocessor = None

self._classification_path = ''
self._template_path = ''
Expand All @@ -83,68 +88,81 @@ def __init__(self):
self._inferencing_server = None
self._triton_server_url = None

def get_classifications(self, image, job_properties):
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def get_classifications(self, images, job_properties: Mapping[str, str]) -> Iterable[mpf.ImageLocation]:
kwargs = self._parse_properties(job_properties)
image_width, image_height = image.size

self._check_template_list(kwargs['template_path'], kwargs['num_templates'])
self._check_class_list(kwargs['classification_path'], kwargs['classification_list'])

image = ImagePreprocessor(kwargs['enable_cropping']).preprocess(image).to(device)
self._preprocessor = ImagePreprocessor(kwargs['enable_cropping'])

if kwargs['enable_triton']:
if self._inferencing_server is None or kwargs['triton_server'] != self._triton_server_url:
self._inferencing_server = CLIPInferencingServer(kwargs['triton_server'])
self._triton_server_url = kwargs['triton_server']

results = self._inferencing_server.get_responses(image)
image_tensors= torch.Tensor(np.copy(results)).to(device=device)
image_features = torch.mean(image_tensors, 0)
else:
for image in images:
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
image_width, image_height = image.size

image = self._preprocessor.preprocess(image).to(device)

if kwargs['enable_triton']:
if self._inferencing_server is None or kwargs['triton_server'] != self._triton_server_url:
self._inferencing_server = CLIPInferencingServer(kwargs['triton_server'])
self._triton_server_url = kwargs['triton_server']

results = self._inferencing_server.get_responses(image)
image_tensors= torch.Tensor(np.copy(results)).to(device=device)
image_features = torch.mean(image_tensors, 0)
else:
with torch.no_grad():
image_features = self._model.encode_image(image).float()
image_features = torch.mean(image_features, 0).unsqueeze(0)

with torch.no_grad():
image_features = self._model.encode_image(image).float()
image_features = torch.mean(image_features, 0).unsqueeze(0)

with torch.no_grad():
image_features /= image_features.norm(dim=-1, keepdim=True)

similarity = (100.0 * image_features @ self._text_features).softmax(dim=-1).to(device)
similarity = torch.mean(similarity, 0)
values, indices = similarity.topk(kwargs['num_classifications'])

classification_list = '; '.join([self._class_mapping[list(self._class_mapping.keys())[int(index)]] for index in indices])
classification_confidence_list = '; '.join([str(value.item()) for value in values])

detection_properties = {
"CLASSIFICATION": classification_list.split('; ')[0],
"CLASSIFICATION CONFIDENCE LIST": classification_confidence_list,
"CLASSIFICATION LIST": classification_list
}

if kwargs['include_features']:
detection_properties['FEATURE'] = base64.b64encode(image_features.cpu().numpy()).decode()
image_features /= image_features.norm(dim=-1, keepdim=True)

similarity = (100.0 * image_features @ self._text_features).softmax(dim=-1).to(device)
similarity = torch.mean(similarity, 0)
values, indices = similarity.topk(len(self._class_mapping))

classification_list = []
classification_confidence_list = []
count = 0
for value, index in zip(values, indices):
if count >= kwargs['num_classifications']:
break
class_name = self._class_mapping[list(self._class_mapping.keys())[int(index)]]
if class_name not in classification_list:
classification_list.append(class_name)
classification_confidence_list.append(str(value.item()))
count += 1

classification_list = '; '.join(classification_list)
classification_confidence_list = '; '.join(classification_confidence_list)

detection_properties = {
"CLASSIFICATION": classification_list.split('; ')[0],
"CLASSIFICATION CONFIDENCE LIST": classification_confidence_list,
"CLASSIFICATION LIST": classification_list
}

if kwargs['include_features']:
detection_properties['FEATURE'] = base64.b64encode(image_features.cpu().numpy()).decode()

return [
mpf.ImageLocation(
yield mpf.ImageLocation(
x_left_upper = 0,
y_left_upper = 0,
width = image_width,
height = image_height,
confidence = float(classification_confidence_list.split('; ')[0]),
detection_properties = detection_properties
)
]

def _parse_properties(self, job_properties):
classification_list = self._get_prop(job_properties, "CLASSIFICATION_LIST", 'coco', ['coco', 'imagenet'])
classification_path = self._get_prop(job_properties, "CLASSIFICATION_PATH", '')
classification_path = os.path.expandvars(self._get_prop(job_properties, "CLASSIFICATION_PATH", ''))
enable_cropping = self._get_prop(job_properties, "ENABLE_CROPPING", True)
enable_triton = self._get_prop(job_properties, "ENABLE_TRITON", False)
include_features = self._get_prop(job_properties, "INCLUDE_FEATURES", False)
num_classifications = self._get_prop(job_properties, "NUMBER_OF_CLASSIFICATIONS", 1)
num_templates = self._get_prop(job_properties, "NUMBER_OF_TEMPLATES", 80, [1, 7, 80])
template_path = self._get_prop(job_properties, "TEMPLATE_PATH", '')
template_path = os.path.expandvars(self._get_prop(job_properties, "TEMPLATE_PATH", ''))
triton_server = self._get_prop(job_properties, "TRITON_SERVER", 'clip-detection-server:8001')

return dict(
Expand Down Expand Up @@ -401,4 +419,4 @@ def _get_crops(imgs):
return crops


EXPORT_MPF_COMPONENT = ClipComponent
EXPORT_MPF_COMPONENT = ClipComponent
7 changes: 7 additions & 0 deletions python/ClipDetection/tests/data/rollup.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
dog,indoor animal
cat,indoor animal
lion,wild animal
sedan,vehicle
truck,vehicle
guitar,musical instrument
house,building
18 changes: 17 additions & 1 deletion python/ClipDetection/tests/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_image_file(self):

def test_image_file_custom(self):
job = mpf.ImageJob(
job_name='test-image',
job_name='test-image-custom',
data_uri=self._get_test_file('riot.jpg'),
job_properties=dict(
NUMBER_OF_CLASSIFICATIONS = 4,
Expand All @@ -77,6 +77,22 @@ def test_image_file_custom(self):
self.assertEqual(job.job_properties["NUMBER_OF_CLASSIFICATIONS"], len(self._output_to_list(result.detection_properties["CLASSIFICATION LIST"])))
self.assertTrue("violent scene" in self._output_to_list(result.detection_properties["CLASSIFICATION LIST"]))
self.assertEqual("violent scene", result.detection_properties["CLASSIFICATION"])

def test_image_file_rollup(self):
job = mpf.ImageJob(
job_name='test-image-rollup',
data_uri=self._get_test_file('dog.jpg'),
job_properties=dict(
NUMBER_OF_CLASSIFICATIONS = 4,
NUMBER_OF_TEMPLATES = 1,
CLASSIFICATION_PATH = self._get_test_file("rollup.csv"),
ENABLE_CROPPING='False'
),
media_properties={},
feed_forward_location=None
)
result = list(ClipComponent().get_detections_from_image(job))[0]
self.assertEqual("indoor animal", result.detection_properties["CLASSIFICATION"])

@staticmethod
def _get_test_file(filename):
Expand Down