Skip to content

Commit

Permalink
add mars common API (#26)
Browse files Browse the repository at this point in the history
* add mars api
  • Loading branch information
hekaisheng authored and qinxuye committed Dec 13, 2018
1 parent ffe306b commit eb946d4
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 217 deletions.
97 changes: 97 additions & 0 deletions mars/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 1999-2018 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from .actors import new_client
from .cluster_info import ClusterInfoActor
from .node_info import NodeInfoActor
from .scheduler import SessionActor, GraphActor, KVStoreActor
from .scheduler.session import SessionManagerActor
from .scheduler.graph import ResultReceiverActor

logger = logging.getLogger(__name__)


class MarsAPI(object):
def __init__(self, scheduler_ip):
self.actor_client = new_client()
self.cluster_info = self.actor_client.actor_ref(
ClusterInfoActor.default_name(), address=scheduler_ip)
self.kv_store = self.get_actor_ref(KVStoreActor.default_name())
self.session_manager = self.get_actor_ref(SessionManagerActor.default_name())

def get_actor_ref(self, uid):
actor_address = self.cluster_info.get_scheduler(uid)
return self.actor_client.actor_ref(uid, address=actor_address)

def get_schedulers_info(self):
schedulers = self.cluster_info.get_schedulers()
infos = []
for scheduler in schedulers:
info_ref = self.actor_client.actor_ref(NodeInfoActor.default_name(),
address=scheduler)
infos.append(info_ref.get_info())
return infos

def count_workers(self):
try:
worker_info = self.kv_store.read('/workers/meta')
workers_num = len(worker_info.children)
return workers_num
except KeyError:
return 0

def create_session(self, session_id, **kw):
self.session_manager.create_session(session_id, **kw)

def delete_session(self, session_id):
session_uid = SessionActor.gen_name(session_id)
session_ref = self.get_actor_ref(session_uid)
session_ref.destroy()

def submit_graph(self, session_id, serialized_graph, graph_key, target):
session_uid = SessionActor.gen_name(session_id)
session_ref = self.get_actor_ref(session_uid)
session_ref.submit_tensor_graph(serialized_graph, graph_key, target, _tell=True)

def delete_graph(self, session_id, graph_key):
graph_uid = GraphActor.gen_name(session_id, graph_key)
graph_ref = self.get_actor_ref(graph_uid)
graph_ref.destroy()

def stop_graph(self, session_id, graph_key):
graph_uid = GraphActor.gen_name(session_id, graph_key)
graph_ref = self.get_actor_ref(graph_uid)
graph_ref.stop_graph()

def get_graph_state(self, session_id, graph_key):
from .scheduler.utils import GraphState

state_obj = self.kv_store.read(
'/sessions/%s/graph/%s/state' % (session_id, graph_key), silent=True)
state = state_obj.value if state_obj else 'preparing'
state = GraphState(state.lower())
return state

def fetch_data(self, session_id, graph_key, tensor_key):
graph_uid = GraphActor.gen_name(session_id, graph_key)
graph_address = self.cluster_info.get_scheduler(graph_uid)
result_ref = self.actor_client.create_actor(ResultReceiverActor, address=graph_address)
return result_ref.fetch_tensor(session_id, graph_key, tensor_key)

def delete_data(self, session_id, graph_key, tensor_key):
graph_uid = GraphActor.gen_name(session_id, graph_key)
graph_ref = self.get_actor_ref(graph_uid)
graph_ref.free_tensor_data(tensor_key, _tell=True)
1 change: 0 additions & 1 deletion mars/scheduler/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class ResultReceiverActor(SchedulerActor):
def __init__(self):
super(ResultReceiverActor, self).__init__()
self._kv_store_ref = None
self.chunks = dict()

@classmethod
def default_name(cls):
Expand Down
12 changes: 9 additions & 3 deletions mars/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
import json
import time
import numpy as np

from .api import MarsAPI
from .scheduler.graph import GraphState
from .serialize import dataserializer


class LocalSession(object):
def __init__(self):
Expand Down Expand Up @@ -43,10 +50,9 @@ class Session(object):

def __init__(self, endpoint=None):
if endpoint is not None:
from .web import get_client
from .web.session import Session

client = get_client(endpoint)
self._sess = client.create_session()
self._sess = Session(endpoint)
else:
self._sess = LocalSession()

Expand Down
1 change: 0 additions & 1 deletion mars/web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .api_client import MarsApiClient, get_client
from .session import Session

try:
Expand Down
2 changes: 1 addition & 1 deletion mars/web/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import gevent.monkey
gevent.monkey.patch_all(thread=False)
gevent.monkey.patch_all()

import logging
import time
Expand Down
86 changes: 13 additions & 73 deletions mars/web/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
from .server import register_api_handler
from ..compat import six, futures
from ..lib.tblib import pickling_support
from ..scheduler import SessionActor, GraphActor, KVStoreActor
from ..scheduler.session import SessionManagerActor
from ..actors import new_client
from .. import resource

pickling_support.install()
_actor_client = new_client()
Expand All @@ -50,33 +47,8 @@ def _is_future(x):


class ApiRequestHandler(web.RequestHandler):
def initialize(self, sessions, cluster_info):
self.sessions = sessions
self.cluster_info = cluster_info

uid = SessionManagerActor.default_name()
scheduler_address = self.cluster_info.get_scheduler(uid)
self.session_manager_ref = _actor_client.actor_ref(uid, address=scheduler_address)

uid = KVStoreActor.default_name()
scheduler_address = self.cluster_info.get_scheduler(uid)
self.kv_store_ref = _actor_client.actor_ref(uid, address=scheduler_address)

def get_session_ref(self, session_id):
try:
return self.sessions[session_id]
except KeyError:
uid = SessionActor.gen_name(session_id)
scheduler_ip = self.cluster_info.get_scheduler(uid)
actor_ref = _actor_client.actor_ref(uid, address=scheduler_ip)
self.sessions[session_id] = actor_ref
return actor_ref

def get_graph_ref(self, session_id, graph_key):
uid = GraphActor.gen_name(session_id, graph_key)
scheduler_ip = self.cluster_info.get_scheduler(uid)
actor_ref = _actor_client.actor_ref(uid, address=scheduler_ip)
return actor_ref
def initialize(self, web_api):
self.web_api = web_api


class ApiEntryHandler(ApiRequestHandler):
Expand All @@ -88,23 +60,13 @@ class SessionsApiHandler(ApiRequestHandler):
def post(self):
args = {k: self.get_argument(k) for k in self.request.arguments}
session_id = str(uuid.uuid1())

session_ref = self.session_manager_ref.create_session(session_id, **args)
session_ref = _actor_client.actor_ref(session_ref)
self.sessions[session_id] = session_ref
logger.info('Session %s created.' % session_id)

self.web_api.create_session(session_id, **args)
self.write(json.dumps(dict(session_id=session_id)))


class SessionApiHandler(ApiRequestHandler):
def delete(self, session_id):
session_ref = self.get_session_ref(session_id)
session_ref.destroy()
try:
del self.sessions[session_id]
except KeyError:
pass
self.web_api.delete_session(session_id)


class GraphsApiHandler(ApiRequestHandler):
Expand All @@ -118,8 +80,7 @@ def post(self, session_id):

try:
graph_key = str(uuid.uuid4())
session_ref = self.get_session_ref(session_id)
session_ref.submit_tensor_graph(graph, graph_key, target, _tell=True)
self.web_api.submit_graph(session_id, graph, graph_key, target)
self.write(json.dumps(dict(graph_key=graph_key)))
except:
pickled_exc = pickle.dumps(sys.exc_info())
Expand All @@ -134,11 +95,7 @@ class GraphApiHandler(ApiRequestHandler):
def get(self, session_id, graph_key):
from ..scheduler.utils import GraphState

state_obj = self.kv_store_ref.read(
'/sessions/%s/graph/%s/state' % (session_id, graph_key), silent=True)
state = state_obj.value if state_obj else 'preparing'
state = GraphState(state.lower())

state = self.web_api.get_graph_state(session_id, graph_key)
if state == GraphState.RUNNING:
self.write(json.dumps(dict(state='running')))
elif state == GraphState.SUCCEEDED:
Expand All @@ -151,42 +108,25 @@ def get(self, session_id, graph_key):
self.write(json.dumps(dict(state='preparing')))

def delete(self, session_id, graph_key):
graph_ref = self.get_graph_ref(session_id, graph_key)
graph_ref.stop_graph()
self.web_api.stop_graph(session_id, graph_key)


class GraphDataHandler(ApiRequestHandler):
_executor = futures.ThreadPoolExecutor(resource.cpu_count())

@gen.coroutine
def get(self, session_id, graph_key, tensor_key):
from ..scheduler.graph import ResultReceiverActor
uid = GraphActor.gen_name(session_id, graph_key)
scheduler_ip = self.cluster_info.get_scheduler(uid)

def _fetch_fun():
client = new_client()
merge_ref = client.create_actor(ResultReceiverActor, address=scheduler_ip)
return merge_ref.fetch_tensor(session_id, graph_key, tensor_key)

data = yield self._executor.submit(_fetch_fun)
executor = futures.ThreadPoolExecutor(1)
data = yield executor.submit(
self.web_api.fetch_data, session_id, graph_key, tensor_key)
self.write(data)

def delete(self, session_id, graph_key, tensor_key):
uid = GraphActor.gen_name(session_id, graph_key)
scheduler_ip = self.cluster_info.get_scheduler(uid)
graph_ref = _actor_client.actor_ref(uid, address=scheduler_ip)
graph_ref.free_tensor_data(tensor_key, _tell=True)
self.web_api.delete_data(session_id, graph_key, tensor_key)


class WorkersApiHandler(ApiRequestHandler):
def get(self):
try:
worker_info = self.kv_store_ref.read('/workers/meta')
workers_num = len(worker_info.children)
self.write(json.dumps(workers_num))
except KeyError:
self.write(json.dumps(0))
workers_num = self.web_api.count_workers()
self.write(json.dumps(workers_num))


register_api_handler('/api', ApiEntryHandler)
Expand Down
60 changes: 0 additions & 60 deletions mars/web/api_client.py

This file was deleted.

12 changes: 2 additions & 10 deletions mars/web/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,12 @@
# limitations under the License.

from .server import register_ui_handler, get_jinja_env
from ..node_info import NodeInfoActor
from ..actors import new_client


def dashboard(cluster_ref, doc):
def dashboard(web_api, doc):
doc.title = 'Mars UI'

actor_client = new_client()
schedulers = cluster_ref.get_schedulers()
infos = []
for scheduler in schedulers:
info_ref = actor_client.actor_ref(NodeInfoActor.default_name(),
address=scheduler)
infos.append(info_ref.get_info())
infos = web_api.get_schedulers_info()
doc.template_variables['infos'] = infos
jinja_env = get_jinja_env()
doc.template = jinja_env.get_template('dashboard.html')
Expand Down

0 comments on commit eb946d4

Please sign in to comment.