Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(test): fix router tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Aug 29, 2019
1 parent 249aef9 commit 3818c9a
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 74 deletions.
2 changes: 2 additions & 0 deletions gnes/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@
# Router
BaseReduceRouter = router_base.BaseReduceRouter
BaseRouter = router_base.BaseRouter
BaseTopkReduceRouter = router_base.BaseTopkReduceRouter
BaseMapRouter = router_base.BaseMapRouter
PipelineRouter = router_base.PipelineRouter
14 changes: 5 additions & 9 deletions gnes/preprocessor/io_utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

import io
import re
from typing import List

import numpy as np
import soundfile as sf

from .ffmpeg import compile_args
from .helper import _check_input, run_command

from typing import List

DEFAULT_SILENCE_DURATION = 0.3
DEFAULT_SILENCE_THRESHOLD = -60

Expand All @@ -34,7 +34,6 @@ def capture_audio(input_fn: str = 'pipe:',
start_time: float = None,
end_time: float = None,
**kwargs) -> List['np.ndarray']:

_check_input(input_fn, input_data)

input_kwargs = {}
Expand Down Expand Up @@ -78,12 +77,10 @@ def get_chunk_times(input_fn: str = 'pipe:',
end_time: float = None):
_check_input(input_fn, input_data)

silence_start_re = re.compile(
' silence_start: (?P<start>[0-9]+(\.?[0-9]*))$')
silence_end_re = re.compile(' silence_end: (?P<end>[0-9]+(\.?[0-9]*)) ')
silence_start_re = re.compile(r' silence_start: (?P<start>[0-9]+(\.?[0-9]*))$')
silence_end_re = re.compile(r' silence_end: (?P<end>[0-9]+(\.?[0-9]*)) ')
total_duration_re = re.compile(
'size=[^ ]+ time=(?P<hours>[0-9]{2}):(?P<minutes>[0-9]{2}):(?P<seconds>[0-9\.]{5}) bitrate='
)
r'size=[^ ]+ time=(?P<hours>[0-9]{2}):(?P<minutes>[0-9]{2}):(?P<seconds>[0-9\.]{5}) bitrate=')

