Skip to content
Permalink
Browse files

fix(encoder): fix batching in encoder

  • Loading branch information...
hanxiao committed Aug 6, 2019
1 parent e35e3b3 commit e5fefcee9ea003c7d244bd58c889606a03e12936
@@ -31,6 +31,9 @@ def resolve_yaml_path(path):
elif path.isidentifier():
# possible class name
return io.StringIO('!%s {}' % path)
elif path.startswith('!'):
# possible YAML content
return io.StringIO(path)
else:
raise argparse.ArgumentTypeError('%s can not be resolved, it should be a readable stream,'
' or a valid file path, or a supported class name.' % path)
@@ -384,7 +384,7 @@ def rule3():
self._num_layer += 1
last_layer.components[0]['socket_out'] = str(SocketType.PUSH_CONNECT)
r = CommentedMap({'name': 'Router',
'yaml_path': None,
'yaml_path': 'BaseRouter',
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUSH_BIND),
'port_in': last_layer.components[0]['port_out'],
@@ -405,6 +405,8 @@ def rule5():
# a shortcut fn: based on c3(): (N)-2-(N) with pub sub connection
rule3()
router_layers[0].components[0]['socket_out'] = str(SocketType.PUB_BIND)
router_layers[0].components[0]['yaml_path'] = '"!PublishRouter {parameter: {num_part: %d}}"' \
% len(layer.components)
for c in layer.components:
c['socket_in'] = str(SocketType.SUB_CONNECT)

@@ -415,7 +417,7 @@ def rule6():
for c in layer.components:
income = self.Layer.get_value(c, 'income')
r = CommentedMap({'name': 'Router',
'yaml_path': None,
'yaml_path': 'BaseReduceRouter',
'socket_in': str(SocketType.SUB_CONNECT),
'socket_out': str(SocketType.PUSH_BIND) if income == 'pull' else str(
SocketType.PUB_BIND),
@@ -432,7 +434,7 @@ def rule7():
router_layer = YamlComposer.Layer(layer_id=self._num_layer)
self._num_layer += 1
r0 = CommentedMap({'name': 'Router',
'yaml_path': None,
'yaml_path': '"!PublishRouter {parameter: {num_part: %d}}"' % len(layer.components),
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUB_BIND),
'port_in': self._get_random_port(),
@@ -445,7 +447,7 @@ def rule7():
self._num_layer += 1
for c in layer.components:
r = CommentedMap({'name': 'Router',
'yaml_path': None,
'yaml_path': 'BaseRouter',
'socket_in': str(SocketType.SUB_CONNECT),
'socket_out': str(SocketType.PUSH_BIND),
'port_in': r0['port_out'],
@@ -461,7 +463,7 @@ def rule10():
router_layer = YamlComposer.Layer(layer_id=self._num_layer)
self._num_layer += 1
r0 = CommentedMap({'name': 'Router',
'yaml_path': None,
'yaml_path': '"!PublishRouter {parameter: {num_part: %d}}"' % len(layer.components),
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUB_BIND),
'port_in': self._get_random_port(),
@@ -478,7 +480,7 @@ def rule8():
router_layer = YamlComposer.Layer(layer_id=self._num_layer)
self._num_layer += 1
r = CommentedMap({'name': 'Router',
'yaml_path': None,
'yaml_path': 'BaseReduceRouter',
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUSH_BIND),
'port_in': self._get_random_port(),
@@ -489,7 +491,7 @@ def rule8():
if last_income == 'sub':
c['socket_out'] = str(SocketType.PUSH_CONNECT)
r_c = CommentedMap({'name': 'Router',
'yaml_path': None,
'yaml_path': 'BaseReduceRouter',
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUSH_CONNECT),
'port_in': self._get_random_port(),
@@ -519,26 +521,6 @@ def rule9():
last_layer.components[0]['socket_out'] = str(SocketType.PUSH_CONNECT)
layer.components[0]['socket_in'] = str(SocketType.PULL_BIND)

def rule11():
# a shortcut fn: (N)-2-(N) with push pull connection
router_layer = YamlComposer.Layer(layer_id=self._num_layer)
self._num_layer += 1
r = CommentedMap({'name': 'Router',
'yaml_path': None,
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUSH_BIND),
'port_in': self._get_random_port(),
'port_out': self._get_random_port()})

for c in last_layer.components:
c['socket_out'] = str(SocketType.PUSH_CONNECT)
c['port_out'] = r['port_in']
for c in layer.components:
c['socket_in'] = str(SocketType.PULL_CONNECT)
c['port_in'] = r['port_out']
router_layer.append(r)
router_layers.append(router_layer)

router_layers = [] # type: List['self.Layer']
# bind the last out to current in

@@ -118,9 +118,9 @@ def _encode(_, img: List['np.ndarray']):
# for video
if len(img[0].shape) == 4:
padding_image, max_lenth = _padding(img)
output = _encode(None, padding_image)
output = _encode(self, padding_image)
# for image
else:
output = _encode(None, img)
output = _encode(self, img)

return output
@@ -74,4 +74,4 @@ def _encode(_, data):
feed_dict={self.inputs: data})
return end_points_[self.select_layer]

return _encode(None, img).astype(np.float32)
return _encode(self, img).astype(np.float32)
@@ -118,7 +118,7 @@ def _encode1(_, data):
feed_dict={self.inputs: data})
return end_points_[self.select_layer]

v = [_ for vi in _encode1(None, img) for _ in vi]
v = [_ for vi in _encode1(self, img) for _ in vi]

v_input = [v[s:e] for s, e in zip(pos_start, pos_end)]
v_input = [(vi + [[0.0] * self.input_size] * (max_len - len(vi)))[:max_len] for vi in v_input]
@@ -129,4 +129,4 @@ def _encode2(_, data):
return self.sess2.run(self.mix_model.repre,
feed_dict={self.mix_model.feeds: data})

return _encode2(None, v_input).astype(np.float32)
return _encode2(self, v_input).astype(np.float32)
@@ -14,8 +14,8 @@ class TestProto(unittest.TestCase):

def setUp(self):
self.dirname = os.path.dirname(__file__)
self.publish_router_yaml = os.path.join(self.dirname, 'yaml', 'router-publish.yml')
self.batch_router_yaml = os.path.join(self.dirname, 'yaml', 'router-batch.yml')
self.publish_router_yaml = '!PublishRouter {parameter: {num_part: 2}}'
self.batch_router_yaml = '!DocBatchRouter {gnes_config: {batch_size: 2}}'
self.reduce_router_yaml = 'BaseReduceRouter'
self.chunk_router_yaml = 'ChunkReduceRouter'
self.doc_router_yaml = 'DocReduceRouter'

0 comments on commit e5fefce

Please sign in to comment.
You can’t perform that action at this time.