Skip to content

Commit

Permalink
refactor: hide get_attributes to private (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Jan 10, 2022
1 parent 9112c5a commit e62c1ad
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docarray/array/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __getitem__(
if isinstance(_attrs, str):
_attrs = (index[1],)

return _docs.get_attributes(*_attrs)
return _docs._get_attributes(*_attrs)
elif isinstance(index[0], bool):
return DocumentArray(itertools.compress(self._data, index))
elif isinstance(index[0], int):
Expand Down
4 changes: 2 additions & 2 deletions docarray/array/mixins/getattr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
class GetAttributeMixin:
"""Helpers that provide attributes getter in bulk """

def get_attributes(self, *fields: str) -> List:
def _get_attributes(self, *fields: str) -> List:
"""Return all nonempty values of the fields from all docs this array contains
:param fields: Variable length argument with the name of the fields to extract
Expand All @@ -25,7 +25,7 @@ def get_attributes(self, *fields: str) -> List:
fields.remove('blob')

if fields:
contents = [doc.get_attributes(*fields) for doc in self]
contents = [doc._get_attributes(*fields) for doc in self]
if len(fields) > 1:
contents = list(map(list, zip(*contents)))
if b_index is None and e_index is None:
Expand Down
4 changes: 2 additions & 2 deletions docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def summary(self):
from rich.console import Console
from rich import box

all_attrs = self.get_attributes('non_empty_fields')
all_attrs = self._get_attributes('non_empty_fields')
attr_counter = Counter(all_attrs)

table = Table(box=box.SIMPLE, title='Documents Summary')
Expand Down Expand Up @@ -74,7 +74,7 @@ def summary(self):
attr_table.add_column('#Unique values')
attr_table.add_column('Has empty value')

all_attrs_values = self.get_attributes(*all_attrs_names)
all_attrs_values = self._get_attributes(*all_attrs_names)
if len(all_attrs_names) == 1:
all_attrs_values = [all_attrs_values]
for _a, _a_name in zip(all_attrs_values, all_attrs_names):
Expand Down
2 changes: 1 addition & 1 deletion docarray/document/mixins/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class GetAttributesMixin:
"""Provide helper functions for :class:`Document` to allow advanced set and get attributes """

def get_attributes(self, *fields: str) -> Union[Any, List[Any]]:
def _get_attributes(self, *fields: str) -> Union[Any, List[Any]]:
"""Bulk fetch Document fields and return a list of the values of these fields
:param fields: the variable length values to extract from the document
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/array/mixins/test_getset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def test_set_embeddings_multi_kind(array):

@pytest.mark.parametrize('da', da_and_dam())
def test_da_get_embeddings(da):
np.testing.assert_almost_equal(da.get_attributes('embedding'), da.embeddings)
np.testing.assert_almost_equal(da._get_attributes('embedding'), da.embeddings)
np.testing.assert_almost_equal(da[:, 'embedding'], da.embeddings)


@pytest.mark.parametrize('da', da_and_dam())
Expand Down Expand Up @@ -65,7 +66,6 @@ def test_blobs_getter_da(da):
blobs = np.random.random((100, 10, 10))
da.blobs = blobs
assert len(da) == 100
np.testing.assert_almost_equal(da.get_attributes('blob'), da.blobs)
np.testing.assert_almost_equal(da.blobs, blobs)

da.blobs = None
Expand All @@ -77,7 +77,7 @@ def test_blobs_getter_da(da):
@pytest.mark.parametrize('da', da_and_dam())
def test_texts_getter_da(da):
assert len(da.texts) == 100
assert da.texts == da.get_attributes('text')
assert da.texts == da[:, 'text']
texts = ['text' for _ in range(100)]
da.texts = texts
assert da.texts == texts
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/array/mixins/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def test_batching(da, batch_size, shuffle):
all_ids = []
for v in da.batch(batch_size=batch_size, shuffle=shuffle):
assert len(v) <= batch_size
all_ids.extend(v.get_attributes('id'))
all_ids.extend(v[:, 'id'])

if shuffle:
assert all_ids != da.get_attributes('id')
assert all_ids != da[:, 'id']
else:
assert all_ids == da.get_attributes('id')
assert all_ids == da[:, 'id']
14 changes: 7 additions & 7 deletions tests/unit/array/mixins/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_matching_retrieves_correct_number(
D1.match(
D2, metric='sqeuclidean', limit=limit, batch_size=batch_size, only_id=only_id
)
for m in D1.get_attributes('matches'):
for m in D1[:, 'matches']:
if limit is None:
assert len(m) == len(D2)
else:
Expand All @@ -106,14 +106,14 @@ def test_matching_same_results_with_sparse(
# use match with numpy arrays
D1.match(D2, metric=metric, only_id=only_id)
distances = []
for m in D1.get_attributes('matches'):
for m in D1[:, 'matches']:
for d in m:
distances.extend([d.scores[metric].value])

# use match with sparse arrays
D1_sp.match(D2_sp, metric=metric, is_sparse=True)
distances_sparse = []
for m in D1.get_attributes('matches'):
for m in D1[:, 'matches']:
for d in m:
distances_sparse.extend([d.scores[metric].value])

Expand All @@ -132,15 +132,15 @@ def test_matching_same_results_with_batch(
# use match without batches
D1.match(D2, metric=metric, only_id=only_id)
distances = []
for m in D1.get_attributes('matches'):
for m in D1[:, 'matches']:
for d in m:
distances.extend([d.scores[metric].value])

# use match with batches
D1_batch.match(D2_batch, metric=metric, batch_size=10)

distances_batch = []
for m in D1.get_attributes('matches'):
for m in D1[:, 'matches']:
for d in m:
distances_batch.extend([d.scores[metric].value])

Expand All @@ -161,14 +161,14 @@ def scipy_cdist_metric(X, Y, *args):
# match with our custom metric
D1.match(D2, metric=metric)
distances = []
for m in D1.get_attributes('matches'):
for m in D1[:, 'matches']:
for d in m:
distances.extend([d.scores[metric].value])

# match with callable cdist function from scipy
D1_scipy.match(D2, metric=scipy_cdist_metric, only_id=only_id)
distances_scipy = []
for m in D1.get_attributes('matches'):
for m in D1[:, 'matches']:
for d in m:
distances_scipy.extend([d.scores[metric].value])

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/array/mixins/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_plot_embeddings_same_path(tmpdir):

def test_summary_homo_hetero():
da = DocumentArray.empty(100)
da.get_attributes()
da._get_attributes()
da.summary()

da[0].pop('id')
Expand All @@ -69,4 +69,4 @@ def test_summary_homo_hetero():
def test_empty_get_attributes():
da = DocumentArray.empty(10)
da[0].pop('id')
print(da.get_attributes('id'))
print(da[:, 'id'])
6 changes: 3 additions & 3 deletions tests/unit/array/mixins/test_traverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_traverse_root_match_chunk(doc_req, filter_fn):
@pytest.mark.parametrize('filter_fn', [(lambda d: True), None])
def test_traverse_flatten_embedding(doc_req, filter_fn):
flattened_results = doc_req.traverse_flat('r,c', filter_fn=filter_fn)
ds = np.stack(flattened_results.get_attributes('embedding'))
ds = flattened_results.embeddings
assert ds.shape == (num_docs + num_chunks_per_doc * num_docs, 10)


Expand Down Expand Up @@ -137,10 +137,10 @@ def test_traverse_flatten_root_match_chunk(doc_req, filter_fn):
@pytest.mark.parametrize('filter_fn', [(lambda d: True), None])
def test_traverse_flattened_per_path_embedding(doc_req, filter_fn):
flattened_results = list(doc_req.traverse_flat_per_path('r,c', filter_fn=filter_fn))
ds = np.stack(flattened_results[0].get_attributes('embedding'))
ds = flattened_results[0].embeddings
assert ds.shape == (num_docs, 10)

ds = np.stack(flattened_results[1].get_attributes('embedding'))
ds = flattened_results[1].embeddings
assert ds.shape == (num_docs * num_chunks_per_doc, 10)


Expand Down
8 changes: 4 additions & 4 deletions tests/unit/document/test_docdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_get_attr_values():
'tags__id',
'tags__e__2__f',
]
res = d.get_attributes(*required_keys)
res = d._get_attributes(*required_keys)
assert len(res) == len(required_keys)
assert res[required_keys.index('id')] == '123'
assert res[required_keys.index('tags__feature1')] == 121
Expand All @@ -176,18 +176,18 @@ def test_get_attr_values():
assert res[required_keys.index('tags__e__2__f')] == 'g'

required_keys_2 = ['tags', 'text']
res2 = d.get_attributes(*required_keys_2)
res2 = d._get_attributes(*required_keys_2)
assert len(res2) == 2
assert res2[required_keys_2.index('text')] == 'document'
assert res2[required_keys_2.index('tags')] == d.tags

d = Document({'id': '123', 'tags': {'outterkey': {'innerkey': 'real_value'}}})
required_keys_3 = ['tags__outterkey__innerkey']
res3 = d.get_attributes(*required_keys_3)
res3 = d._get_attributes(*required_keys_3)
assert res3 == 'real_value'

d = Document(content=np.array([1, 2, 3]))
res4 = np.stack(d.get_attributes(*['blob']))
res4 = np.stack(d._get_attributes(*['blob']))
np.testing.assert_equal(res4, np.array([1, 2, 3]))


Expand Down

0 comments on commit e62c1ad

Please sign in to comment.