Skip to content

Commit

Permalink
test: refactor driver test (#1452)
Browse files Browse the repository at this point in the history
* test: refactor driver

* test: refactor driver test

* test: refactor driver test

* test: refactor driver test
  • Loading branch information
bwanglzu committed Dec 17, 2020
1 parent 79fd844 commit 2d0c489
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 52 deletions.
9 changes: 1 addition & 8 deletions jina/drivers/querylang/queryset/helper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
from functools import partial
from typing import Callable, List, Type, Union, Iterable

from ....excepts import LookupyError

## Exceptions

class LookupyError(Exception):
"""Base exception class for all exceptions raised by lookupy"""
pass


## utility functions

def iff(precond: Callable, val: Union[int, str], f: Callable) -> bool:
"""If and only if the precond is True
Expand Down
4 changes: 4 additions & 0 deletions jina/excepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,5 +183,9 @@ class BadFlowYAMLVersion(Exception):
""" Exception when Flow YAML config specifies a wrong version number"""


class LookupyError(Exception):
"""Base exception class for all exceptions raised by lookupy"""


class EventLoopError(Exception):
""" Exception when a running event loop is found but not under jupyter or ipython """
15 changes: 11 additions & 4 deletions tests/unit/drivers/querylang/queryset/test_dunderkeys.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import pytest

from jina.drivers.querylang.queryset.dunderkey import (
dunderkey,
dunder_init,
dunder_get,
dunder_partition,
undunder_keys,
dunder_truncate,
)
from jina.proto.jina_pb2 import DocumentProto
from jina import Document

def test_dunderkey():
assert dunderkey('a', 'b', 'c') == 'a__b__c'

def test_dunder_init():
assert dunder_init('a__b__c') == 'a__b'

def test_dunder_get():
assert dunder_get({'a': {'b': 5}}, 'a__b') == 5
Expand All @@ -22,9 +29,9 @@ class A:

assert dunder_get(A, 'b__c') == 5

d = DocumentProto()
d.tags['a'] = 'hello'
assert dunder_get(d, 'tags__a') == 'hello'
with Document() as d:
d.tags['a'] = 'hello'
assert dunder_get(d, 'tags__a') == 'hello'

# Error on invalid key

Expand Down
33 changes: 32 additions & 1 deletion tests/unit/drivers/querylang/queryset/test_lookup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import pytest

from jina.types.document import Document
from jina.drivers.querylang.queryset.lookup import LookupLeaf
from jina.drivers.querylang.queryset.lookup import LookupLeaf, Q, QuerySet

from tests import random_docs


class MockId:
Expand All @@ -16,6 +20,9 @@ class MockIter:
def __init__(self, iterable):
self.iter = iterable

@pytest.fixture(scope='function')
def docs():
return random_docs(num_docs=10)

def test_lookup_leaf_exact():
leaf = LookupLeaf(id__exact=1)
Expand Down Expand Up @@ -173,3 +180,27 @@ def test_lookup_leaf_None():
assert leaf.evaluate(mock0)
mock1 = MockId(4)
assert not leaf.evaluate(mock1)

def test_docs_filter(docs):
filtered_docs = QuerySet(docs).filter(tags__id__lt=5, tags__id__gt=3)
filtered_docs = list(filtered_docs)
assert len(filtered_docs) == 1
for d in filtered_docs:
assert (3 < d.tags['id'] < 5)


def test_docs_filter_equal(docs):
filtered_docs = QuerySet(docs).filter(tags__id=4)
filtered_docs = list(filtered_docs)
assert len(filtered_docs) == 1
for d in filtered_docs:
assert int(d.tags['id']) == 4
assert len(d.chunks) == 5


def test_nested_chunks_filter(docs):
filtered_docs = QuerySet(docs).filter(Q(chunks__filter=Q(tags__id__lt=35, tags__id__gt=33)))
filtered_docs = list(filtered_docs)
assert len(filtered_docs) == 1
for d in filtered_docs:
assert len(d.chunks) == 5
30 changes: 0 additions & 30 deletions tests/unit/drivers/querylang/queryset/test_queryset.py

This file was deleted.

16 changes: 7 additions & 9 deletions tests/unit/drivers/test_matches2doc_rank_drivers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np

from jina import Document
from jina.drivers.rank import Matches2DocRankDriver
from jina.executors.rankers import Match2DocRanker
from jina.proto import jina_pb2
from jina.types.sets import DocumentSet


Expand Down Expand Up @@ -41,17 +41,15 @@ def create_document_to_score():
# |- matches: (id: 3, parent_id: 1, score.value: 3),
# |- matches: (id: 4, parent_id: 1, score.value: 4),
# |- matches: (id: 5, parent_id: 1, score.value: 5),

doc = jina_pb2.DocumentProto()
doc = Document()
doc.id = '1' * 16
doc.length = 5
for match_id, match_score in [(2, 3), (3, 6), (4, 1), (5, 8)]:
match = doc.matches.add()
match.id = str(match_id) * 16
match.parent_id = '1' * 16
match.length = match_score
match.score.ref_id = doc.id
match.score.value = match_score
with Document() as match:
match.id = str(match_id) * 16
match.length = match_score
match.score.value = match_score
doc.matches.append(match)
return doc


Expand Down

0 comments on commit 2d0c489

Please sign in to comment.