Skip to content

Commit

Permalink
refactor: change content field name (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Jan 16, 2022
1 parent fca35a0 commit 5b2ea94
Show file tree
Hide file tree
Showing 62 changed files with 599 additions and 557 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ Let's do some standard computer vision pre-processing:
from docarray import Document

def preproc(d: Document):
return (d.load_uri_to_image_blob() # load
.set_image_blob_shape((200, 200)) # resize all to 200x200
.set_image_blob_normalization() # normalize color
.set_image_blob_channel_axis(-1, 0)) # switch color axis for the PyTorch model later
return (d.load_uri_to_image_tensor() # load
.set_image_tensor_shape((200, 200)) # resize all to 200x200
.set_image_tensor_normalization() # normalize color
.set_image_tensor_channel_axis(-1, 0)) # switch color axis for the PyTorch model later

left_da.apply(preproc)
```
Expand Down Expand Up @@ -206,8 +206,8 @@ Better see it.

```python
(DocumentArray(left_da[8].matches, copy=True)
.apply(lambda d: d.set_image_blob_channel_axis(0, -1)
.set_image_blob_inv_normalization())
.apply(lambda d: d.set_image_tensor_channel_axis(0, -1)
.set_image_tensor_inv_normalization())
.plot_image_sprites())
```

Expand Down
4 changes: 2 additions & 2 deletions docarray/array/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def __setitem__(
value = (value,)

for _a, _v in zip(_attrs, value):
if _a == 'blob':
_docs.blobs = _v
if _a == 'tensor':
_docs.tensors = _v
elif _a == 'embedding':
_docs.embeddings = _v
else:
Expand Down
58 changes: 29 additions & 29 deletions docarray/array/mixins/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,37 +46,37 @@ def embeddings(self, value: 'ArrayType'):
ravel(value, self, 'embedding')

@property
def blobs(self) -> Optional['ArrayType']:
"""Return a :class:`ArrayType` stacking all :attr:`.blob`.
def tensors(self) -> Optional['ArrayType']:
"""Return a :class:`ArrayType` stacking all :attr:`.tensor`.
The `blob` attributes are stacked together along a newly created first
The `tensor` attributes are stacked together along a newly created first
dimension (as if you would stack using ``np.stack(X, axis=0)``).
.. warning:: This operation assumes all blobs have the same shape and dtype.
.. warning:: This operation assumes all tensors have the same shape and dtype.
All dtype and shape values are assumed to be equal to the values of the
first element in the DocumentArray
:return: a :class:`ArrayType` of blobs
:return: a :class:`ArrayType` of tensors
"""
if self and self[0].content_type == 'blob':
if self and self[0].content_type == 'tensor':
if self:
return unravel(self, 'blob')
return unravel(self, 'tensor')

@blobs.setter
def blobs(self, value: 'ArrayType'):
"""Set :attr:`.blob` of the Documents. To clear all :attr:`blob`, set it to ``None``.
@tensors.setter
def tensors(self, value: 'ArrayType'):
"""Set :attr:`.tensor` of the Documents. To clear all :attr:`tensor`, set it to ``None``.
:param value: The blob array to set. The first axis is the "row" axis.
:param value: The tensor array to set. The first axis is the "row" axis.
"""

if value is None:
for d in self:
d.blob = None
d.tensor = None
else:
blobs_shape0 = _get_len(value)
self._check_length(blobs_shape0)
tensors_shape0 = _get_len(value)
self._check_length(tensors_shape0)

ravel(value, self, 'blob')
ravel(value, self, 'tensor')

@property
def texts(self) -> Optional[List[str]]:
Expand Down Expand Up @@ -105,37 +105,37 @@ def texts(self, value: Sequence[str]):
doc.text = text

@property
def buffers(self) -> Optional[List[bytes]]:
"""Get the buffer attribute of all Documents.
def blobs(self) -> Optional[List[bytes]]:
"""Get the blob attribute of all Documents.
:return: a list of buffers
:return: a list of blobs
"""
if self and self[0].content_type == 'buffer':
if self and self[0].content_type == 'blob':
if self:
return [d.buffer for d in self]
return [d.blob for d in self]

@buffers.setter
def buffers(self, value: List[bytes]):
"""Set the buffer attribute for all Documents. To clear all :attr:`buffer`, set it to ``None``.
@blobs.setter
def blobs(self, value: List[bytes]):
"""Set the blob attribute for all Documents. To clear all :attr:`blob`, set it to ``None``.
:param value: A sequence of buffer to set, should be the same length as the
:param value: A sequence of blob to set, should be the same length as the
number of Documents
"""

if value is None:
for d in self:
d.buffer = None
d.blob = None
else:
self._check_length(len(value))

for doc, buffer in zip(self, value):
doc.buffer = buffer
for doc, blob in zip(self, value):
doc.blob = blob

@property
def contents(self) -> Optional[Union[Sequence['DocumentContentType'], 'ArrayType']]:
"""Get the :attr:`.content` of all Documents.
:return: a list of texts, buffers or :class:`ArrayType`
:return: a list of texts, blobs or :class:`ArrayType`
"""
if self:
content_type = self[0].content_type
Expand All @@ -148,7 +148,7 @@ def contents(
):
"""Set the :attr:`.content` of all Documents.
:param value: a list of texts, buffers or :class:`ArrayType`
:param value: a list of texts, blobs or :class:`ArrayType`
"""
if self:
content_type = self[0].content_type
Expand Down
8 changes: 4 additions & 4 deletions docarray/array/mixins/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _set_embeddings_keras(
device = tf.device('/GPU:0') if device == 'cuda' else tf.device('/CPU:0')
with device:
for b in self.batch(batch_size):
r = embed_model(b.blobs, training=False)
r = embed_model(b.tensors, training=False)
b.embeddings = r.numpy() if to_numpy else r

def _set_embeddings_torch(
Expand All @@ -60,7 +60,7 @@ def _set_embeddings_torch(
embed_model.eval()
with torch.inference_mode():
for b in self.batch(batch_size):
batch_inputs = torch.tensor(b.blobs, device=device)
batch_inputs = torch.tensor(b.tensors, device=device)
r = embed_model(batch_inputs).cpu().detach()
b.embeddings = r.numpy() if to_numpy else r
if is_training_before:
Expand All @@ -79,7 +79,7 @@ def _set_embeddings_paddle(
embed_model.to(device=device)
embed_model.eval()
for b in self.batch(batch_size):
batch_inputs = paddle.to_tensor(b.blobs, place=device)
batch_inputs = paddle.to_tensor(b.tensors, place=device)
r = embed_model(batch_inputs)
b.embeddings = r.numpy() if to_numpy else r
if is_training_before:
Expand All @@ -105,7 +105,7 @@ def _set_embeddings_onnx(

for b in self.batch(batch_size):
b.embeddings = embed_model.run(
None, {embed_model.get_inputs()[0].name: b.blobs}
None, {embed_model.get_inputs()[0].name: b.tensors}
)[0]


Expand Down
18 changes: 9 additions & 9 deletions docarray/array/mixins/getattr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def _get_attributes(self, *fields: str) -> List:
fields = list(fields)
if 'embedding' in fields:
e_index = fields.index('embedding')
if 'blob' in fields:
b_index = fields.index('blob')
fields.remove('blob')
if 'tensor' in fields:
b_index = fields.index('tensor')
fields.remove('tensor')

if 'embedding' in fields:
fields.remove('embedding')
if 'blob' in fields:
fields.remove('blob')
if 'tensor' in fields:
fields.remove('tensor')

if fields:
contents = [doc._get_attributes(*fields) for doc in self]
Expand All @@ -34,18 +34,18 @@ def _get_attributes(self, *fields: str) -> List:
if len(fields) == 1:
contents = [contents]
if b_index is not None:
contents.insert(b_index, self.blobs)
contents.insert(b_index, self.tensors)
if e_index is not None:
contents.insert(e_index, self.embeddings)
return contents

if b_index is not None and e_index is None:
return self.blobs
return self.tensors
if b_index is None and e_index is not None:
return self.embeddings
if b_index is not None and e_index is not None:
return (
[self.embeddings, self.blobs]
[self.embeddings, self.tensors]
if b_index > e_index
else [self.blobs, self.embeddings]
else [self.tensors, self.embeddings]
)
14 changes: 7 additions & 7 deletions docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def plot_embeddings(
make sure to give different names each time and set ``path`` to the same value.
:param host: if set, bind the embedding-projector frontend to given host. Otherwise `localhost` is used.
:param port: if set, run the embedding-projector frontend at given port. Otherwise a random port is used.
:param image_sprites: if set, visualize the dots using :attr:`.uri` and :attr:`.blob`.
:param image_sprites: if set, visualize the dots using :attr:`.uri` and :attr:`.tensor`.
:param path: if set, then append the visualization to an existing folder, where you can compare multiple
embeddings at the same time. Make sure to use a different ``title`` each time .
:param min_image_size: only used when `image_sprites=True`. the minimum size of the image
Expand Down Expand Up @@ -158,7 +158,7 @@ def plot_embeddings(

self.save_embeddings_csv(os.path.join(path, emb_fn), delimiter='\t')

_exclude_fields = ('embedding', 'blob', 'scores')
_exclude_fields = ('embedding', 'tensor', 'scores')
with_header = True
if len(set(self[0].non_empty_fields).difference(set(_exclude_fields))) <= 1:
with_header = False
Expand Down Expand Up @@ -288,7 +288,7 @@ def plot_image_sprites(
min_size: int = 16,
channel_axis: int = -1,
) -> None:
"""Generate a sprite image for all image blobs in this DocumentArray-like object.
"""Generate a sprite image for all image tensors in this DocumentArray-like object.
An image sprite is a collection of images put into a single image. It is always square-sized.
Each sub-image is also square-sized and equally-sized.
Expand Down Expand Up @@ -318,11 +318,11 @@ def plot_image_sprites(
img_id = 0
for d in self:
_d = copy.deepcopy(d)
if _d.content_type != 'blob':
_d.load_uri_to_image_blob()
if _d.content_type != 'tensor':
_d.load_uri_to_image_tensor()
channel_axis = -1

_d.set_image_blob_channel_axis(channel_axis, -1).set_image_blob_shape(
_d.set_image_tensor_channel_axis(channel_axis, -1).set_image_tensor_shape(
shape=(img_size, img_size)
)

Expand All @@ -331,7 +331,7 @@ def plot_image_sprites(
sprite_img[
(row_id * img_size) : ((row_id + 1) * img_size),
(col_id * img_size) : ((col_id + 1) * img_size),
] = _d.blob
] = _d.tensor

img_id += 1
if img_id >= max_num_img:
Expand Down
4 changes: 2 additions & 2 deletions docarray/document/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(
parent_id: Optional[str] = None,
granularity: Optional[int] = None,
adjacency: Optional[int] = None,
buffer: Optional[bytes] = None,
blob: Optional['ArrayType'] = None,
blob: Optional[bytes] = None,
tensor: Optional['ArrayType'] = None,
mime_type: Optional[str] = None,
text: Optional[str] = None,
content: Optional['DocumentContentType'] = None,
Expand Down
14 changes: 7 additions & 7 deletions docarray/document/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
granularity=0,
adjacency=0,
parent_id='',
buffer=b'',
blob=b'',
text='',
weight=0.0,
uri='',
Expand All @@ -39,8 +39,8 @@ class DocumentData:
parent_id: Optional[str] = None
granularity: Optional[int] = None
adjacency: Optional[int] = None
buffer: Optional[bytes] = None
blob: Optional['ArrayType'] = field(default=None, hash=False, compare=False)
blob: Optional[bytes] = None
tensor: Optional['ArrayType'] = field(default=None, hash=False, compare=False)
mime_type: Optional[str] = None # must be put in front of `text` `content`
text: Optional[str] = None
content: Optional['DocumentContentType'] = None
Expand All @@ -58,13 +58,13 @@ class DocumentData:

def __setattr__(self, key, value):
if value is not None:
if key == 'text' or key == 'blob' or key == 'buffer':
if key == 'text' or key == 'tensor' or key == 'blob':
# enable mutual exclusivity for content field
dv = default_values.get(key)
if type(value) != type(dv) or value != dv:
self.text = None
self.tensor = None
self.blob = None
self.buffer = None
if key == 'text':
self.mime_type = 'text/plain'
elif key == 'uri':
Expand All @@ -79,11 +79,11 @@ def __setattr__(self, key, value):
value = r or value
elif key == 'content':
if isinstance(value, bytes):
self.buffer = value
self.blob = value
elif isinstance(value, str):
self.text = value
else:
self.blob = value
self.tensor = value
value = None
elif key == 'chunks':
from ..array.chunk import ChunkArray
Expand Down
4 changes: 2 additions & 2 deletions docarray/document/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .attribute import GetAttributesMixin
from .audio import AudioDataMixin
from .buffer import BufferDataMixin
from .blob import BlobDataMixin
from .content import ContentPropertyMixin
from .convert import ConvertMixin
from .dump import UriFileMixin
Expand Down Expand Up @@ -28,7 +28,7 @@ class AllMixins(
TextDataMixin,
MeshDataMixin,
VideoDataMixin,
BufferDataMixin,
BlobDataMixin,
PlotMixin,
UriFileMixin,
SingletonSugarMixin,
Expand Down
22 changes: 11 additions & 11 deletions docarray/document/mixins/_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,23 @@ def adjacency(self, value: int):
self._data.adjacency = value

@property
def buffer(self) -> Optional[bytes]:
self._data._set_default_value_if_none('buffer')
return self._data.buffer

@buffer.setter
def buffer(self, value: bytes):
self._data.buffer = value

@property
def blob(self) -> Optional['ArrayType']:
def blob(self) -> Optional[bytes]:
self._data._set_default_value_if_none('blob')
return self._data.blob

@blob.setter
def blob(self, value: 'ArrayType'):
def blob(self, value: bytes):
self._data.blob = value

@property
def tensor(self) -> Optional['ArrayType']:
self._data._set_default_value_if_none('tensor')
return self._data.tensor

@tensor.setter
def tensor(self, value: 'ArrayType'):
self._data.tensor = value

@property
def mime_type(self) -> Optional[str]:
self._data._set_default_value_if_none('mime_type')
Expand Down
Loading

0 comments on commit 5b2ea94

Please sign in to comment.