Skip to content
Permalink
Browse files

feat(tests): add unittest for EncoderService and IndexerService

  • Loading branch information...
raccoonliukai committed Sep 5, 2019
1 parent d805b56 commit 5a745b1e8f9a158ab3ce488ff0eb53d59364291b
Showing with 91 additions and 0 deletions.
  1. +51 −0 tests/test_encoder_service.py
  2. +40 −0 tests/test_indexer_service.py
@@ -0,0 +1,51 @@
import os
import unittest

import numpy as np

from gnes.cli.parser import set_encoder_parser, _set_client_parser
from gnes.client.base import ZmqClient
from gnes.proto import gnes_pb2, array2blob
from gnes.service.base import ServiceManager
from gnes.service.encoder import EncoderService
from gnes.encoder.base import BaseEncoder
from gnes.helper import train_required


class DummyEncoder(BaseEncoder):

def train(self, *args, **kwargs):
pass

@train_required
def encode(self, x):
return np.array(x)


class TestEncoderService(unittest.TestCase):

def setUp(self):
self.test_numeric = np.random.randint(0, 255, (1000, 1024)).astype('float32')

def test_empty_service(self):
args = set_encoder_parser().parse_args(['--yaml_path', '!DummyEncoder {gnes_config: {name: EncoderService, is_trained: True}}'])
c_args = _set_client_parser().parse_args([
'--port_in', str(args.port_out),
'--port_out', str(args.port_in)])

with ServiceManager(EncoderService, args), ZmqClient(c_args) as client:
msg = gnes_pb2.Message()
d = msg.request.index.docs.add()
d.doc_type = gnes_pb2.Document.IMAGE

c = d.chunks.add()
c.blob.CopyFrom(array2blob(self.test_numeric))

client.send_message(msg)
r = client.recv_message()
self.assertEqual(len(r.request.index.docs), 1)
self.assertEqual(r.response.index.status, gnes_pb2.Response.SUCCESS)

def tearDown(self):
if os.path.exists('EncoderService.bin'):
os.remove('EncoderService.bin')
@@ -0,0 +1,40 @@
import os
import unittest

import numpy as np

from gnes.cli.parser import set_indexer_parser, _set_client_parser
from gnes.client.base import ZmqClient
from gnes.proto import gnes_pb2, array2blob
from gnes.service.base import ServiceManager
from gnes.service.indexer import IndexerService


class TestIndexerService(unittest.TestCase):

def setUp(self):
self.test_numeric = np.random.randint(0, 255, (1000, 1024)).astype('float32')

def test_empty_service(self):
args = set_indexer_parser().parse_args(['--yaml_path', '!BaseChunkIndexer {gnes_config: {name: IndexerService}}'])
c_args = _set_client_parser().parse_args([
'--port_in', str(args.port_out),
'--port_out', str(args.port_in)])

with ServiceManager(IndexerService, args), ZmqClient(c_args) as client:
msg = gnes_pb2.Message()
d = msg.request.index.docs.add()

c = d.chunks.add()
c.doc_id = 0
c.embedding.CopyFrom(array2blob(self.test_numeric))
c.offset = 0
c.weight = 1.0

client.send_message(msg)
r = client.recv_message()
self.assertEqual(r.response.index.status, gnes_pb2.Response.SUCCESS)

def tearDown(self):
if os.path.exists('IndexerService.bin'):
os.remove('IndexerService.bin')

0 comments on commit 5a745b1

Please sign in to comment.
You can’t perform that action at this time.