Skip to content

Commit

Permalink
feat: add split prepare for trainer (#3167)
Browse files Browse the repository at this point in the history
  • Loading branch information
bwanglzu committed Aug 16, 2021
1 parent 25bb903 commit 7a700e5
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 1 deletion.
25 changes: 25 additions & 0 deletions .github/2.0/cookbooks/Document.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Table of Contents
- [Filter a subset of `DocumentArray` using `.find`](#filter-a-subset-of-documentarray-using-find)
- [Sample a subset of `DocumentArray` using `sample`](#sample-a-subset-of-documentarray-using-sample)
- [Shuffle a `DocumentArray` using `shuffle`](#shuffle-a-documentarray-using-shuffle)
- [Split a `DocumentArray` by tag using `split`](#split-a-documentarray-by-tag-using-split)
- [Visualize the embeddings of a `DocumentArray`](#visualize-the-embeddings-of-a-documentarray)
- [`DocumentArrayMemmap` API](#documentarraymemmap-api)
- [Create `DocumentArrayMemmap`](#create-documentarraymemmap)
Expand Down Expand Up @@ -1255,6 +1256,29 @@ shuffled_da = da.shuffle() # shuffle the DocumentArray
shuffled_da_with_seed = da.shuffle(seed=1) # shuffle the DocumentArray with seed.
```

### Split a `DocumentArray` by tag using `split`

`DocumentArray` provides function `.split` that split the `DocumentArray` into multiple :class:`DocumentArray` according to the tag value (stored in `tags`) of each :class:`Document`.
It returns a python `dict` where `Documents` with the same value on `tag` are grouped together, their orders are preserved from the original :class:`DocumentArray`.

To make use of the function:

```python
from jina import Document, DocumentArray

da = DocumentArray()
da.append(Document(tags={'category': 'c'}))
da.append(Document(tags={'category': 'c'}))
da.append(Document(tags={'category': 'b'}))
da.append(Document(tags={'category': 'a'}))
da.append(Document(tags={'category': 'a'}))

rv = da.split(tag='category')
assert len(rv['c']) == 2 # category `c` is a DocumentArray has 2 Documents
```



### Visualize the embeddings of a `DocumentArray`

`DocumentArray` provides function `.visualize` to plot document embeddings in a 2D graph. `visualize` supports 2 methods
Expand Down Expand Up @@ -1537,6 +1561,7 @@ This table summarizes the interfaces of `DocumentArrayMemmap` and `DocumentArray
| `__eq__` |||
| `sample` |||
| `shuffle` |||
| `split` |||
| `match` (L/Rvalue) |||
| `visualize` |||

Expand Down
24 changes: 23 additions & 1 deletion jina/types/arrays/search_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
import random
import operator
from typing import Dict, Optional, Union, Tuple
from collections import defaultdict
from typing import Dict, Optional, Union, Tuple, Any


if False:
Expand Down Expand Up @@ -112,3 +113,24 @@ def shuffle(self, seed: Optional[int] = None) -> 'DocumentArray':
from .document import DocumentArray

return DocumentArray(self.sample(len(self), seed=seed))

def split(self, tag: str) -> Dict[Any, 'DocumentArray']:
"""Split the `DocumentArray` into multiple DocumentArray according to the tag value of each `Document`.
:param tag: the tag name to split stored in tags.
:return: a dict where Documents with the same value on `tag` are grouped together, their orders
are preserved from the original :class:`DocumentArray`.
.. note::
If the :attr:`tags` of :class:`Document` do not contains the specified :attr:`tag`,
return an empty dict.
"""
from .document import DocumentArray

rv = defaultdict(DocumentArray)
for doc in self:
value = doc.tags.get(tag)
if not value:
continue
rv[value].append(doc)
return dict(rv)
26 changes: 26 additions & 0 deletions tests/unit/types/arrays/test_documentarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import random
from copy import deepcopy

import pytest
Expand Down Expand Up @@ -59,6 +60,17 @@ def docarray_for_cache():
return da


@pytest.fixture
def docarray_for_split():
da = DocumentArray()
da.append(Document(tags={'category': 'c'}))
da.append(Document(tags={'category': 'c'}))
da.append(Document(tags={'category': 'b'}))
da.append(Document(tags={'category': 'a'}))
da.append(Document(tags={'category': 'a'}))
return da


def test_length(docarray, docs):
assert len(docs) == len(docarray) == 3

Expand Down Expand Up @@ -439,3 +451,17 @@ def test_shuffle_with_seed():
assert len(shuffled_1) == len(shuffled_2) == len(shuffled_3) == len(da)
assert shuffled_1 == shuffled_2
assert shuffled_1 != shuffled_3


def test_split(docarray_for_split):
rv = docarray_for_split.split('category')
assert isinstance(rv, dict)
assert sorted(list(rv.keys())) == ['a', 'b', 'c']
# assure order is preserved c, b, a
assert list(rv.keys()) == ['c', 'b', 'a']
# original input c, c, b, a, a
assert len(rv['c']) == 2
assert len(rv['b']) == 1
assert len(rv['a']) == 2
rv = docarray_for_split.split('random')
assert not rv # wrong tag returns empty dict
23 changes: 23 additions & 0 deletions tests/unit/types/arrays/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ def memmap_with_text_and_embedding(tmpdir):
dam.clear()


@pytest.fixture
def memmap_for_split(tmpdir):
da = DocumentArrayMemmap(tmpdir)
da.append(Document(tags={'category': 'c'}))
da.append(Document(tags={'category': 'c'}))
da.append(Document(tags={'category': 'b'}))
da.append(Document(tags={'category': 'a'}))
da.append(Document(tags={'category': 'a'}))
return da


def test_memmap_append_extend(tmpdir):
dam = DocumentArrayMemmap(tmpdir)
docs = list(random_docs(100))
Expand Down Expand Up @@ -444,3 +455,15 @@ def test_memmap_mutate(tmpdir):

da.clear()
assert not len(da)


def test_split(memmap_for_split):
rv = memmap_for_split.split('category')
assert isinstance(rv, dict)
assert sorted(list(rv.keys())) == ['a', 'b', 'c']
# assure order is preserved c, b, a
assert list(rv.keys()) == ['c', 'b', 'a']
# original input c, c, b, a, a
assert len(rv['c']) == 2
assert len(rv['b']) == 1
assert len(rv['a']) == 2

0 comments on commit 7a700e5

Please sign in to comment.