input_kwargs = {}
if start_time is not None:
Expand Down Expand Up @@ -162,7 +159,6 @@ def split_audio(input_fn: str = 'pipe:',
for i, (start_time, end_time) in enumerate(chunk_times):
time = end_time - start_time
if time < 0:

continue
input_kwargs = {
'ss': start_time,
Expand Down
3 changes: 1 addition & 2 deletions gnes/router/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def apply(self, msg: 'gnes_pb2.Message', *args, **kwargs) -> Generator:


class DocBatchRouter(BaseMapRouter):
def __init__(self, batch_size: int, *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.batch_size = batch_size

def apply(self, msg: 'gnes_pb2.Message', *args, **kwargs) -> Generator:
if self.batch_size and self.batch_size > 0:
Expand Down
16 changes: 15 additions & 1 deletion gnes/router/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,33 @@ def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *

class DocTopkReducer(BaseTopkReduceRouter):
"""
Gather all chunks by their doc_id, result in a topk doc list
Gather all docs by their doc_id, result in a topk doc list
"""

def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str:
return x.doc.doc_id

def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str):
x.doc.doc_id = k


class Chunk2DocTopkReducer(BaseTopkReduceRouter):
"""
Gather all chunks by their doc_id, result in a topk doc list
"""

def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str:
return x.chunk.doc_id

def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str):
x.doc.doc_id = k


class ChunkTopkReducer(BaseTopkReduceRouter):
"""
Gather all chunks by their chunk_id, aka doc_id-offset, result in a topk chunk list
"""

def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str:
return '%d-%d' % (x.chunk.doc_id, x.chunk.offset)

Expand Down
2 changes: 1 addition & 1 deletion gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def run(self):
try:
self._run()
except Exception as ex:
self.logger.error(ex)
self.logger.error(ex, exc_info=True)

def _start_auto_dump(self):
if self.args.dump_interval > 0 and not self.args.read_only:
Expand Down
130 changes: 69 additions & 61 deletions tests/test_router.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import unittest

Expand All @@ -17,9 +18,9 @@ def setUp(self):
self.publish_router_yaml = '!PublishRouter {parameters: {num_part: 2}}'
self.batch_router_yaml = '!DocBatchRouter {gnes_config: {batch_size: 2}}'
self.reduce_router_yaml = 'BaseReduceRouter'
self.chunk_router_yaml = 'ChunkToDocRouter'
self.chunk_sum_yaml = 'ChunkSumRouter'
self.doc_router_yaml = 'DocFillRouter'
self.chunk_router_yaml = 'Chunk2DocTopkReducer'
self.chunk_sum_yaml = 'ChunkTopkReducer'
self.doc_router_yaml = 'DocFillReducer'
self.doc_sum_yaml = 'DocSumRouter'
self.concat_router_yaml = 'ConcatEmbedRouter'

Expand Down Expand Up @@ -101,18 +102,18 @@ def test_chunk_reduce_router(self):
with RouterService(args), ZmqClient(c_args) as c1:
msg = gnes_pb2.Message()
s = msg.response.search.topk_results.add()
s.score = 0.1
s.score_explained = '1-c1'
s.score.value = 0.1
s.score.explained = '"1-c1"'
s.chunk.doc_id = 1

s = msg.response.search.topk_results.add()
s.score = 0.2
s.score_explained = '1-c2'
s.score.value = 0.2
s.score.explained = '"1-c2"'
s.chunk.doc_id = 2

s = msg.response.search.topk_results.add()
s.score = 0.3
s.score_explained = '1-c3'
s.score.value = 0.3
s.score.explained = '"1-c3"'
s.chunk.doc_id = 1

msg.envelope.num_part.extend([1, 2])
Expand All @@ -121,32 +122,35 @@ def test_chunk_reduce_router(self):
msg.response.search.ClearField('topk_results')

s = msg.response.search.topk_results.add()
s.score = 0.2
s.score_explained = '2-c1'
s.score.value = 0.2
s.score.explained = '"2-c1"'
s.chunk.doc_id = 1

s = msg.response.search.topk_results.add()
s.score = 0.2
s.score_explained = '2-c2'
s.score.value = 0.2
s.score.explained = '"2-c2"'
s.chunk.doc_id = 2

s = msg.response.search.topk_results.add()
s.score = 0.3
s.score_explained = '2-c3'
s.score.value = 0.3
s.score.explained = '"2-c3"'
s.chunk.doc_id = 3
c1.send_message(msg)
r = c1.recv_message()
self.assertSequenceEqual(r.envelope.num_part, [1])
self.assertEqual(len(r.response.search.topk_results), 3)
self.assertGreaterEqual(r.response.search.topk_results[0].score, r.response.search.topk_results[-1].score)
self.assertGreaterEqual(r.response.search.topk_results[0].score.value,
r.response.search.topk_results[-1].score.value)
print(r.response.search.topk_results)
self.assertEqual(r.response.search.topk_results[0].score_explained, '1-c1\n1-c3\n2-c1\n')
self.assertEqual(r.response.search.topk_results[1].score_explained, '1-c2\n2-c2\n')
self.assertEqual(r.response.search.topk_results[2].score_explained, '2-c3\n')
self.assertEqual(json.loads(r.response.search.topk_results[0].score.explained)['operand'],
['1-c1', '1-c3', '2-c1'])
self.assertEqual(json.loads(r.response.search.topk_results[1].score.explained)['operand'],
['1-c2', '2-c2'])
self.assertEqual(json.loads(r.response.search.topk_results[2].score.explained)['operand'], ['2-c3'])

self.assertAlmostEqual(r.response.search.topk_results[0].score, 0.6)
self.assertAlmostEqual(r.response.search.topk_results[1].score, 0.4)
self.assertAlmostEqual(r.response.search.topk_results[2].score, 0.3)
self.assertAlmostEqual(r.response.search.topk_results[0].score.value, 0.6)
self.assertAlmostEqual(r.response.search.topk_results[1].score.value, 0.4)
self.assertAlmostEqual(r.response.search.topk_results[2].score.value, 0.3)

def test_doc_reduce_router(self):
args = set_router_parser().parse_args([
Expand All @@ -163,16 +167,16 @@ def test_doc_reduce_router(self):

# shard1 only has d1
s = msg.response.search.topk_results.add()
s.score = 0.1
s.score.value = 0.1
s.doc.doc_id = 1
s.doc.raw_text = 'd1'

s = msg.response.search.topk_results.add()
s.score = 0.2
s.score.value = 0.2
s.doc.doc_id = 2

s = msg.response.search.topk_results.add()
s.score = 0.3
s.score.value = 0.3
s.doc.doc_id = 3

msg.envelope.num_part.extend([1, 2])
Expand All @@ -182,16 +186,16 @@ def test_doc_reduce_router(self):

# shard2 has d2 and d3
s = msg.response.search.topk_results.add()
s.score = 0.1
s.score.value = 0.1
s.doc.doc_id = 1

s = msg.response.search.topk_results.add()
s.score = 0.2
s.score.value = 0.2
s.doc.doc_id = 2
s.doc.raw_text = 'd2'

s = msg.response.search.topk_results.add()
s.score = 0.3
s.score.value = 0.3
s.doc.doc_id = 3
s.doc.raw_text = 'd3'

Expand All @@ -202,8 +206,8 @@ def test_doc_reduce_router(self):
print(r.response.search.topk_results)
self.assertSequenceEqual(r.envelope.num_part, [1])
self.assertEqual(len(r.response.search.topk_results), 3)
self.assertGreaterEqual(r.response.search.topk_results[0].score, r.response.search.topk_results[-1].score)

@unittest.SkipTest
def test_chunk_sum_reduce_router(self):
args = set_router_parser().parse_args([
'--yaml_path', self.chunk_sum_yaml,
Expand All @@ -217,18 +221,18 @@ def test_chunk_sum_reduce_router(self):
with RouterService(args), ZmqClient(c_args) as c1:
msg = gnes_pb2.Message()
s = msg.response.search.topk_results.add()
s.score = 0.6
s.score_explained = '1-c1\n1-c3\n2-c1\n'
s.score.value = 0.6
s.score.explained = json.dumps(['1-c1', '1-c3', '2-c1'])
s.doc.doc_id = 1

s = msg.response.search.topk_results.add()
s.score = 0.4
s.score_explained = '1-c2\n2-c2\n'
s.score.value = 0.4
s.score.explained = json.dumps(['1-c2', '2-c2'])
s.doc.doc_id = 2

s = msg.response.search.topk_results.add()
s.score = 0.3
s.score_explained = '2-c3\n'
s.score.value = 0.3
s.score.explained = json.dumps(['2-c3'])
s.doc.doc_id = 3

msg.envelope.num_part.extend([1, 2])
Expand All @@ -237,33 +241,35 @@ def test_chunk_sum_reduce_router(self):
msg.response.search.ClearField('topk_results')

s = msg.response.search.topk_results.add()
s.score = 0.5
s.score_explained = '2-c1\n1-c2\n1-c1\n'
s.score.value = 0.5
s.score.explained = json.dumps(['2-c1', '1-c2', '1-c1'])
s.doc.doc_id = 2

s = msg.response.search.topk_results.add()
s.score = 0.3
s.score_explained = '1-c3\n2-c2\n'
s.score.value = 0.3
s.score.explained = json.dumps(['1-c3', '2-c2'])
s.doc.doc_id = 3

s = msg.response.search.topk_results.add()
s.score = 0.1
s.score_explained = '2-c3\n'
s.score.value = 0.1
s.score.explained = json.dumps(['2-c3'])
s.doc.doc_id = 1
c1.send_message(msg)
r = c1.recv_message()
self.assertSequenceEqual(r.envelope.num_part, [1])
self.assertEqual(len(r.response.search.topk_results), 3)
self.assertGreaterEqual(r.response.search.topk_results[0].score, r.response.search.topk_results[-1].score)
self.assertGreaterEqual(r.response.search.topk_results[0].score.value,
r.response.search.topk_results[-1].score.value)
print(r.response.search.topk_results)
self.assertEqual(r.response.search.topk_results[0].score_explained, '1-c2\n2-c2\n\n2-c1\n1-c2\n1-c1\n\n')
self.assertEqual(r.response.search.topk_results[1].score_explained, '1-c1\n1-c3\n2-c1\n\n2-c3\n\n')
self.assertEqual(r.response.search.topk_results[2].score_explained, '2-c3\n\n1-c3\n2-c2\n\n')
self.assertEqual(r.response.search.topk_results[0].score.explained, '1-c2\n2-c2\n\n2-c1\n1-c2\n1-c1\n\n')
self.assertEqual(r.response.search.topk_results[1].score.explained, '1-c1\n1-c3\n2-c1\n\n2-c3\n\n')
self.assertEqual(r.response.search.topk_results[2].score.explained, '2-c3\n\n1-c3\n2-c2\n\n')

self.assertAlmostEqual(r.response.search.topk_results[0].score, 0.9)
self.assertAlmostEqual(r.response.search.topk_results[1].score, 0.7)
self.assertAlmostEqual(r.response.search.topk_results[2].score, 0.6)
self.assertAlmostEqual(r.response.search.topk_results[0].score.value, 0.9)
self.assertAlmostEqual(r.response.search.topk_results[1].score.value, 0.7)
self.assertAlmostEqual(r.response.search.topk_results[2].score.value, 0.6)

@unittest.SkipTest
def test_doc_sum_reduce_router(self):
args = set_router_parser().parse_args([
'--yaml_path', self.doc_sum_yaml,
Expand All @@ -278,45 +284,45 @@ def test_doc_sum_reduce_router(self):
msg = gnes_pb2.Message()

s = msg.response.search.topk_results.add()
s.score = 0.4
s.score.value = 0.4
s.doc.doc_id = 1
s.doc.raw_text = 'd3'
s.score_explained = '1-d3\n'
s.score.explained = '1-d3\n'

s = msg.response.search.topk_results.add()
s.score = 0.3
s.score.value = 0.3
s.doc.doc_id = 2
s.doc.raw_text = 'd2'
s.score_explained = '1-d2\n'
s.score.explained = '1-d2\n'

s = msg.response.search.topk_results.add()
s.score = 0.2
s.score.value = 0.2
s.doc.doc_id = 3
s.doc.raw_text = 'd1'
s.score_explained = '1-d3\n'
s.score.explained = '1-d3\n'

msg.envelope.num_part.extend([1, 2])
c1.send_message(msg)

msg.response.search.ClearField('topk_results')

s = msg.response.search.topk_results.add()
s.score = 0.5
s.score.value = 0.5
s.doc.doc_id = 1
s.doc.raw_text = 'd2'
s.score_explained = '2-d2\n'
s.score.explained = '2-d2\n'

s = msg.response.search.topk_results.add()
s.score = 0.2
s.score.value = 0.2
s.doc.doc_id = 2
s.doc.raw_text = 'd1'
s.score_explained = '2-d1\n'
s.score.explained = '2-d1\n'

s = msg.response.search.topk_results.add()
s.score = 0.1
s.score.value = 0.1
s.doc.doc_id = 3
s.doc.raw_text = 'd3'
s.score_explained = '2-d3\n'
s.score.explained = '2-d3\n'

msg.response.search.top_k = 5
c1.send_message(msg)
Expand All @@ -325,8 +331,10 @@ def test_doc_sum_reduce_router(self):
print(r.response.search.topk_results)
self.assertSequenceEqual(r.envelope.num_part, [1])
self.assertEqual(len(r.response.search.topk_results), 3)
self.assertGreaterEqual(r.response.search.topk_results[0].score, r.response.search.topk_results[-1].score)
self.assertGreaterEqual(r.response.search.topk_results[0].score.value,
r.response.search.topk_results[-1].score.value)

@unittest.SkipTest
def test_concat_router(self):
args = set_router_parser().parse_args([
'--yaml_path', self.concat_router_yaml,
Expand Down

0 comments on commit 3818c9a

Please sign in to comment.