Skip to content

Commit

Permalink
test: test _extract content from docset
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Mar 3, 2021
1 parent e581f89 commit 939408b
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tests/unit/types/sets/test_documentset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy

import pytest
import numpy as np

from jina import Document
from jina.types.sets import DocumentSet
Expand Down Expand Up @@ -61,6 +62,7 @@ def test_union(docset, document_factory):
for idx in range(0, 6):
assert union[idx + 3].id == additional_docset[idx].id


def test_union_inplace(docset, document_factory):
additional_docset = DocumentSet([])
for idx in range(4, 10):
Expand Down Expand Up @@ -220,3 +222,57 @@ def callback_fn(docs, *args, **kwargs) -> None:
add_chunk(doc)
add_match(doc)
add_match(doc)


@pytest.mark.parametrize('num_rows', [1, 2, 3])
@pytest.mark.parametrize('field', ['content', 'blob', 'embedding'])
def test_get_content(num_rows, field):
batch_size = 10
embed_size = 20

kwargs = {
field: np.random.random((num_rows, embed_size))
}

docs = DocumentSet([Document(**kwargs) for _ in range(batch_size)])
docs.append(Document())

contents, pts = docs._extract_docs(field)

assert contents.shape == (batch_size, num_rows, embed_size)


@pytest.mark.parametrize('field', ['id', 'text'])
def test_get_content_text_fields(field):
batch_size = 10

kwargs = {
field: 'text'
}

docs = DocumentSet([Document(**kwargs) for _ in range(batch_size)])

contents, pts = docs._extract_docs(field)

assert contents.shape == (batch_size,)
assert len(contents) == batch_size
for content in contents:
assert content == 'text'


@pytest.mark.parametrize('field', ['content', 'buffer'])
def test_get_content_bytes_fields(field):
batch_size = 10

kwargs = {
field: b'bytes'
}

docs = DocumentSet([Document(**kwargs) for _ in range(batch_size)])

contents, pts = docs._extract_docs(field)

assert contents.shape == (batch_size,)
assert len(contents) == batch_size
for content in contents:
assert content == b'bytes'

0 comments on commit 939408b

Please sign in to comment.