Skip to content
This repository has been archived by the owner. It is now read-only.
Permalink
Browse files
feat(reducer): add concat reducer
  • Loading branch information
jemmyshin committed Sep 6, 2019
1 parent 3d2a74f commit 928608483c63571df870bc1023503c16067d62c8
Showing with 44 additions and 5 deletions.
  1. +1 −0 gnes/router/__init__.py
  2. +31 −0 gnes/router/reduce.py
  3. +12 −5 tests/test_router.py
@@ -27,6 +27,7 @@
'DocFillReducer': 'reduce',
'PublishRouter': 'map',
'DocBatchRouter': 'map',
'ConcatEmbedRouter': 'reduce'
}

register_all_class(_cls2file_map, 'router')
@@ -14,8 +14,10 @@
# limitations under the License.

from typing import List
import numpy as np

from .base import BaseReduceRouter, BaseTopkReduceRouter
from ..proto import gnes_pb2, blob2array, array2blob


class DocFillReducer(BaseReduceRouter):
@@ -74,3 +76,32 @@ def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str:

def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str):
x.chunk.doc_id, x.chunk.offset = map(int, k.split('-'))


class ConcatEmbedRouter(BaseReduceRouter):
"""
Gather all embeddings from multiple encoders and concat them on a specific axis.
In default, concat will happen on the last axis.
"""

def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *args, **kwargs):
body = getattr(msg, msg.WhichOneof('body'))
msg_type = type(getattr(body, body.WhichOneof('body')))
if msg_type == gnes_pb2.Request.QueryRequest:
for i in range(len(msg.request.search.query.chunks)):
concat_embedding = array2blob(
np.concatenate([blob2array(m.request.search.query.chunks[i].embedding) for m in accum_msgs],
axis=1))
msg.request.search.query.chunks[i].embedding.CopyFrom(concat_embedding)

elif msg_type == gnes_pb2.Request.IndexRequest:
for i in range(len(msg.request.index.docs)):
for j in range(len(msg.request.index.docs[i].chunks)):
concat_embedding = array2blob(
np.concatenate(
[blob2array(m.request.index.docs[i].chunks[j].embedding) for m in accum_msgs], axis=1))
msg.request.index.docs[i].chunks[j].embedding.CopyFrom(concat_embedding)
else:
self.logger.error('dont know how to handle %s' % msg_type)

super().apply(msg, accum_msgs)
@@ -334,7 +334,7 @@ def test_doc_sum_reduce_router(self):
self.assertGreaterEqual(r.response.search.topk_results[0].score.value,
r.response.search.topk_results[-1].score.value)

@unittest.SkipTest
# @unittest.SkipTest
def test_concat_router(self):
args = set_router_parser().parse_args([
'--yaml_path', self.concat_router_yaml,
@@ -345,29 +345,36 @@ def test_concat_router(self):
'--port_out', str(args.port_in),
'--socket_in', str(SocketType.PULL_CONNECT)
])
# 10 chunks in each doc, dimension of chunk embedding is (5, 2)
with RouterService(args), ZmqClient(c_args) as c1:
msg = gnes_pb2.Message()
msg.request.search.query.chunk_embeddings.CopyFrom(array2blob(np.random.random([5, 2])))
for i in range(10):
c = msg.request.search.query.chunks.add()
c.embedding.CopyFrom(array2blob(np.random.random([5, 2])))
msg.envelope.num_part.extend([1, 3])
c1.send_message(msg)
c1.send_message(msg)
c1.send_message(msg)
r = c1.recv_message()
self.assertSequenceEqual(r.envelope.num_part, [1])
print(r.envelope.routes)
self.assertEqual(r.request.search.query.chunk_embeddings.shape, [5, 6])
for i in range(10):
self.assertEqual(r.request.search.query.chunks[i].embedding.shape, [5, 6])

for j in range(1, 4):
d = msg.request.index.docs.add()
d.chunk_embeddings.CopyFrom(array2blob(np.random.random([5, 2 * j])))
for k in range(10):
c = d.chunks.add()
c.embedding.CopyFrom(array2blob(np.random.random([5, 2])))

c1.send_message(msg)
c1.send_message(msg)
c1.send_message(msg)
r = c1.recv_message()
self.assertSequenceEqual(r.envelope.num_part, [1])
for j in range(1, 4):
self.assertEqual(r.request.index.docs[j - 1].chunk_embeddings.shape, [5, 6 * j])
for i in range(10):
self.assertEqual(r.request.index.docs[j - 1].chunks[i].embedding.shape, [5, 6])

def test_multimap_multireduce(self):
# p1 ->

0 comments on commit 9286084

Please sign in to comment.