Skip to content
Permalink
Browse files

fix: bugs for integrated test

  • Loading branch information...
Jem
Jem committed Jul 12, 2019
1 parent 72a8bd9 commit 8780a4da75f943e5f43499ec047f69c948fcac55
Showing with 16 additions and 4 deletions.
  1. +13 −1 gnes/preprocessor/image/simple.py
  2. +1 −1 gnes/service/encoder.py
  3. +2 −2 gnes/service/indexer.py
@@ -34,6 +34,7 @@ def __init__(self, window_size: int = 64,
self.stride_wide = stride_wide

def apply(self, doc: 'gnes_pb2.Document'):
super().apply(doc)
if doc.raw_bytes:
img = np.array(Image.open(io.BytesIO(doc.raw_bytes)))
image_set = self._get_all_sliding_window(img)
@@ -87,7 +88,18 @@ def _get_all_chunks_weight(self, image_set) -> List[float]:
class WeightedSlidingPreprocessor(BaseSlidingPreprocessor):

def _get_all_chunks_weight(self, image_set) -> List[float]:
raise NotImplementedError
weight = np.zeros([len(image_set)])
# n_channel is usually 3 for RGB images
n_channel = image_set[0].shape[-1]
for i in range(len(image_set)):
# calcualte the variance of histgram of pixels
weight[i] = sum([np.histogram(image_set[i][:, :, _])[0].var()
for _ in range(n_channel)])
weight = weight / weight.sum()

# normalized result
weight = np.exp(- weight * 10)
return weight / weight.sum()


class SegmentPreprocessor(BaseImagePreprocessor):
@@ -50,7 +50,7 @@ def _handler_train(self, msg: 'gnes_pb2.Message'):
chunks = self.get_chunks_from_docs(msg.request.train.docs)
self.train_data.extend(chunks)
msg.response.train.status = gnes_pb2.Response.PENDING
raise BlockMessage
# raise BlockMessage
if msg.request.train.flush:
self._model.train(self.train_data)
self.logger.info('%d samples is flushed for training' % len(self.train_data))
@@ -77,9 +77,9 @@ def _handler_chunk_search(self, msg: 'gnes_pb2.Message'):
@handler.register(gnes_pb2.Response.QueryResponse)
def _handler_doc_search(self, msg: 'gnes_pb2.Message'):
if msg.response.search.level == gnes_pb2.Response.QueryResponse.DOCUMENT_NOT_FILLED:
doc_ids = [r.doc.doc_id for r in msg.response.topk_results]
doc_ids = [r.doc.doc_id for r in msg.response.search.topk_results]
docs = self._model.query(doc_ids)
for r, d in zip(msg.response.topk_results, docs):
for r, d in zip(msg.response.search.topk_results, docs):
if d is not None:
# fill in the doc if this shard returns non-empty
r.doc.CopyFrom(d)

0 comments on commit 8780a4d

Please sign in to comment.
You can’t perform that action at this time.