Skip to content

Commit

Permalink
Support set endpoint for web session to reuse the session but with a …
Browse files Browse the repository at this point in the history
…different web (#54)

* add endpoint setter in web session so that it still works if web endpoint changed.
  • Loading branch information
hekaisheng authored and qinxuye committed Dec 21, 2018
1 parent 0778a5c commit 70dae9e
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 57 deletions.
9 changes: 9 additions & 0 deletions mars/deploy/local/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def __init__(self, endpoint):
# create session on the cluster side
self._api.create_session(self._session_id)

@property
def endpoint(self):
return self._endpoint

@endpoint.setter
def endpoint(self, endpoint):
self._endpoint = endpoint
self._api = MarsAPI(self._endpoint)

def run(self, *tensors, **kw):
timeout = kw.pop('timeout', -1)
if kw:
Expand Down
19 changes: 19 additions & 0 deletions mars/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ def __init__(self):
from .tensor.execution.core import Executor

self._executor = Executor()
self._endpoint = None

@property
def endpoint(self):
return self._endpoint

@endpoint.setter
def endpoint(self, endpoint):
if endpoint is not None:
raise ValueError('Local session cannot set endpoint')
self._endpoint = endpoint

def run(self, *tensors, **kw):
if self._executor is None:
Expand Down Expand Up @@ -87,6 +98,14 @@ def run(self, *tensors, **kw):
return ret
return ret[0]

@property
def endpoint(self):
return self._sess.endpoint

@endpoint.setter
def endpoint(self, endpoint):
self._sess.endpoint = endpoint

def decref(self, *keys):
if hasattr(self._sess, 'decref'):
self._sess.decref(*keys)
Expand Down
8 changes: 8 additions & 0 deletions mars/web/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def __init__(self, endpoint, args=None):
def session_id(self):
return self._session_id

@property
def endpoint(self):
return self._endpoint

@endpoint.setter
def endpoint(self, url):
self._endpoint = url

def _main(self):
resp = self._req_session.post(self._endpoint + '/api/session', self._args)
if resp.status_code >= 400:
Expand Down
122 changes: 65 additions & 57 deletions mars/web/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,17 @@
import sys
import signal
import subprocess
import uuid

import gevent
import numpy as np
from numpy.testing import assert_array_equal


from mars import tensor as mt
from mars.tensor.execution.core import Executor
from mars.actors import create_actor_pool, new_client
from mars.actors import new_client
from mars.utils import get_next_port
from mars.cluster_info import ClusterInfoActor
from mars.scheduler import SessionManagerActor, KVStoreActor, ResourceActor
from mars.scheduler.graph import GraphActor
from mars.web import MarsWeb
from mars.scheduler import KVStoreActor
from mars.session import new_session
from mars.serialize.dataserializer import dumps, loads
from mars.config import options
Expand Down Expand Up @@ -189,66 +186,77 @@ def testApi(self):
self.assertEqual(res.status_code, 200)


class TestWithMockServer(unittest.TestCase):
def setUp(self):
self._executor = Executor('numpy')
class MockResponse:
def __init__(self, status_code, json_text=None, data=None):
self._json_text = json_text
self._content = data
self._status_code = status_code

# create scheduler pool with needed actor
scheduler_address = '127.0.0.1:' + str(get_next_port())
pool = create_actor_pool(address=scheduler_address, n_process=1, backend='gevent')
pool.create_actor(ClusterInfoActor, [scheduler_address], uid=ClusterInfoActor.default_name())
pool.create_actor(ResourceActor, uid=ResourceActor.default_name())
pool.create_actor(SessionManagerActor, uid=SessionManagerActor.default_name())
self._kv_store_ref = pool.create_actor(KVStoreActor, uid=KVStoreActor.default_name())
self._pool = pool
@property
def text(self):
return json.dumps(self._json_text)

self.start_web(scheduler_address)
@property
def content(self):
return self._content

check_time = time.time()
while True:
if time.time() - check_time > 30:
raise SystemError('Wait for service start timeout')
try:
resp = requests.get(self._service_ep + '/api', timeout=1)
except (requests.ConnectionError, requests.Timeout):
time.sleep(1)
continue
if resp.status_code >= 400:
time.sleep(1)
continue
break
@property
def status_code(self):
return self._status_code

def tearDown(self):
self._web.stop()
self._pool.stop()

def start_web(self, scheduler_address):
import gevent.monkey
gevent.monkey.patch_all()
class MockedServer(object):
def __init__(self):
self._data = None

web_port = str(get_next_port())
mars_web = MarsWeb(port=int(web_port), scheduler_ip=scheduler_address)
mars_web.start()
service_ep = 'http://127.0.0.1:' + web_port
self._web = mars_web
self._service_ep = service_ep

@mock.patch(GraphActor.__module__ + '.GraphActor.execute_graph')
@mock.patch(GraphActor.__module__ + '.ResultReceiverActor.fetch_tensor')
def testApi(self, mock_fetch_tensor, _):
@property
def data(self):
return self._data

@data.setter
def data(self, data):
self._data = data

@staticmethod
def mocked_requests_get(*arg, **_):
url = arg[0]
if url.endswith('worker'):
return MockResponse(200, json_text=1)
if url.split('/')[-2] == 'graph':
return MockResponse(200, json_text={"state": 'success'})
elif url.split('/')[-2] == 'data':
data = dumps(np.ones((100, 100)) * 100)
return MockResponse(200, data=data)

@staticmethod
def mocked_requests_post(*arg, **_):
url = arg[0]
if url.endswith('session'):
return MockResponse(200, json_text={"session_id": str(uuid.uuid4())})
elif url.endswith('graph'):
return MockResponse(200, json_text={"graph_key": str(uuid.uuid4())})
else:
return MockResponse(404)

@staticmethod
def mocked_requests_delete(*_):
return MockResponse(200)


class TestWithMockServer(unittest.TestCase):
def setUp(self):
self._service_ep = 'http://mock.com'

@mock.patch('requests.Session.get', side_effect=MockedServer.mocked_requests_get)
@mock.patch('requests.Session.post', side_effect=MockedServer.mocked_requests_post)
@mock.patch('requests.Session.delete', side_effect=MockedServer.mocked_requests_delete)
def testApi(self, *_):
with new_session(self._service_ep) as sess:
self._kv_store_ref.write('/workers/meta/%s' % 'mock_endpoint', 'mock_meta')
self.assertEqual(sess.count_workers(), 1)

a = mt.ones((100, 100), chunks=30)
b = mt.ones((100, 100), chunks=30)
c = a.dot(b)
graph_key = sess.run(c, timeout=120, wait=False)
self._kv_store_ref.write('/sessions/%s/graph/%s/state' % (sess.session_id, graph_key), 'SUCCEEDED')
graph_url = '%s/api/session/%s/graph/%s' % (self._service_ep, sess.session_id, graph_key)
graph_state = json.loads(requests.get(graph_url).text)
self.assertEqual(graph_state['state'], 'success')
mock_fetch_tensor.return_value = dumps(self._executor.execute_tensor(c, concat=True)[0])
data_url = graph_url + '/data/' + c.key
data = loads(requests.get(data_url).content)
assert_array_equal(data, np.ones((100, 100)) * 100)

result = sess.run(c, timeout=120)
assert_array_equal(result, np.ones((100, 100)) * 100)

0 comments on commit 70dae9e

Please sign in to comment.