Skip to content

Commit

Permalink
[BACKPORT] Add support for HTTP request rewriter (#2664) (#2665)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi authored Jan 30, 2022
1 parent b1bb988 commit b47e27d
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 31 deletions.
33 changes: 24 additions & 9 deletions mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ def __init__(
web_api: Optional[OscarWebAPI],
client: ClientType = None,
timeout: float = None,
request_rewriter: Callable = None,
):
super().__init__(address, session_id)
self._session_api = session_api
Expand All @@ -736,6 +737,7 @@ def __init__(
self._web_api = web_api
self.client = client
self.timeout = timeout
self._request_rewriter = request_rewriter

self._tileable_to_fetch = WeakKeyDictionary()
self._asyncio_task_timeout_detector_task = (
Expand Down Expand Up @@ -785,6 +787,7 @@ async def init(
**kwargs,
) -> "AbstractAsyncSession":
init_local = kwargs.pop("init_local", False)
request_rewriter = kwargs.pop("request_rewriter", None)
if init_local:
from .local import new_cluster_in_isolation

Expand All @@ -800,7 +803,11 @@ async def init(

if urlparse(address).scheme == "http":
return await _IsolatedWebSession._init(
address, session_id, new=new, timeout=timeout
address,
session_id,
new=new,
timeout=timeout,
request_rewriter=request_rewriter,
)
else:
return await cls._init(address, session_id, new=new, timeout=timeout)
Expand Down Expand Up @@ -973,7 +980,9 @@ async def _get_storage_api(self, band: BandType):
if urlparse(self.address).scheme == "http":
from ...services.storage.api import WebStorageAPI

storage_api = WebStorageAPI(self._session_id, self.address, band[1])
storage_api = WebStorageAPI(
self._session_id, self.address, band[1], self._request_rewriter
)
else:
storage_api = await StorageAPI.create(self._session_id, band[0], band[1])
return storage_api
Expand Down Expand Up @@ -1216,7 +1225,12 @@ async def stop_server(self):
class _IsolatedWebSession(_IsolatedSession):
@classmethod
async def _init(
cls, address: str, session_id: str, new: bool = True, timeout: float = None
cls,
address: str,
session_id: str,
new: bool = True,
timeout: float = None,
request_rewriter: Callable = None,
):
from ...services.session import WebSessionAPI
from ...services.lifecycle import WebLifecycleAPI
Expand All @@ -1225,15 +1239,15 @@ async def _init(
from ...services.mutable import WebMutableAPI
from ...services.cluster import WebClusterAPI

session_api = WebSessionAPI(address)
session_api = WebSessionAPI(address, request_rewriter)
if new:
# create new session
await session_api.create_session(session_id)
lifecycle_api = WebLifecycleAPI(session_id, address)
meta_api = WebMetaAPI(session_id, address)
task_api = WebTaskAPI(session_id, address)
mutable_api = WebMutableAPI(session_id, address)
cluster_api = WebClusterAPI(address)
lifecycle_api = WebLifecycleAPI(session_id, address, request_rewriter)
meta_api = WebMetaAPI(session_id, address, request_rewriter)
task_api = WebTaskAPI(session_id, address, request_rewriter)
mutable_api = WebMutableAPI(session_id, address, request_rewriter)
cluster_api = WebClusterAPI(address, request_rewriter)

return cls(
address,
Expand All @@ -1246,6 +1260,7 @@ async def _init(
cluster_api,
None,
timeout=timeout,
request_rewriter=request_rewriter,
)

async def get_web_endpoint(self) -> Optional[str]:
Expand Down
4 changes: 3 additions & 1 deletion mars/deploy/oscar/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ async def test_web_session(create_cluster):
client = create_cluster[0]
session_id = str(uuid.uuid4())
web_address = client.web_address
session = await AsyncSession.init(web_address, session_id)
session = await AsyncSession.init(
web_address, session_id, request_rewriter=lambda x: x
)
assert await session.get_web_endpoint() == web_address
session.as_default()
assert isinstance(session._isolated_session, _IsolatedWebSession)
Expand Down
5 changes: 3 additions & 2 deletions mars/services/cluster/api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import json
from typing import Dict, List, Optional, Set
from typing import Callable, Dict, List, Optional, Set

from ....lib.aio import alru_cache
from ....typing import BandType
Expand Down Expand Up @@ -143,8 +143,9 @@ async def get_mars_versions(self):


class WebClusterAPI(AbstractClusterAPI, MarsWebAPIClientMixin):
def __init__(self, address: str):
def __init__(self, address: str, request_rewriter: Callable = None):
self._address = address.rstrip("/")
self.request_rewriter = request_rewriter

@staticmethod
def _convert_node_dict(node_info_list: Dict[str, Dict]):
Expand Down
7 changes: 5 additions & 2 deletions mars/services/lifecycle/api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Callable, Dict, List

from ....utils import serialize_serializable, deserialize_serializable
from ...web import web_api, MarsServiceWebAPIHandler, MarsWebAPIClientMixin
Expand Down Expand Up @@ -45,9 +45,12 @@ async def get_all_chunk_ref_counts(self, session_id: str):


class WebLifecycleAPI(AbstractLifecycleAPI, MarsWebAPIClientMixin):
def __init__(self, session_id: str, address: str):
def __init__(
self, session_id: str, address: str, request_rewriter: Callable = None
):
self._session_id = session_id
self._address = address.rstrip("/")
self.request_rewriter = request_rewriter

async def decref_tileables(self, tileable_keys: List[str]):
path = f"{self._address}/api/session/{self._session_id}/lifecycle"
Expand Down
7 changes: 5 additions & 2 deletions mars/services/meta/api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional

from .... import oscar as mo
from ....utils import serialize_serializable, deserialize_serializable
Expand Down Expand Up @@ -53,9 +53,12 @@ async def get_chunks_meta(self, session_id: str):


class WebMetaAPI(AbstractMetaAPI, MarsWebAPIClientMixin):
def __init__(self, session_id: str, address: str):
def __init__(
self, session_id: str, address: str, request_rewriter: Callable = None
):
self._session_id = session_id
self._address = address.rstrip("/")
self.request_rewriter = request_rewriter

@mo.extensible
async def get_chunk_meta(
Expand Down
7 changes: 5 additions & 2 deletions mars/services/mutable/api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from typing import Union, Callable

import numpy as np

Expand Down Expand Up @@ -105,9 +105,12 @@ async def write_mutable(self, session_id: str, name: str): # pragma: no cover


class WebMutableAPI(AbstractMutableAPI, MarsWebAPIClientMixin):
def __init__(self, session_id: str, address: str):
def __init__(
self, session_id: str, address: str, request_rewriter: Callable = None
):
self._session_id = session_id
self._address = address.rstrip("/")
self.request_rewriter = request_rewriter

async def create_mutable_tensor(
self,
Expand Down
7 changes: 5 additions & 2 deletions mars/services/scheduling/api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import json
from typing import List, Optional
from typing import Callable, List, Optional

from ....lib.aio import alru_cache
from ...web import web_api, MarsServiceWebAPIHandler, MarsWebAPIClientMixin
Expand Down Expand Up @@ -71,9 +71,12 @@ async def get_subtask_schedule_summaries(self, session_id: str):


class WebSchedulingAPI(AbstractSchedulingAPI, MarsWebAPIClientMixin):
def __init__(self, session_id: str, address: str):
def __init__(
self, session_id: str, address: str, request_rewriter: Callable = None
):
self._session_id = session_id
self._address = address.rstrip("/")
self.request_rewriter = request_rewriter

async def get_subtask_schedule_summaries(
self, task_id: Optional[str] = None
Expand Down
5 changes: 3 additions & 2 deletions mars/services/session/api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import json
from typing import Dict, List, Union
from typing import Callable, Dict, List, Union

from ....utils import parse_readable_size
from ...web import web_api, MarsServiceWebAPIHandler, MarsWebAPIClientMixin
Expand Down Expand Up @@ -116,8 +116,9 @@ async def fetch_tileable_op_logs(self, session_id: str, op_key: str):


class WebSessionAPI(AbstractSessionAPI, MarsWebAPIClientMixin):
def __init__(self, address: str):
def __init__(self, address: str, request_rewriter: Callable = None):
self._address = address.rstrip("/")
self.request_rewriter = request_rewriter

async def get_sessions(self) -> List[SessionInfo]:
addr = f"{self._address}/api/session"
Expand Down
11 changes: 9 additions & 2 deletions mars/services/storage/api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from collections import defaultdict
from typing import Any, List
from typing import Any, Callable, List

from .... import oscar as mo
from ....storage import StorageLevel
Expand Down Expand Up @@ -98,10 +98,17 @@ async def get_infos(self, session_id: str, data_key: str):


class WebStorageAPI(AbstractStorageAPI, MarsWebAPIClientMixin):
def __init__(self, session_id: str, address: str, band_name: str):
def __init__(
self,
session_id: str,
address: str,
band_name: str,
request_rewriter: Callable = None,
):
self._session_id = session_id
self._address = address.rstrip("/")
self._band_name = band_name
self.request_rewriter = request_rewriter

@mo.extensible
async def get(
Expand Down
9 changes: 6 additions & 3 deletions mars/services/task/api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import base64
import json
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

from ....core import TileableGraph, Tileable
from ....utils import serialize_serializable, deserialize_serializable
Expand Down Expand Up @@ -162,9 +162,12 @@ async def cancel_task(self, session_id: str, task_id: str):


class WebTaskAPI(AbstractTaskAPI, MarsWebAPIClientMixin):
def __init__(self, session_id: str, address: str):
def __init__(
self, session_id: str, address: str, request_rewriter: Callable = None
):
self._session_id = session_id
self._address = address.rstrip("/")
self.request_rewriter = request_rewriter

async def get_task_results(self, progress: bool = False) -> List[TaskResult]:
path = f"{self._address}/api/session/{self._session_id}/task"
Expand Down Expand Up @@ -200,7 +203,7 @@ async def submit_tileable_graph(
headers={"Content-Type": "application/octet-stream"},
data=body,
)
return res.body.decode()
return res.body.decode().strip()

async def get_fetch_tileables(self, task_id: str) -> List[Tileable]:
path = (
Expand Down
20 changes: 16 additions & 4 deletions mars/services/web/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Callable, Dict, List, NamedTuple, Optional, Type, Union

from tornado import httpclient, web
from tornado.simple_httpclient import HTTPTimeoutError
from tornado.simple_httpclient import HTTPRequest, HTTPTimeoutError

from ...lib.aio import alru_cache
from ...utils import serialize_serializable, deserialize_serializable
Expand Down Expand Up @@ -135,6 +135,9 @@ def _collect_services(cls):
for api_def in web_api_defs:
cls._method_to_handlers[api_def.method.lower()][handle_func] = api_def

def prepare(self):
self.set_header("Content-Type", "application/octet-stream")

@classmethod
def get_root_pattern(cls):
return cls._root_pattern + "(?:/(?P<sub_path>.*)$|$)"
Expand Down Expand Up @@ -200,6 +203,14 @@ def _client(self):
self._client_obj = httpclient.AsyncHTTPClient()
return self._client_obj

@property
def request_rewriter(self) -> Callable:
return getattr(self, "_request_rewriter", None)

@request_rewriter.setter
def request_rewriter(self, value: Callable):
self._request_rewriter = value

def __del__(self):
if hasattr(self, "_client_obj"):
self._client_obj.close()
Expand All @@ -220,9 +231,10 @@ async def _request_url(self, method, path, **kwargs):
path += path_connector + url_params

try:
res = await self._client.fetch(
path, method=method, raise_error=False, **kwargs
)
request = HTTPRequest(path, method=method, **kwargs)
if self.request_rewriter:
request = self.request_rewriter(request)
res = await self._client.fetch(request, raise_error=False)
except HTTPTimeoutError as ex:
raise TimeoutError(str(ex)) from None

Expand Down
8 changes: 8 additions & 0 deletions mars/services/web/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,14 @@ async def fetch(self, path, method="GET", **kwargs):
@pytest.mark.asyncio
async def test_web_api(actor_pool):
_pool, web_port = actor_pool
recorded_urls = []

def url_recorder(request):
recorded_urls.append(request.url)
return request

client = SimpleWebClient()
client.request_rewriter = url_recorder

res = await client.fetch(f"http://localhost:{web_port}/")
assert res.body.decode()
Expand Down Expand Up @@ -139,3 +145,5 @@ async def test_web_api(actor_pool):

res = await client.fetch(f"http://localhost:{web_port}/api/extra_test")
assert "Test" in res.body.decode()

assert len(recorded_urls) > 0

0 comments on commit b47e27d

Please sign in to comment.