Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(types): add to_blob converter to document #1929

Merged
merged 2 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions jina/drivers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ConvertDriver(BaseRecursiveDriver):
.. note::
The list of functions that can be applied can be found in `:class:`Document`
"""

def __init__(self, convert_fn: str, *args, **kwargs):
"""
:param convert_fn: the method name from `:class:`Document` to be applied
Expand Down Expand Up @@ -45,6 +46,16 @@ def __init__(self, convert_fn: str = 'convert_buffer_to_uri', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)


class BufferImage2Blob(ConvertDriver):
def __init__(self, convert_fn: str = 'convert_buffer_image_to_blob', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)


class URI2Blob(ConvertDriver):
def __init__(self, convert_fn: str = 'convert_uri_to_blob', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)


class Text2URI(ConvertDriver):
def __init__(self, convert_fn: str = 'convert_text_to_uri', *args, **kwargs):
super().__init__(convert_fn, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion jina/drivers/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _apply_all(self, docs: 'DocumentSet', *args, **kwargs):
_args_dict = doc.get_attrs(*self.exec.required_keys)
ret = self.exec_fn(**_args_dict)
if ret:
SegmentDriver._update(doc, ret)
self._update(doc, ret)

@staticmethod
def _update(doc, ret):
Expand Down
19 changes: 18 additions & 1 deletion jina/types/document/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import io
import json
import mimetypes
import os
Expand All @@ -12,7 +13,7 @@
from google.protobuf import json_format
from google.protobuf.field_mask_pb2 import FieldMask

from .converters import png_to_buffer, to_datauri, guess_mime
from .converters import png_to_buffer, to_datauri, guess_mime, to_image_blob
from ..mixin import ProtoTypeMixin
from ..ndarray.generic import NdArray
from ..score import NamedScore
Expand Down Expand Up @@ -612,11 +613,27 @@ def convert_buffer_to_blob(self, **kwargs):
"""
self.blob = np.frombuffer(self.buffer)

def convert_buffer_image_to_blob(self, color_axis: int = -1, **kwargs):
""" Convert an image buffer to blob

:param color_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
:param kwargs: reserved for maximum compatibility when using with ConvertDriver
"""
self.blob = to_image_blob(io.BytesIO(self.buffer), color_axis)

def convert_blob_to_uri(self, width: int, height: int, resize_method: str = 'BILINEAR', **kwargs):
"""Assuming :attr:`blob` is a _valid_ image, set :attr:`uri` accordingly"""
png_bytes = png_to_buffer(self.blob, width, height, resize_method)
self.uri = 'data:image/png;base64,' + base64.b64encode(png_bytes).decode()

def convert_uri_to_blob(self, color_axis: int = -1, **kwargs):
""" Convert uri to blob

:param color_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
:param kwargs: reserved for maximum compatibility when using with ConvertDriver
"""
self.blob = to_image_blob(self.uri, color_axis)

def convert_uri_to_buffer(self, **kwargs):
"""Convert uri to buffer
Internally it downloads from the URI and set :attr:`buffer`.
Expand Down
9 changes: 9 additions & 0 deletions jina/types/document/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def png_to_buffer(arr: 'np.ndarray', width: int, height: int, resize_method: str
return png_bytes


def to_image_blob(source, color_axis: int = -1) -> 'np.ndarray':
from PIL import Image
raw_img = Image.open(source).convert('RGB')
img = np.array(raw_img).astype('float32')
if color_axis != -1:
img = np.moveaxis(img, -1, color_axis)
return img


def to_datauri(mimetype, data, charset: str = 'utf-8', base64: bool = False, binary: bool = True):
parts = ['data:', mimetype]
if charset is not None:
Expand Down
Binary file added tests/unit/types/document/test.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 24 additions & 5 deletions tests/unit/types/document/test_converters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import os

import numpy as np
import pytest

from jina import Document

cur_dir = os.path.dirname(os.path.abspath(__file__))


def test_uri_to_blob():
doc = Document(uri=os.path.join(cur_dir, 'test.png'))
doc.convert_uri_to_blob()
assert isinstance(doc.blob, np.ndarray)
assert doc.blob.shape == (85, 152, 3) # h,w,c


def test_buffer_to_blob():
doc = Document(uri=os.path.join(cur_dir, 'test.png'))
doc.convert_uri_to_buffer()
doc.convert_buffer_image_to_blob()
assert isinstance(doc.blob, np.ndarray)
assert doc.blob.shape == (85, 152, 3) # h,w,c


def test_convert_buffer_to_blob():
rand_state = np.random.RandomState(0)
array = rand_state.random([10,10])
array = rand_state.random([10, 10])
doc = Document(content=array.tobytes())
assert doc.content_type == 'buffer'
intialiazed_buffer = doc.buffer
Expand Down Expand Up @@ -62,7 +81,7 @@ def test_convert_text_to_uri(converter):
def test_convert_uri_to_text(uri, mimetype):
doc = Document(uri=uri, mime_type=mimetype)
doc.convert_uri_to_text()
if mimetype == 'text/html':
if mimetype == 'text/html':
assert '<!doctype html>' in doc.text
elif mimetype == 'text/x-python':
text_from_file = open(__file__).read()
Expand All @@ -89,11 +108,11 @@ def test_convert_content_to_uri():
('https://google.com/index.html', 'text/html')])
def test_convert_uri_to_data_uri(uri, mimetype):
doc = Document(uri=uri, mime_type=mimetype)
intialiazed_buffer = doc.buffer
intialiazed_buffer = doc.buffer
intialiazed_uri = doc.uri
doc.convert_uri_to_data_uri()
converted_buffer = doc.buffer
converted_uri = doc.uri
converted_buffer = doc.buffer
converted_uri = doc.uri
print(doc.content_type)
assert doc.uri.startswith(f'data:{mimetype}')
assert intialiazed_uri != converted_uri
Expand Down