Skip to content

Commit

Permalink
fix: proper usage of importextensions (#2237)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianmtr committed Mar 26, 2021
1 parent d7e069f commit f2def8a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 21 deletions.
14 changes: 11 additions & 3 deletions jina/drivers/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from jina.drivers import FlatRecursiveMixin, BaseRecursiveDriver
from jina.importer import ImportExtensions

if False:
# noinspection PyUnreachableCode
Expand Down Expand Up @@ -43,10 +44,17 @@ def _move_channel_axis(
return np.moveaxis(img, channel_axis_to_move, target_channel_axis)

def _load_image(blob: 'np.ndarray', channel_axis: int):
from PIL import Image
with ImportExtensions(
required=True,
pkg_name='Pillow',
verbose=True,
logger=self.logger,
help_text='PIL is missing. Install it with `pip install Pillow`',
):
from PIL import Image

img = _move_channel_axis(blob, channel_axis)
return Image.fromarray(img.astype('uint8'))
img = _move_channel_axis(blob, channel_axis)
return Image.fromarray(img.astype('uint8'))

for d in docs:
if self.done < self.top:
Expand Down
34 changes: 21 additions & 13 deletions tests/integration/debug/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from jina.executors.crafters import BaseCrafter
from jina.executors.decorators import single
from jina.importer import ImportExtensions
from .helper import _crop_image, _move_channel_axis, _load_image


Expand Down Expand Up @@ -35,19 +36,26 @@ def craft(self, buffer: bytes, uri: str, *args, **kwargs) -> Dict:
:param uri: the image file path
"""
from PIL import Image

if buffer:
raw_img = Image.open(io.BytesIO(buffer))
elif uri:
raw_img = Image.open(uri)
else:
raise ValueError('no value found in "buffer" and "uri"')
raw_img = raw_img.convert('RGB')
img = np.array(raw_img).astype('float32')
if self.channel_axis != -1:
img = np.moveaxis(img, -1, self.channel_axis)
return dict(blob=img)
with ImportExtensions(
required=True,
verbose=True,
pkg_name='Pillow',
logger=self.logger,
help_text='PIL is missing. Install it with `pip install Pillow`',
):
from PIL import Image

if buffer:
raw_img = Image.open(io.BytesIO(buffer))
elif uri:
raw_img = Image.open(uri)
else:
raise ValueError('no value found in "buffer" and "uri"')
raw_img = raw_img.convert('RGB')
img = np.array(raw_img).astype('float32')
if self.channel_axis != -1:
img = np.moveaxis(img, -1, self.channel_axis)
return dict(blob=img)


class CenterImageCropper(BaseCrafter):
Expand Down
17 changes: 12 additions & 5 deletions tests/integration/debug/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import numpy as np

from jina.importer import ImportExtensions


def _move_channel_axis(
img: 'np.ndarray', channel_axis_to_move: int, target_channel_axis: int = -1
Expand All @@ -21,11 +23,16 @@ def _load_image(blob: 'np.ndarray', channel_axis: int):
"""
Load an image array and return a `PIL.Image` object.
"""

from PIL import Image

img = _move_channel_axis(blob, channel_axis)
return Image.fromarray(img.astype('uint8'))
with ImportExtensions(
required=True,
verbose=True,
pkg_name='Pillow',
help_text='PIL is missing. Install it with `pip install Pillow`',
):
from PIL import Image

img = _move_channel_axis(blob, channel_axis)
return Image.fromarray(img.astype('uint8'))


def _crop_image(
Expand Down

0 comments on commit f2def8a

Please sign in to comment.