Skip to content
Permalink
Browse files

fix(base): fix ref to CompositionalTrainableBase

  • Loading branch information...
hanxiao committed Aug 2, 2019
1 parent 53dae6d commit 50fdc0414659d7fe0acf858fe23e67c1be1bee0b
Showing with 20 additions and 9 deletions.
  1. +2 −2 gnes/base/__init__.py
  2. +3 −2 gnes/cli/parser.py
  3. +2 −2 gnes/encoder/text/bert.py
  4. +2 −3 gnes/indexer/base.py
  5. +11 −0 tests/test_parser.py
@@ -295,7 +295,7 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
if stop_on_import_error:
raise RuntimeError('Cannot import module, pip install may required') from ex

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
if node.tag in {'!PipelineEncoder', '!CompositionalTrainableBase'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '0'

data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
@@ -325,7 +325,7 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
obj.logger.info('initialize %s from a yaml config' % cls.__name__)
cls.init_from_yaml = False

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
if node.tag in {'!PipelineEncoder', '!CompositionalTrainableBase'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1'

return obj, data, load_from_dump
@@ -213,10 +213,11 @@ def set_grpc_frontend_parser(parser=None):
from ..service.base import SocketType
if not parser:
parser = set_base_parser()
_set_client_parser(parser)
set_service_parser(parser)
_set_grpc_parser(parser)
parser.set_defaults(socket_in=SocketType.PULL_BIND,
socket_out=SocketType.PUSH_BIND)
socket_out=SocketType.PUSH_BIND,
read_only=True)
parser.add_argument('--max_concurrency', type=int, default=10,
help='maximum concurrent client allowed')
parser.add_argument('--max_send_size', type=int, default=100,
@@ -20,7 +20,7 @@

import numpy as np

from ..base import CompositionalEncoder, BaseTextEncoder
from ..base import CompositionalTrainableBase, BaseTextEncoder
from ...helper import batching


@@ -45,7 +45,7 @@ def close(self):
self.bc_encoder.close()


class BertEncoderWithServer(CompositionalEncoder):
class BertEncoderWithServer(CompositionalTrainableBase):
def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
return self.component['bert_client'].encode(text, *args, **kwargs)

@@ -19,8 +19,7 @@

import numpy as np

from ..base import TrainableBase
from ..encoder.base import CompositionalEncoder
from ..base import TrainableBase, CompositionalTrainableBase


class BaseIndexer(TrainableBase):
@@ -71,7 +70,7 @@ def normalize_score(self, *args, **kwargs):
pass


class JointIndexer(CompositionalEncoder):
class JointIndexer(CompositionalTrainableBase):

@property
def component(self):
@@ -0,0 +1,11 @@
import unittest

from gnes.cli.parser import set_grpc_frontend_parser


class TestParser(unittest.TestCase):
def test_service_parser(self):
args1 = set_grpc_frontend_parser().parse_args([])
args2 = set_grpc_frontend_parser().parse_args([])
self.assertNotEqual(args1.port_in, args2.port_in)
self.assertNotEqual(args1.port_out, args2.port_out)

0 comments on commit 50fdc04

Please sign in to comment.
You can’t perform that action at this time.