Skip to content

Commit

Permalink
fix fetch tensor data (#23)
Browse files Browse the repository at this point in the history
* fix merge chunks when fetch tensor data
  • Loading branch information
hekaisheng authored and wjsi committed Dec 11, 2018
1 parent 4b8ad0e commit 7db71f6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
49 changes: 38 additions & 11 deletions mars/scheduler/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
class ResultReceiverActor(SchedulerActor):
def __init__(self):
super(ResultReceiverActor, self).__init__()
self._kv_store_ref = None
self.chunks = dict()

@classmethod
Expand All @@ -46,19 +47,30 @@ def post_create(self):
self.set_cluster_info_ref()
self._kv_store_ref = self.get_actor_ref(KVStoreActor.default_name())

def merge_chunks(self, session_id, graph_key, tensor_key):
def fetch_tensor(self, session_id, graph_key, tensor_key):
from ..tensor.expressions.datasource import TensorFetchChunk
from ..tensor.execution.core import Executor
graph_actor = self.ctx.actor_ref(GraphActor.gen_name(session_id, graph_key))
tiled_tensor = graph_actor.get_tiled_tensor(tensor_key)
fetch_graph = deserialize_graph(graph_actor.build_tensor_merge_graph(tensor_key))

ctx = dict()
for chunk_key in [c.key for c in tiled_tensor.chunks]:
if chunk_key not in ctx:
target_keys = set()
for c in fetch_graph:
if isinstance(c.op, TensorFetchChunk):
if c.key in ctx:
continue
endpoints = self._kv_store_ref.read('/sessions/%s/chunks/%s/workers'
% (session_id, chunk_key))
% (session_id, c.key))
worker_ip = endpoints.children[0].key.rsplit('/', 1)[-1]
sender_ref = self.ctx.actor_ref('ResultSenderActor', address=worker_ip)
ctx[chunk_key] = loads(sender_ref.fetch_data(session_id, chunk_key))
return dumps(merge_tensor_chunks(tiled_tensor, ctx))
future = sender_ref.fetch_data(session_id, c.key, _wait=False)
ctx[c.key] = future
else:
target_keys.add(c.key)
ctx = dict((k, loads(future.result())) for k, future in six.iteritems(ctx))
executor = Executor(storage=ctx)
concat_result = executor.execute_graph(fetch_graph, keys=target_keys)
return dumps(concat_result[0])


class GraphActor(SchedulerActor):
Expand Down Expand Up @@ -641,10 +653,6 @@ def calc_stats(self):
transposed[state].append(data_src[op][sid])
return ops, transposed, finished * 100.0 / total_count

def get_tiled_tensor(self, tensor_key):
tiled_tensor = self._tensor_to_tiled[tensor_key][-1]
return tiled_tensor

def free_tensor_data(self, tensor_key):
from .operand import OperandActor
tiled_tensor = self._tensor_to_tiled[tensor_key][-1]
Expand All @@ -654,6 +662,25 @@ def free_tensor_data(self, tensor_key):
op_ref = self.ctx.actor_ref(op_uid, address=scheduler_addr)
op_ref.free_data(_tell=True)

def build_tensor_merge_graph(self, tensor_key):
from ..tensor.expressions.merge.concatenate import TensorConcatenate
from ..tensor.expressions.datasource import TensorFetchChunk

tiled_tensor = self._tensor_to_tiled[tensor_key][-1]
graph = DAG()
fetch_chunks = []
for c in tiled_tensor.chunks:
op = TensorFetchChunk(dtype=c.dtype, to_fetch_key=c.key)
fetch_chunk = op.new_chunk(None, c.shape, index=c.index, _key=c.key).data
graph.add_node(fetch_chunk)
fetch_chunks.append(fetch_chunk)
chunk = TensorConcatenate(dtype=tiled_tensor.op.dtype).new_chunk(
fetch_chunks, tiled_tensor.shape).data
graph.add_node(chunk)
[graph.add_edge(fetch_chunk, chunk) for fetch_chunk in fetch_chunks]

return serialize_graph(graph)

def fetch_tensor_result(self, tensor_key):
# TODO for test
tiled_tensor = self._tensor_to_tiled[tensor_key][-1]
Expand Down
2 changes: 1 addition & 1 deletion mars/web/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get(self, session_id, graph_key, tensor_key):
def _fetch_fun():
client = new_client()
merge_ref = client.create_actor(ResultReceiverActor, address=scheduler_ip)
return merge_ref.merge_chunks(session_id, graph_key, tensor_key)
return merge_ref.fetch_tensor(session_id, graph_key, tensor_key)

data = yield self._executor.submit(_fetch_fun)
self.write(data)
Expand Down

0 comments on commit 7db71f6

Please sign in to comment.