diff --git a/inngest/_internal/comm_lib/README.md b/inngest/_internal/comm_lib/README.md new file mode 100644 index 0000000..850ccb7 --- /dev/null +++ b/inngest/_internal/comm_lib/README.md @@ -0,0 +1,5 @@ +The `comm_lib` library contains the framework-agnostic, HTTP-aware communication layer. It's responsible for high-level handling of each request kind: + +- Execution +- Inspection +- Synchronization (a.k.a. registration) diff --git a/inngest/_internal/comm_lib/__init__.py b/inngest/_internal/comm_lib/__init__.py new file mode 100644 index 0000000..a6b27bc --- /dev/null +++ b/inngest/_internal/comm_lib/__init__.py @@ -0,0 +1,7 @@ +from .handler import CommHandler +from .models import CommResponse + +__all__ = [ + "CommHandler", + "CommResponse", +] diff --git a/inngest/_internal/comm.py b/inngest/_internal/comm_lib/handler.py similarity index 58% rename from inngest/_internal/comm.py rename to inngest/_internal/comm_lib/handler.py index c05ea81..881c97c 100644 --- a/inngest/_internal/comm.py +++ b/inngest/_internal/comm_lib/handler.py @@ -7,6 +7,7 @@ import urllib.parse import httpx +import typing_extensions from inngest._internal import ( client_lib, @@ -22,176 +23,74 @@ types, ) +from .models import ( + AuthenticatedInspection, + CommResponse, + UnauthenticatedInspection, +) +from .utils import parse_query_params -class _ErrorData(types.BaseModel): - code: server_lib.ErrorCode - message: str - name: str - stack: typing.Optional[str] - - @classmethod - def from_error(cls, err: Exception) -> _ErrorData: - if isinstance(err, errors.Error): - code = err.code - message = err.message - name = err.name - stack = err.stack - else: - code = server_lib.ErrorCode.UNKNOWN - message = str(err) - name = type(err).__name__ - stack = transforms.get_traceback(err) - - return cls( - code=code, - message=message, - name=name, - stack=stack, - ) - - -def _prep_call_result( - call_res: execution.CallResult, -) -> types.MaybeError[object]: - """ - Convert a CallResult to the shape the Inngest Server expects. For step-level - results this is a dict and for function-level results this is the output or - error. - """ - - if call_res.step is not None: - d = call_res.step.to_dict() - if isinstance(d, Exception): - # Unreachable - return d - else: - d = {} - - if call_res.error is not None: - e = _ErrorData.from_error(call_res.error).to_dict() - if isinstance(e, Exception): - return e - d["error"] = e - - if call_res.output is not types.empty_sentinel: - err = transforms.dump_json(call_res.output) - if isinstance(err, Exception): - msg = "returned unserializable data" - if call_res.step is not None: - msg = f'"{call_res.step.display_name}" {msg}' - - return errors.OutputUnserializableError(msg) - - d["data"] = call_res.output - - is_function_level = call_res.step is None - if is_function_level: - # Don't nest function-level results - return d.get("error") or d.get("data") - - return d +_ParamsT = typing_extensions.ParamSpec("_ParamsT") -class CommResponse: - def __init__( - self, - *, - body: object = None, - headers: typing.Optional[dict[str, str]] = None, - status_code: int = http.HTTPStatus.OK.value, - ) -> None: - self.headers = headers or {} - self.body = body - self.status_code = status_code - - @classmethod - def from_call_result( - cls, - logger: types.Logger, - call_res: execution.CallResult, +def _prep_response( + f: typing.Callable[ + _ParamsT, typing.Awaitable[typing.Union[CommResponse, Exception]] + ], +) -> typing.Callable[_ParamsT, typing.Awaitable[CommResponse]]: + async def inner( + *args: _ParamsT.args, + **kwargs: _ParamsT.kwargs, ) -> CommResponse: - headers = { - server_lib.HeaderKey.SERVER_TIMING.value: "handler", + comm_handler = args[0] + if not isinstance(comm_handler, CommHandler): + raise ValueError("First argument must be a CommHandler instance.") + + res = await f(*args, **kwargs) + if isinstance(res, Exception): + res = CommResponse.from_error(comm_handler._client.logger, res) + + res.headers = { + **res.headers, + **net.create_headers( + env=comm_handler._client.env, + framework=comm_handler._framework, + server_kind=comm_handler._client._mode, + ), } - if call_res.multi: - multi_body: list[object] = [] - for item in call_res.multi: - d = _prep_call_result(item) - if isinstance(d, Exception): - return cls.from_error(logger, d) - multi_body.append(d) - - if item.error is not None: - if errors.is_retriable(item.error) is False: - headers[server_lib.HeaderKey.NO_RETRY.value] = "true" - - return cls( - body=multi_body, - headers=headers, - status_code=http.HTTPStatus.PARTIAL_CONTENT.value, - ) - - body = _prep_call_result(call_res) - status_code = http.HTTPStatus.OK.value - if isinstance(body, Exception): - return cls.from_error(logger, body) - - if call_res.error is not None: - status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR.value - if errors.is_retriable(call_res.error) is False: - headers[server_lib.HeaderKey.NO_RETRY.value] = "true" + return res - if isinstance(call_res.error, errors.RetryAfterError): - headers[ - server_lib.HeaderKey.RETRY_AFTER.value - ] = transforms.to_iso_utc(call_res.error.retry_after) + return inner - return cls( - body=body, - headers=headers, - status_code=status_code, - ) - @classmethod - def from_error( - cls, - logger: types.Logger, - err: Exception, - status: http.HTTPStatus = http.HTTPStatus.INTERNAL_SERVER_ERROR, +def _prep_response_sync( + f: typing.Callable[_ParamsT, typing.Union[CommResponse, Exception]], +) -> typing.Callable[_ParamsT, CommResponse]: + def inner( + *args: _ParamsT.args, + **kwargs: _ParamsT.kwargs, ) -> CommResponse: - code: typing.Optional[str] = None - if isinstance(err, errors.Error): - code = err.code.value - else: - code = server_lib.ErrorCode.UNKNOWN.value - - if errors.is_quiet(err) is False: - logger.error(f"{code}: {err!s}") - - return cls( - body={ - "code": code, - "message": str(err), - "name": type(err).__name__, - }, - status_code=status.value, - ) + comm_handler = args[0] + if not isinstance(comm_handler, CommHandler): + raise ValueError("First argument must be a CommHandler instance.") + + res = f(*args, **kwargs) + if isinstance(res, Exception): + res = CommResponse.from_error(comm_handler._client.logger, res) + + res.headers = { + **res.headers, + **net.create_headers( + env=comm_handler._client.env, + framework=comm_handler._framework, + server_kind=comm_handler._client._mode, + ), + } - @classmethod - def from_error_code( - cls, - code: server_lib.ErrorCode, - message: str, - status: http.HTTPStatus = http.HTTPStatus.INTERNAL_SERVER_ERROR, - ) -> CommResponse: - return cls( - body={ - "code": code.value, - "message": message, - }, - status_code=status.value, - ) + return res + + return inner class CommHandler: @@ -243,7 +142,7 @@ def __init__( self._signing_key_fallback = client.signing_key_fallback - def _build_registration_request( + def _build_register_request( self, *, app_url: str, @@ -259,7 +158,7 @@ def _build_registration_request( if isinstance(fn_configs, Exception): return fn_configs - body = server_lib.RegisterRequest( + body = server_lib.SynchronizeRequest( app_name=self._client.app_id, deploy_type=server_lib.DeployType.PING, framework=self._framework, @@ -290,21 +189,23 @@ def _build_registration_request( timeout=30, ) + @_prep_response async def call_function( self, *, - call: server_lib.ServerRequest, - fn_id: str, + body: bytes, + headers: typing.Union[dict[str, str], dict[str, list[str]]], + query_params: typing.Union[dict[str, str], dict[str, list[str]]], raw_request: object, - req_sig: net.RequestSignature, - target_hashed_id: str, - ) -> CommResponse: + ) -> typing.Union[CommResponse, Exception]: """Handle a function call from the Executor.""" - if target_hashed_id == server_lib.UNSPECIFIED_STEP_ID: - target_step_id = None - else: - target_step_id = target_hashed_id + headers = net.normalize_headers(headers) + + server_kind = transforms.get_server_kind(headers) + if isinstance(server_kind, Exception): + self._client.logger.error(server_kind) + server_kind = None middleware = middleware_lib.MiddlewareManager.from_client( self._client, @@ -312,74 +213,96 @@ async def call_function( ) # Validate the request signature. - err = req_sig.validate( + err = net.validate_request( + body=body, + headers=headers, + mode=self._client._mode, signing_key=self._signing_key, signing_key_fallback=self._signing_key_fallback, ) if isinstance(err, Exception): - return await self._respond(err) + return err + + request = server_lib.ServerRequest.from_raw(body) + if isinstance(request, Exception): + return request + + params = parse_query_params(query_params) + if isinstance(params, Exception): + return params + if params.fn_id is None: + return errors.QueryParamMissingError( + server_lib.QueryParamKey.FUNCTION_ID.value + ) # Get the function we should call. - fn = self._get_function(fn_id) + fn = self._get_function(params.fn_id) if isinstance(fn, Exception): - return await self._respond(fn) + return fn - events = call.events - steps = call.steps - if call.use_api: + events = request.events + steps = request.steps + if request.use_api: # Putting the batch and memoized steps in the request would make it # to big, so the Executor is telling the SDK to fetch them from the # API fetched_events, fetched_steps = await asyncio.gather( - self._client._get_batch(call.ctx.run_id), - self._client._get_steps(call.ctx.run_id), + self._client._get_batch(request.ctx.run_id), + self._client._get_steps(request.ctx.run_id), ) if isinstance(fetched_events, Exception): - return await self._respond(fetched_events) + return fetched_events events = fetched_events if isinstance(fetched_steps, Exception): - return await self._respond(fetched_steps) + return fetched_steps steps = fetched_steps if events is None: # Should be unreachable. The Executor should always either send the # batch or tell the SDK to fetch the batch - return await self._respond(Exception("events not in request")) - + return Exception("events not in request") call_res = await fn.call( self._client, execution.Context( - attempt=call.ctx.attempt, - event=call.event, + attempt=request.ctx.attempt, + event=request.event, events=events, logger=self._client.logger, - run_id=call.ctx.run_id, + run_id=request.ctx.run_id, ), - fn_id, + params.fn_id, middleware, - call.ctx.stack.stack or [], + request.ctx.stack.stack or [], step_lib.StepMemos.from_raw(steps), - target_step_id, + params.step_id, ) - return await self._respond(call_res) + return CommResponse.from_call_result( + self._client.logger, + call_res, + self._client.env, + self._framework, + server_kind, + ) + @_prep_response_sync def call_function_sync( self, *, - call: server_lib.ServerRequest, - fn_id: str, + body: bytes, + headers: typing.Union[dict[str, str], dict[str, list[str]]], + query_params: typing.Union[dict[str, str], dict[str, list[str]]], raw_request: object, - req_sig: net.RequestSignature, - target_hashed_id: str, - ) -> CommResponse: + ) -> typing.Union[CommResponse, Exception]: """Handle a function call from the Executor.""" - if target_hashed_id == server_lib.UNSPECIFIED_STEP_ID: - target_step_id = None - else: - target_step_id = target_hashed_id + headers = net.normalize_headers(headers) + + server_kind = transforms.get_server_kind(headers) + if isinstance(server_kind, Exception): + self._client.logger.error(server_kind) + server_kind = None middleware = middleware_lib.MiddlewareManager.from_client( self._client, @@ -387,56 +310,77 @@ def call_function_sync( ) # Validate the request signature. - err = req_sig.validate( + err = net.validate_request( + body=body, + headers=headers, + mode=self._client._mode, signing_key=self._signing_key, signing_key_fallback=self._signing_key_fallback, ) if isinstance(err, Exception): - return self._respond_sync(err) + return err + + request = server_lib.ServerRequest.from_raw(body) + if isinstance(request, Exception): + return request + + params = parse_query_params(query_params) + if isinstance(params, Exception): + return params + if params.fn_id is None: + return errors.QueryParamMissingError( + server_lib.QueryParamKey.FUNCTION_ID.value + ) # Get the function we should call. - fn = self._get_function(fn_id) + fn = self._get_function(params.fn_id) if isinstance(fn, Exception): - return self._respond_sync(fn) + return fn - events = call.events - steps = call.steps - if call.use_api: + events = request.events + steps = request.steps + if request.use_api: # Putting the batch and memoized steps in the request would make it # to big, so the Executor is telling the SDK to fetch them from the # API - fetched_events = self._client._get_batch_sync(call.ctx.run_id) + fetched_events = self._client._get_batch_sync(request.ctx.run_id) if isinstance(fetched_events, Exception): - return self._respond_sync(fetched_events) + return fetched_events events = fetched_events - fetched_steps = self._client._get_steps_sync(call.ctx.run_id) + fetched_steps = self._client._get_steps_sync(request.ctx.run_id) if isinstance(fetched_steps, Exception): - return self._respond_sync(fetched_steps) + return fetched_steps steps = fetched_steps if events is None: # Should be unreachable. The Executor should always either send the # batch or tell the SDK to fetch the batch - return self._respond_sync(Exception("events not in request")) + return Exception("events not in request") call_res = fn.call_sync( self._client, execution.Context( - attempt=call.ctx.attempt, - event=call.event, + attempt=request.ctx.attempt, + event=request.event, events=events, logger=self._client.logger, - run_id=call.ctx.run_id, + run_id=request.ctx.run_id, ), - fn_id, + params.fn_id, middleware, step_lib.StepMemos.from_raw(steps), - target_step_id, + params.step_id, ) - return self._respond_sync(call_res) + return CommResponse.from_call_result( + self._client.logger, + call_res, + self._client.env, + self._framework, + server_kind, + ) def _get_function(self, fn_id: str) -> types.MaybeError[function.Function]: # Look for the function ID in the list of user functions, but also @@ -477,27 +421,36 @@ def get_function_configs( return errors.FunctionConfigInvalidError("no functions found") return configs + @_prep_response_sync def inspect( self, *, - req_sig: net.RequestSignature, + body: bytes, + headers: typing.Union[dict[str, str], dict[str, list[str]]], serve_origin: typing.Optional[str], serve_path: typing.Optional[str], - server_kind: typing.Optional[server_lib.ServerKind], ) -> CommResponse: """Handle Dev Server's auto-discovery.""" + headers = net.normalize_headers(headers) + + server_kind = transforms.get_server_kind(headers) + if isinstance(server_kind, Exception): + self._client.logger.error(server_kind) + server_kind = None + if server_kind is not None and server_kind != self._mode: # Tell Dev Server to leave the app alone since it's in production # mode. return CommResponse( body={}, - headers={}, status_code=403, ) - # Validate the request signature. - err = req_sig.validate( + err = net.validate_request( + body=body, + headers=headers, + mode=self._client._mode, signing_key=self._signing_key, signing_key_fallback=self._signing_key_fallback, ) @@ -508,7 +461,7 @@ def inspect( if isinstance(err, Exception): authentication_succeeded = False - body = _UnauthenticatedIntrospection( + res_body = UnauthenticatedInspection( authentication_succeeded=authentication_succeeded, function_count=len(self._fns), has_event_key=self._client.event_key is not None, @@ -535,7 +488,7 @@ def inspect( else None ) - body = _AuthenticatedIntrospection( + res_body = AuthenticatedInspection( api_origin=self._client.api_origin, app_id=self._client.app_id, authentication_succeeded=True, @@ -554,7 +507,7 @@ def inspect( signing_key_hash=signing_key_hash, ) - body_json = body.to_dict() + body_json = res_body.to_dict() if isinstance(body, Exception): body_json = { "error": "failed to serialize inspection data", @@ -562,18 +515,12 @@ def inspect( return CommResponse( body=body_json, - headers=net.create_headers( - env=self._client.env, - framework=self._framework, - server_kind=server_kind, - ), status_code=200, ) def _parse_registration_response( self, server_res: httpx.Response, - server_kind: typing.Optional[server_lib.ServerKind], ) -> CommResponse: try: server_res_body = server_res.json() @@ -592,11 +539,6 @@ def _parse_registration_response( if server_res.status_code < 400: return CommResponse( body=server_res_body, - headers=net.create_headers( - env=self._client.env, - framework=self._framework, - server_kind=server_kind, - ), status_code=http.HTTPStatus.OK, ) @@ -610,26 +552,46 @@ def _parse_registration_response( comm_res.status_code = server_res.status_code return comm_res + @_prep_response async def register( - self, + self: CommHandler, *, - app_url: str, - server_kind: typing.Optional[server_lib.ServerKind], - sync_id: typing.Optional[str], - ) -> CommResponse: + headers: typing.Union[dict[str, str], dict[str, list[str]]], + query_params: typing.Union[dict[str, str], dict[str, list[str]]], + request_url: str, + serve_origin: typing.Optional[str], + serve_path: typing.Optional[str], + ) -> typing.Union[CommResponse, Exception]: """Handle a registration call.""" + headers = net.normalize_headers(headers) + + app_url = net.create_serve_url( + request_url=request_url, + serve_origin=serve_origin, + serve_path=serve_path, + ) + + server_kind = transforms.get_server_kind(headers) + if isinstance(server_kind, Exception): + self._client.logger.error(server_kind) + server_kind = None + comm_res = self._validate_registration(server_kind) if comm_res is not None: return comm_res - req = self._build_registration_request( + params = parse_query_params(query_params) + if isinstance(params, Exception): + return params + + req = self._build_register_request( app_url=app_url, server_kind=server_kind, - sync_id=sync_id, + sync_id=params.sync_id, ) if isinstance(req, Exception): - return CommResponse.from_error(self._client.logger, req) + return req res = await net.fetch_with_auth_fallback( self._client._http_client, @@ -639,31 +601,48 @@ async def register( signing_key_fallback=self._signing_key_fallback, ) - return self._parse_registration_response( - res, - server_kind, - ) + return self._parse_registration_response(res) + @_prep_response_sync def register_sync( - self, + self: CommHandler, *, - app_url: str, - server_kind: typing.Optional[server_lib.ServerKind], - sync_id: typing.Optional[str], - ) -> CommResponse: + headers: typing.Union[dict[str, str], dict[str, list[str]]], + query_params: typing.Union[dict[str, str], dict[str, list[str]]], + request_url: str, + serve_origin: typing.Optional[str], + serve_path: typing.Optional[str], + ) -> typing.Union[CommResponse, Exception]: """Handle a registration call.""" + headers = net.normalize_headers(headers) + + app_url = net.create_serve_url( + request_url=request_url, + serve_origin=serve_origin, + serve_path=serve_path, + ) + + server_kind = transforms.get_server_kind(headers) + if isinstance(server_kind, Exception): + self._client.logger.error(server_kind) + server_kind = None + comm_res = self._validate_registration(server_kind) if comm_res is not None: return comm_res - req = self._build_registration_request( + params = parse_query_params(query_params) + if isinstance(params, Exception): + return params + + req = self._build_register_request( app_url=app_url, server_kind=server_kind, - sync_id=sync_id, + sync_id=params.sync_id, ) if isinstance(req, Exception): - return CommResponse.from_error(self._client.logger, req) + return req res = net.fetch_with_auth_fallback_sync( self._client._http_client_sync, @@ -672,28 +651,7 @@ def register_sync( signing_key_fallback=self._signing_key_fallback, ) - return self._parse_registration_response( - res, - server_kind, - ) - - async def _respond( - self, - value: typing.Union[execution.CallResult, Exception], - ) -> CommResponse: - if isinstance(value, Exception): - return CommResponse.from_error(self._client.logger, value) - - return CommResponse.from_call_result(self._client.logger, value) - - def _respond_sync( - self, - value: typing.Union[execution.CallResult, Exception], - ) -> CommResponse: - if isinstance(value, Exception): - return CommResponse.from_error(self._client.logger, value) - - return CommResponse.from_call_result(self._client.logger, value) + return self._parse_registration_response(res) def _validate_registration( self, @@ -714,30 +672,3 @@ def _validate_registration( ) return None - - -class _UnauthenticatedIntrospection(types.BaseModel): - schema_version: str = "2024-05-24" - - authentication_succeeded: typing.Optional[bool] - function_count: int - has_event_key: bool - has_signing_key: bool - has_signing_key_fallback: bool - mode: server_lib.ServerKind - - -class _AuthenticatedIntrospection(_UnauthenticatedIntrospection): - api_origin: str - app_id: str - authentication_succeeded: bool = True - env: typing.Optional[str] - event_api_origin: str - event_key_hash: typing.Optional[str] - framework: str - sdk_language: str = const.LANGUAGE - sdk_version: str = const.VERSION - serve_origin: typing.Optional[str] - serve_path: typing.Optional[str] - signing_key_fallback_hash: typing.Optional[str] - signing_key_hash: typing.Optional[str] diff --git a/inngest/_internal/comm_test.py b/inngest/_internal/comm_lib/handler_test.py similarity index 95% rename from inngest/_internal/comm_test.py rename to inngest/_internal/comm_lib/handler_test.py index cbce57b..0fd93a8 100644 --- a/inngest/_internal/comm_test.py +++ b/inngest/_internal/comm_lib/handler_test.py @@ -7,7 +7,7 @@ import inngest from inngest._internal import errors, server_lib -from . import comm +from .handler import CommHandler logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -51,7 +51,7 @@ def fn( ) -> int: return 1 - handler = comm.CommHandler( + handler = CommHandler( api_base_url="http://foo.bar", client=client, framework=server_lib.Framework.FLASK, @@ -66,7 +66,7 @@ def fn( def test_no_functions(self) -> None: functions: list[inngest.Function] = [] - handler = comm.CommHandler( + handler = CommHandler( api_base_url="http://foo.bar", client=client, framework=server_lib.Framework.FLASK, diff --git a/inngest/_internal/comm_lib/models.py b/inngest/_internal/comm_lib/models.py new file mode 100644 index 0000000..06b5f8f --- /dev/null +++ b/inngest/_internal/comm_lib/models.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import http +import typing + +from inngest._internal import ( + const, + errors, + execution, + net, + server_lib, + transforms, + types, +) + + +class CommResponse: + def __init__( + self, + *, + body: object = None, + headers: typing.Optional[dict[str, str]] = None, + status_code: int = http.HTTPStatus.OK.value, + ) -> None: + self.headers = headers or {} + self.body = body + self.status_code = status_code + + @classmethod + def from_call_result( + cls, + logger: types.Logger, + call_res: execution.CallResult, + env: typing.Optional[str], + framework: server_lib.Framework, + server_kind: typing.Optional[server_lib.ServerKind], + ) -> CommResponse: + headers = { + server_lib.HeaderKey.SERVER_TIMING.value: "handler", + **net.create_headers( + env=env, + framework=framework, + server_kind=server_kind, + ), + } + + if call_res.multi: + multi_body: list[object] = [] + for item in call_res.multi: + d = _prep_call_result(item) + if isinstance(d, Exception): + return cls.from_error(logger, d) + multi_body.append(d) + + if item.error is not None: + if errors.is_retriable(item.error) is False: + headers[server_lib.HeaderKey.NO_RETRY.value] = "true" + + return cls( + body=multi_body, + headers=headers, + status_code=http.HTTPStatus.PARTIAL_CONTENT.value, + ) + + body = _prep_call_result(call_res) + status_code = http.HTTPStatus.OK.value + if isinstance(body, Exception): + return cls.from_error(logger, body) + + if call_res.error is not None: + status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR.value + if errors.is_retriable(call_res.error) is False: + headers[server_lib.HeaderKey.NO_RETRY.value] = "true" + + if isinstance(call_res.error, errors.RetryAfterError): + headers[ + server_lib.HeaderKey.RETRY_AFTER.value + ] = transforms.to_iso_utc(call_res.error.retry_after) + + return cls( + body=body, + headers=headers, + status_code=status_code, + ) + + @classmethod + def from_error( + cls, + logger: types.Logger, + err: Exception, + status: http.HTTPStatus = http.HTTPStatus.INTERNAL_SERVER_ERROR, + ) -> CommResponse: + code: typing.Optional[str] = None + if isinstance(err, errors.Error): + code = err.code.value + else: + code = server_lib.ErrorCode.UNKNOWN.value + + if errors.is_quiet(err) is False: + logger.error(f"{code}: {err!s}") + + return cls( + body={ + "code": code, + "message": str(err), + "name": type(err).__name__, + }, + status_code=status.value, + ) + + @classmethod + def from_error_code( + cls, + code: server_lib.ErrorCode, + message: str, + status: http.HTTPStatus = http.HTTPStatus.INTERNAL_SERVER_ERROR, + ) -> CommResponse: + return cls( + body={ + "code": code.value, + "message": message, + }, + status_code=status.value, + ) + + def prep_call_result( + self, + call_res: execution.CallResult, + ) -> types.MaybeError[object]: + """ + Convert a CallResult to the shape the Inngest Server expects. For step-level + results this is a dict and for function-level results this is the output or + error. + """ + + if call_res.step is not None: + d = call_res.step.to_dict() + if isinstance(d, Exception): + # Unreachable + return d + else: + d = {} + + if call_res.error is not None: + e = ErrorData.from_error(call_res.error).to_dict() + if isinstance(e, Exception): + return e + d["error"] = e + + if call_res.output is not types.empty_sentinel: + err = transforms.dump_json(call_res.output) + if isinstance(err, Exception): + msg = "returned unserializable data" + if call_res.step is not None: + msg = f'"{call_res.step.display_name}" {msg}' + + return errors.OutputUnserializableError(msg) + + d["data"] = call_res.output + + is_function_level = call_res.step is None + if is_function_level: + # Don't nest function-level results + return d.get("error") or d.get("data") + + return d + + +def _prep_call_result( + call_res: execution.CallResult, +) -> types.MaybeError[object]: + """ + Convert a CallResult to the shape the Inngest Server expects. For step-level + results this is a dict and for function-level results this is the output or + error. + """ + + if call_res.step is not None: + d = call_res.step.to_dict() + if isinstance(d, Exception): + # Unreachable + return d + else: + d = {} + + if call_res.error is not None: + e = ErrorData.from_error(call_res.error).to_dict() + if isinstance(e, Exception): + return e + d["error"] = e + + if call_res.output is not types.empty_sentinel: + err = transforms.dump_json(call_res.output) + if isinstance(err, Exception): + msg = "returned unserializable data" + if call_res.step is not None: + msg = f'"{call_res.step.display_name}" {msg}' + + return errors.OutputUnserializableError(msg) + + d["data"] = call_res.output + + is_function_level = call_res.step is None + if is_function_level: + # Don't nest function-level results + return d.get("error") or d.get("data") + + return d + + +class ErrorData(types.BaseModel): + code: server_lib.ErrorCode + message: str + name: str + stack: typing.Optional[str] + + @classmethod + def from_error(cls, err: Exception) -> ErrorData: + if isinstance(err, errors.Error): + code = err.code + message = err.message + name = err.name + stack = err.stack + else: + code = server_lib.ErrorCode.UNKNOWN + message = str(err) + name = type(err).__name__ + stack = transforms.get_traceback(err) + + return cls( + code=code, + message=message, + name=name, + stack=stack, + ) + + +class UnauthenticatedInspection(types.BaseModel): + schema_version: str = "2024-05-24" + + authentication_succeeded: typing.Optional[bool] + function_count: int + has_event_key: bool + has_signing_key: bool + has_signing_key_fallback: bool + mode: server_lib.ServerKind + + +class AuthenticatedInspection(UnauthenticatedInspection): + api_origin: str + app_id: str + authentication_succeeded: bool = True + env: typing.Optional[str] + event_api_origin: str + event_key_hash: typing.Optional[str] + framework: str + sdk_language: str = const.LANGUAGE + sdk_version: str = const.VERSION + serve_origin: typing.Optional[str] + serve_path: typing.Optional[str] + signing_key_fallback_hash: typing.Optional[str] + signing_key_hash: typing.Optional[str] diff --git a/inngest/_internal/comm_lib/utils.py b/inngest/_internal/comm_lib/utils.py new file mode 100644 index 0000000..401ff78 --- /dev/null +++ b/inngest/_internal/comm_lib/utils.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import typing + +from inngest._internal import server_lib, types + + +class _QueryParams(types.BaseModel): + fn_id: typing.Optional[str] + step_id: typing.Optional[str] + sync_id: typing.Optional[str] + + +def parse_query_params( + query_params: typing.Union[dict[str, str], dict[str, list[str]]], +) -> typing.Union[_QueryParams, Exception]: + normalized: dict[str, str] = {} + for k, v in query_params.items(): + if isinstance(v, list): + normalized[k] = v[0] + else: + normalized[k] = v + + step_id = normalized.get(server_lib.QueryParamKey.STEP_ID.value) + if step_id == server_lib.UNSPECIFIED_STEP_ID: + step_id = None + + return _QueryParams( + fn_id=normalized.get(server_lib.QueryParamKey.FUNCTION_ID.value), + step_id=step_id, + sync_id=normalized.get(server_lib.QueryParamKey.SYNC_ID.value), + ) diff --git a/inngest/_internal/net.py b/inngest/_internal/net.py index 4b663ef..4f324cb 100644 --- a/inngest/_internal/net.py +++ b/inngest/_internal/net.py @@ -55,6 +55,7 @@ def create_headers( headers = { server_lib.HeaderKey.CONTENT_TYPE.value: "application/json", server_lib.HeaderKey.SDK.value: f"inngest-{const.LANGUAGE}:v{const.VERSION}", + server_lib.HeaderKey.REQUEST_VERSION.value: server_lib.PREFERRED_EXECUTION_VERSION, server_lib.HeaderKey.USER_AGENT.value: f"inngest-{const.LANGUAGE}:v{const.VERSION}", } @@ -188,7 +189,9 @@ def fetch_with_auth_fallback_sync( return res -def normalize_headers(headers: dict[str, str]) -> dict[str, str]: +def normalize_headers( + headers: typing.Union[dict[str, str], dict[str, list[str]]], +) -> dict[str, str]: """ Ensure that known headers are in the correct casing. """ @@ -200,7 +203,10 @@ def normalize_headers(headers: dict[str, str]) -> dict[str, str]: if k.lower() == header_key.value.lower(): k = header_key.value - new_headers[k] = v + if isinstance(v, list): + new_headers[k] = v[0] + else: + new_headers[k] = v return new_headers @@ -240,76 +246,91 @@ async def fetch_with_thready_safety( ) -class RequestSignature: - _signature: typing.Optional[str] = None - _timestamp: typing.Optional[int] = None - - def __init__( - self, - body: bytes, - headers: dict[str, str], - mode: server_lib.ServerKind, - ) -> None: - self._body = body - self._mode = mode - - sig_header = headers.get(server_lib.HeaderKey.SIGNATURE.value) - if sig_header is not None: - parsed = urllib.parse.parse_qs(sig_header) - if "t" in parsed: - self._timestamp = int(parsed["t"][0]) - if "s" in parsed: - self._signature = parsed["s"][0] - - def _validate( - self, - signing_key: typing.Optional[str], - ) -> types.MaybeError[None]: - if self._mode == server_lib.ServerKind.DEV_SERVER: - return None - - if signing_key is None: - return errors.SigningKeyMissingError( - "cannot validate signature in production mode without a signing key" - ) - - if self._signature is None: - return errors.HeaderMissingError( - f"cannot validate signature in production mode without a {server_lib.HeaderKey.SIGNATURE.value} header" - ) - - mac = hmac.new( - transforms.remove_signing_key_prefix(signing_key).encode("utf-8"), - self._body, - hashlib.sha256, +def _validate_request( + *, + body: bytes, + headers: dict[str, str], + mode: server_lib.ServerKind, + signing_key: typing.Optional[str], +) -> types.MaybeError[None]: + if mode == server_lib.ServerKind.DEV_SERVER: + return None + + timestamp = None + signature = None + sig_header = headers.get(server_lib.HeaderKey.SIGNATURE.value) + if sig_header is None: + return errors.HeaderMissingError( + f"cannot validate signature in production mode without a {server_lib.HeaderKey.SIGNATURE.value} header" + ) + else: + parsed = urllib.parse.parse_qs(sig_header) + if "t" in parsed: + timestamp = int(parsed["t"][0]) + if "s" in parsed: + signature = parsed["s"][0] + + if signing_key is None: + return errors.SigningKeyMissingError( + "cannot validate signature in production mode without a signing key" + ) + + if signature is None: + return Exception( + f"{server_lib.HeaderKey.SIGNATURE.value} header is malformed" ) - mac.update(str(self._timestamp).encode("utf-8")) - if not hmac.compare_digest(self._signature, mac.hexdigest()): - return errors.SigVerificationFailedError() - return None + mac = hmac.new( + transforms.remove_signing_key_prefix(signing_key).encode("utf-8"), + body, + hashlib.sha256, + ) + + if timestamp: + mac.update(str(timestamp).encode("utf-8")) + + if not hmac.compare_digest(signature, mac.hexdigest()): + return errors.SigVerificationFailedError() + + return None + + +def validate_request( + *, + body: bytes, + headers: dict[str, str], + mode: server_lib.ServerKind, + signing_key: typing.Optional[str], + signing_key_fallback: typing.Optional[str], +) -> types.MaybeError[None]: + """ + Validate the request signature. Falls back to the fallback signing key if + signature validation fails with the primary signing key. + + Args: + ---- + body: Request body. + headers: Request headers. + mode: Server mode. + signing_key: Primary signing key. + signing_key_fallback: Fallback signing key. + """ + + err = _validate_request( + body=body, + headers=headers, + mode=mode, + signing_key=signing_key, + ) + if err is not None and signing_key_fallback is not None: + # If the signature validation failed but there's a "fallback" + # signing key, attempt to validate the signature with the fallback + # key + err = _validate_request( + body=body, + headers=headers, + mode=mode, + signing_key=signing_key_fallback, + ) - def validate( - self, - *, - signing_key: typing.Optional[str], - signing_key_fallback: typing.Optional[str], - ) -> types.MaybeError[None]: - """ - Validate the request signature. Falls back to the fallback signing key if - signature validation fails with the primary signing key. - - Args: - ---- - signing_key: The primary signing key. - signing_key_fallback: The fallback signing key. - """ - - err = self._validate(signing_key) - if err is not None and signing_key_fallback is not None: - # If the signature validation failed but there's a "fallback" - # signing key, attempt to validate the signature with the fallback - # key - err = self._validate(signing_key_fallback) - - return err + return err diff --git a/inngest/_internal/net_test.py b/inngest/_internal/net_test.py index 0b16ea3..25c56ab 100644 --- a/inngest/_internal/net_test.py +++ b/inngest/_internal/net_test.py @@ -107,11 +107,11 @@ def test_success(self) -> None: server_lib.HeaderKey.SIGNATURE.value: f"s={sig}&t={unix_ms}", } - req_sig = net.RequestSignature( - body, headers, mode=server_lib.ServerKind.CLOUD - ) assert not isinstance( - req_sig.validate( + net.validate_request( + body=body, + headers=headers, + mode=server_lib.ServerKind.CLOUD, signing_key=_signing_key, signing_key_fallback=None, ), @@ -131,11 +131,11 @@ def test_body_tamper(self) -> None: } body = json.dumps({"msg": "you've been hacked"}).encode("utf-8") - req_sig = net.RequestSignature( - body, headers, mode=server_lib.ServerKind.CLOUD - ) - validation = req_sig.validate( + validation = net.validate_request( + body=body, + headers=headers, + mode=server_lib.ServerKind.CLOUD, signing_key=_signing_key, signing_key_fallback=None, ) @@ -154,11 +154,11 @@ def test_rotation(self) -> None: server_lib.HeaderKey.SIGNATURE.value: f"s={sig}&t={unix_ms}", } - req_sig = net.RequestSignature( - body, headers, mode=server_lib.ServerKind.CLOUD - ) assert not isinstance( - req_sig.validate( + net.validate_request( + body=body, + headers=headers, + mode=server_lib.ServerKind.CLOUD, signing_key=_signing_key, signing_key_fallback=_signing_key_fallback, ), @@ -177,11 +177,11 @@ def test_fails_for_both_signing_keys(self) -> None: server_lib.HeaderKey.SIGNATURE.value: f"s={sig}&t={unix_ms}", } - req_sig = net.RequestSignature( - body, headers, mode=server_lib.ServerKind.CLOUD - ) assert isinstance( - req_sig.validate( + net.validate_request( + body=body, + headers=headers, + mode=server_lib.ServerKind.CLOUD, signing_key=_signing_key, signing_key_fallback=_signing_key_fallback, ), diff --git a/inngest/_internal/server_lib/__init__.py b/inngest/_internal/server_lib/__init__.py index d209394..973c735 100644 --- a/inngest/_internal/server_lib/__init__.py +++ b/inngest/_internal/server_lib/__init__.py @@ -1,4 +1,5 @@ from .consts import ( + PREFERRED_EXECUTION_VERSION, ROOT_STEP_ID, UNSPECIFIED_STEP_ID, DeployType, @@ -20,10 +21,10 @@ FunctionConfig, Priority, RateLimit, - RegisterRequest, Retries, Runtime, Step, + SynchronizeRequest, Throttle, TriggerCron, TriggerEvent, @@ -34,12 +35,6 @@ "Cancel", "Concurrency", "Debounce", - "Priority", - "RateLimit", - "Throttle", - "TriggerCron", - "Runtime", - "TriggerEvent", "DeployType", "ErrorCode", "Event", @@ -48,12 +43,19 @@ "HeaderKey", "InternalEvents", "Opcode", + "PREFERRED_EXECUTION_VERSION", + "Priority", "QueryParamKey", - "Retries", "ROOT_STEP_ID", - "RegisterRequest", + "RateLimit", + "SynchronizeRequest", + "Retries", + "Runtime", "ServerKind", - "Step", "ServerRequest", + "Step", + "Throttle", + "TriggerCron", + "TriggerEvent", "UNSPECIFIED_STEP_ID", ] diff --git a/inngest/_internal/server_lib/consts.py b/inngest/_internal/server_lib/consts.py index 0c21f5e..dd5fd68 100644 --- a/inngest/_internal/server_lib/consts.py +++ b/inngest/_internal/server_lib/consts.py @@ -27,6 +27,14 @@ class ErrorCode(enum.Enum): URL_INVALID = "url_invalid" +class ExecutionVersion(enum.Enum): + V0 = "0" + V1 = "1" + + +PREFERRED_EXECUTION_VERSION: typing.Final = ExecutionVersion.V1.value + + class Framework(enum.Enum): DIGITAL_OCEAN = "digitalocean" DJANGO = "django" @@ -42,6 +50,7 @@ class HeaderKey(enum.Enum): EXPECTED_SERVER_KIND = "X-Inngest-Expected-Server-Kind" FRAMEWORK = "X-Inngest-Framework" NO_RETRY = "X-Inngest-No-Retry" + REQUEST_VERSION = "X-Inngest-Req-Version" RETRY_AFTER = "Retry-After" SDK = "X-Inngest-SDK" SERVER_KIND = "X-Inngest-Server-Kind" diff --git a/inngest/_internal/server_lib/registration.py b/inngest/_internal/server_lib/registration.py index 0d06c5c..924fbcd 100644 --- a/inngest/_internal/server_lib/registration.py +++ b/inngest/_internal/server_lib/registration.py @@ -161,7 +161,7 @@ class TriggerEvent(_BaseConfig): expression: typing.Optional[str] = None -class RegisterRequest(types.BaseModel): +class SynchronizeRequest(types.BaseModel): app_name: str = pydantic.Field(..., serialization_alias="appname") deploy_type: DeployType framework: Framework diff --git a/inngest/_internal/types.py b/inngest/_internal/types.py index e0c8d51..b631ce4 100644 --- a/inngest/_internal/types.py +++ b/inngest/_internal/types.py @@ -87,6 +87,9 @@ def from_raw( raw: object, ) -> typing.Union[BaseModelT, Exception]: try: + if isinstance(raw, (str, bytes)): + raw = cls.model_validate_json(raw) + return cls.model_validate(raw) except Exception as err: return err diff --git a/inngest/digital_ocean.py b/inngest/digital_ocean.py index 390eb30..13f5683 100644 --- a/inngest/digital_ocean.py +++ b/inngest/digital_ocean.py @@ -8,16 +8,7 @@ import typing import urllib.parse -from ._internal import ( - client_lib, - comm, - errors, - function, - net, - server_lib, - transforms, - types, -) +from ._internal import client_lib, comm_lib, errors, function, server_lib, types FRAMEWORK = server_lib.Framework.DIGITAL_OCEAN @@ -41,7 +32,7 @@ def serve( serve_path: The entire function path (e.g. /api/v1/web/fn-b094417f/sample/hello). """ - handler = comm.CommHandler( + handler = comm_lib.CommHandler( api_base_url=client.api_origin, client=client, framework=FRAMEWORK, @@ -49,8 +40,6 @@ def serve( ) def main(event: dict[str, object], context: _Context) -> _Response: - server_kind: typing.Optional[server_lib.ServerKind] = None - try: if "http" not in event: raise errors.BodyInvalidError('missing "http" key in event') @@ -68,33 +57,16 @@ def main(event: dict[str, object], context: _Context) -> _Response: 'missing "queryString" event.http; have you set "web: raw"?' ) - headers = net.normalize_headers(http.headers) - - _server_kind = transforms.get_server_kind(headers) - if not isinstance(_server_kind, Exception): - server_kind = _server_kind - else: - client.logger.error(_server_kind) - server_kind = None - - req_sig = net.RequestSignature( - body=_to_body_bytes(http.body), - headers=headers, - mode=client._mode, - ) - query_params = urllib.parse.parse_qs(http.queryString) if http.method == "GET": return _to_response( - client, handler.inspect( + body=_to_body_bytes(http.body), + headers=http.headers, serve_origin=serve_origin, serve_path=serve_path, - server_kind=server_kind, - req_sig=req_sig, ), - server_kind, ) if http.method == "POST": @@ -130,18 +102,15 @@ def main(event: dict[str, object], context: _Context) -> _Response: raise call return _to_response( - client, handler.call_function_sync( - call=call, - fn_id=fn_id, + body=_to_body_bytes(http.body), + headers=http.headers, + query_params=query_params, raw_request={ "context": context, "event": event, }, - req_sig=req_sig, - target_hashed_id=step_id, ), - server_kind, ) if http.method == "PUT": @@ -156,37 +125,26 @@ def main(event: dict[str, object], context: _Context) -> _Response: path = "/api/v1/web" + context.function_name request_url = urllib.parse.urljoin(context.api_host, path) - sync_id = _get_first( - query_params.get(server_lib.QueryParamKey.SYNC_ID.value), - ) return _to_response( - client, handler.register_sync( - app_url=net.create_serve_url( - request_url=request_url, - serve_origin=serve_origin, - serve_path=serve_path, - ), - server_kind=server_kind, - sync_id=sync_id, + headers=http.headers, + query_params=urllib.parse.parse_qs(http.queryString), + request_url=request_url, + serve_origin=serve_origin, + serve_path=serve_path, ), - server_kind, ) raise Exception(f"unsupported method: {http.method}") except Exception as e: - comm_res = comm.CommResponse.from_error(client.logger, e) + comm_res = comm_lib.CommResponse.from_error(client.logger, e) if isinstance( e, (errors.BodyInvalidError, errors.QueryParamMissingError) ): comm_res.status_code = 400 - return _to_response( - client, - comm_res, - server_kind, - ) + return _to_response(comm_res) return main @@ -206,20 +164,11 @@ def _to_body_bytes(body: typing.Optional[str]) -> bytes: def _to_response( - client: client_lib.Inngest, - comm_res: comm.CommResponse, - server_kind: typing.Union[server_lib.ServerKind, None], + comm_res: comm_lib.CommResponse, ) -> _Response: return { "body": comm_res.body, # type: ignore - "headers": { - **comm_res.headers, - **net.create_headers( - env=client.env, - framework=FRAMEWORK, - server_kind=server_kind, - ), - }, + "headers": comm_res.headers, "statusCode": comm_res.status_code, } diff --git a/inngest/django.py b/inngest/django.py index 5ea004d..d94151c 100644 --- a/inngest/django.py +++ b/inngest/django.py @@ -11,15 +11,7 @@ import django.urls import django.views.decorators.csrf -from ._internal import ( - client_lib, - comm, - errors, - function, - net, - server_lib, - transforms, -) +from ._internal import client_lib, comm_lib, function, server_lib, transforms FRAMEWORK = server_lib.Framework.DJANGO @@ -44,7 +36,7 @@ def serve( serve_path: Path to serve Inngest from. """ - handler = comm.CommHandler( + handler = comm_lib.CommHandler( api_base_url=client.api_origin, client=client, framework=FRAMEWORK, @@ -74,7 +66,7 @@ def serve( def _create_handler_sync( client: client_lib.Inngest, - handler: comm.CommHandler, + handler: comm_lib.CommHandler, *, serve_origin: typing.Optional[str], serve_path: typing.Optional[str], @@ -82,79 +74,38 @@ def _create_handler_sync( def inngest_api( request: django.http.HttpRequest, ) -> django.http.HttpResponse: - headers = net.normalize_headers(dict(request.headers.items())) - - server_kind = transforms.get_server_kind(headers) - if isinstance(server_kind, Exception): - client.logger.error(server_kind) - server_kind = None - - req_sig = net.RequestSignature( - body=request.body, - headers=headers, - mode=client._mode, - ) - if request.method == "GET": return _to_response( client, handler.inspect( + body=request.body, + headers=dict(request.headers.items()), serve_origin=serve_origin, serve_path=serve_path, - server_kind=server_kind, - req_sig=req_sig, ), - server_kind, ) if request.method == "POST": - fn_id = request.GET.get(server_lib.QueryParamKey.FUNCTION_ID.value) - if fn_id is None: - raise errors.QueryParamMissingError( - server_lib.QueryParamKey.FUNCTION_ID.value - ) - - step_id = request.GET.get(server_lib.QueryParamKey.STEP_ID.value) - if step_id is None: - raise errors.QueryParamMissingError( - server_lib.QueryParamKey.STEP_ID.value - ) - - call = server_lib.ServerRequest.from_raw(json.loads(request.body)) - if isinstance(call, Exception): - return _to_response( - client, - comm.CommResponse.from_error(client.logger, call), - server_kind, - ) - return _to_response( client, handler.call_function_sync( - call=call, - fn_id=fn_id, + body=request.body, + headers=dict(request.headers.items()), + query_params=dict(request.GET.items()), raw_request=request, - req_sig=req_sig, - target_hashed_id=step_id, ), - server_kind, ) if request.method == "PUT": - sync_id = request.GET.get(server_lib.QueryParamKey.SYNC_ID.value) - return _to_response( client, handler.register_sync( - app_url=net.create_serve_url( - request_url=request.build_absolute_uri(), - serve_origin=serve_origin, - serve_path=serve_path, - ), - server_kind=server_kind, - sync_id=sync_id, + headers=dict(request.headers.items()), + query_params=dict(request.GET.items()), + request_url=request.build_absolute_uri(), + serve_origin=serve_origin, + serve_path=serve_path, ), - server_kind, ) return django.http.JsonResponse( @@ -170,7 +121,7 @@ def inngest_api( def _create_handler_async( client: client_lib.Inngest, - handler: comm.CommHandler, + handler: comm_lib.CommHandler, *, serve_origin: typing.Optional[str], serve_path: typing.Optional[str], @@ -188,79 +139,38 @@ def _create_handler_async( async def inngest_api( request: django.http.HttpRequest, ) -> django.http.HttpResponse: - headers = net.normalize_headers(dict(request.headers.items())) - - server_kind = transforms.get_server_kind(headers) - if isinstance(server_kind, Exception): - client.logger.error(server_kind) - server_kind = None - - req_sig = net.RequestSignature( - body=request.body, - headers=headers, - mode=client._mode, - ) - if request.method == "GET": return _to_response( client, handler.inspect( + body=json.loads(request.body), + headers=dict(request.headers.items()), serve_origin=serve_origin, serve_path=serve_path, - server_kind=server_kind, - req_sig=req_sig, ), - server_kind, ) if request.method == "POST": - fn_id = request.GET.get(server_lib.QueryParamKey.FUNCTION_ID.value) - if fn_id is None: - raise errors.QueryParamMissingError( - server_lib.QueryParamKey.FUNCTION_ID.value - ) - - step_id = request.GET.get(server_lib.QueryParamKey.STEP_ID.value) - if step_id is None: - raise errors.QueryParamMissingError( - server_lib.QueryParamKey.STEP_ID.value - ) - - call = server_lib.ServerRequest.from_raw(json.loads(request.body)) - if isinstance(call, Exception): - return _to_response( - client, - comm.CommResponse.from_error(client.logger, call), - server_kind, - ) - return _to_response( client, await handler.call_function( - call=call, - fn_id=fn_id, + body=json.loads(request.body), + headers=dict(request.headers.items()), + query_params=dict(request.GET.items()), raw_request=request, - req_sig=req_sig, - target_hashed_id=step_id, ), - server_kind, ) if request.method == "PUT": - sync_id = request.GET.get(server_lib.QueryParamKey.SYNC_ID.value) - return _to_response( client, await handler.register( - app_url=net.create_serve_url( - request_url=request.build_absolute_uri(), - serve_origin=serve_origin, - serve_path=serve_path, - ), - server_kind=server_kind, - sync_id=sync_id, + headers=dict(request.headers.items()), + query_params=dict(request.GET.items()), + request_url=request.build_absolute_uri(), + serve_origin=serve_origin, + serve_path=serve_path, ), - server_kind, ) return django.http.JsonResponse( @@ -276,23 +186,15 @@ async def inngest_api( def _to_response( client: client_lib.Inngest, - comm_res: comm.CommResponse, - server_kind: typing.Optional[server_lib.ServerKind], + comm_res: comm_lib.CommResponse, ) -> django.http.HttpResponse: body = transforms.dump_json(comm_res.body) if isinstance(body, Exception): - comm_res = comm.CommResponse.from_error(client.logger, body) + comm_res = comm_lib.CommResponse.from_error(client.logger, body) body = json.dumps(comm_res.body) return django.http.HttpResponse( body.encode("utf-8"), - headers={ - **comm_res.headers, - **net.create_headers( - env=client.env, - framework=FRAMEWORK, - server_kind=server_kind, - ), - }, + headers=comm_res.headers, status=comm_res.status_code, ) diff --git a/inngest/fast_api.py b/inngest/fast_api.py index bbe5b46..57767c1 100644 --- a/inngest/fast_api.py +++ b/inngest/fast_api.py @@ -5,7 +5,7 @@ import fastapi -from ._internal import client_lib, comm, function, net, server_lib, transforms +from ._internal import client_lib, comm_lib, function, server_lib, transforms FRAMEWORK = server_lib.Framework.FAST_API @@ -31,7 +31,7 @@ def serve( serve_path: Path to serve the functions from. """ - handler = comm.CommHandler( + handler = comm_lib.CommHandler( api_base_url=client.api_origin, client=client, framework=FRAMEWORK, @@ -42,116 +42,57 @@ def serve( async def get_api_inngest( request: fastapi.Request, ) -> fastapi.Response: - body = await request.body() - headers = net.normalize_headers(dict(request.headers.items())) - - server_kind = transforms.get_server_kind(headers) - if isinstance(server_kind, Exception): - client.logger.error(server_kind) - server_kind = None - return _to_response( client, handler.inspect( + body=await request.body(), + headers=dict(request.headers.items()), serve_origin=serve_origin, serve_path=serve_path, - server_kind=server_kind, - req_sig=net.RequestSignature( - body=body, - headers=headers, - mode=client._mode, - ), ), - server_kind, ) @app.post("/api/inngest") async def post_inngest_api( - fnId: str, # noqa: N803 - stepId: str, # noqa: N803 request: fastapi.Request, ) -> fastapi.Response: - body = await request.body() - headers = net.normalize_headers(dict(request.headers.items())) - - server_kind = transforms.get_server_kind(headers) - if isinstance(server_kind, Exception): - client.logger.error(server_kind) - server_kind = None - - call = server_lib.ServerRequest.from_raw(json.loads(body)) - if isinstance(call, Exception): - return _to_response( - client, - comm.CommResponse.from_error(client.logger, call), - server_kind, - ) - return _to_response( client, await handler.call_function( - call=call, - fn_id=fnId, + body=await request.body(), + headers=dict(request.headers.items()), + query_params=dict(request.query_params.items()), raw_request=request, - req_sig=net.RequestSignature( - body=body, - headers=headers, - mode=client._mode, - ), - target_hashed_id=stepId, ), - server_kind, ) @app.put("/api/inngest") async def put_inngest_api( request: fastapi.Request, ) -> fastapi.Response: - headers = net.normalize_headers(dict(request.headers.items())) - - server_kind = transforms.get_server_kind(headers) - if isinstance(server_kind, Exception): - client.logger.error(server_kind) - server_kind = None - - sync_id = request.query_params.get( - server_lib.QueryParamKey.SYNC_ID.value - ) - return _to_response( client, await handler.register( - app_url=net.create_serve_url( - request_url=str(request.url), - serve_origin=serve_origin, - serve_path=serve_path, - ), - server_kind=server_kind, - sync_id=sync_id, + headers=dict(request.headers.items()), + query_params=dict(request.query_params.items()), + request_url=str(request.url), + serve_origin=serve_origin, + serve_path=serve_path, ), - server_kind, ) def _to_response( client: client_lib.Inngest, - comm_res: comm.CommResponse, - server_kind: typing.Union[server_lib.ServerKind, None], + comm_res: comm_lib.CommResponse, ) -> fastapi.responses.Response: body = transforms.dump_json(comm_res.body) if isinstance(body, Exception): - comm_res = comm.CommResponse.from_error(client.logger, body) + comm_res = comm_lib.CommResponse.from_error(client.logger, body) body = json.dumps(comm_res.body) return fastapi.responses.Response( content=body.encode("utf-8"), - headers={ - **comm_res.headers, - **net.create_headers( - env=client.env, - framework=FRAMEWORK, - server_kind=server_kind, - ), - }, + headers=comm_res.headers, status_code=comm_res.status_code, ) diff --git a/inngest/flask.py b/inngest/flask.py index afa89ce..7756e68 100644 --- a/inngest/flask.py +++ b/inngest/flask.py @@ -5,12 +5,10 @@ import flask -from ._internal import ( +from inngest._internal import ( client_lib, - comm, - errors, + comm_lib, function, - net, server_lib, transforms, ) @@ -39,7 +37,7 @@ def serve( serve_path: Path to serve the functions from. """ - handler = comm.CommHandler( + handler = comm_lib.CommHandler( api_base_url=client.api_origin, client=client, framework=FRAMEWORK, @@ -71,94 +69,45 @@ def serve( def _create_handler_async( app: flask.Flask, client: client_lib.Inngest, - handler: comm.CommHandler, + handler: comm_lib.CommHandler, *, serve_origin: typing.Optional[str], serve_path: typing.Optional[str], ) -> None: @app.route("/api/inngest", methods=["GET", "POST", "PUT"]) async def inngest_api() -> typing.Union[flask.Response, str]: - headers = net.normalize_headers(dict(flask.request.headers.items())) - - server_kind = transforms.get_server_kind(headers) - if isinstance(server_kind, Exception): - client.logger.error(server_kind) - server_kind = None - - req_sig = net.RequestSignature( - body=flask.request.data, - headers=headers, - mode=client._mode, - ) - if flask.request.method == "GET": return _to_response( client, handler.inspect( + body=flask.request.data, + headers=dict(flask.request.headers.items()), serve_origin=serve_origin, serve_path=serve_path, - server_kind=server_kind, - req_sig=req_sig, ), - server_kind, ) if flask.request.method == "POST": - fn_id = flask.request.args.get( - server_lib.QueryParamKey.FUNCTION_ID.value - ) - if fn_id is None: - raise errors.QueryParamMissingError( - server_lib.QueryParamKey.FUNCTION_ID.value - ) - - step_id = flask.request.args.get( - server_lib.QueryParamKey.STEP_ID.value - ) - if step_id is None: - raise errors.QueryParamMissingError( - server_lib.QueryParamKey.STEP_ID.value - ) - - call = server_lib.ServerRequest.from_raw( - json.loads(flask.request.data) - ) - if isinstance(call, Exception): - return _to_response( - client, - comm.CommResponse.from_error(client.logger, call), - server_kind, - ) - return _to_response( client, await handler.call_function( - call=call, - fn_id=fn_id, + body=flask.request.data, + headers=dict(flask.request.headers.items()), + query_params=flask.request.args, raw_request=flask.request, - req_sig=req_sig, - target_hashed_id=step_id, ), - server_kind, ) if flask.request.method == "PUT": - sync_id = flask.request.args.get( - server_lib.QueryParamKey.SYNC_ID.value - ) - return _to_response( client, await handler.register( - app_url=net.create_serve_url( - request_url=flask.request.url, - serve_origin=serve_origin, - serve_path=serve_path, - ), - server_kind=server_kind, - sync_id=sync_id, + headers=dict(flask.request.headers.items()), + query_params=flask.request.args, + request_url=flask.request.url, + serve_origin=serve_origin, + serve_path=serve_path, ), - server_kind, ) # Should be unreachable @@ -168,94 +117,45 @@ async def inngest_api() -> typing.Union[flask.Response, str]: def _create_handler_sync( app: flask.Flask, client: client_lib.Inngest, - handler: comm.CommHandler, + handler: comm_lib.CommHandler, *, serve_origin: typing.Optional[str], serve_path: typing.Optional[str], ) -> None: @app.route("/api/inngest", methods=["GET", "POST", "PUT"]) def inngest_api() -> typing.Union[flask.Response, str]: - headers = net.normalize_headers(dict(flask.request.headers.items())) - - server_kind = transforms.get_server_kind(headers) - if isinstance(server_kind, Exception): - client.logger.error(server_kind) - server_kind = None - - req_sig = net.RequestSignature( - body=flask.request.data, - headers=headers, - mode=client._mode, - ) - if flask.request.method == "GET": return _to_response( client, handler.inspect( + body=flask.request.data, + headers=dict(flask.request.headers.items()), serve_origin=serve_origin, serve_path=serve_path, - server_kind=server_kind, - req_sig=req_sig, ), - server_kind, ) if flask.request.method == "POST": - fn_id = flask.request.args.get( - server_lib.QueryParamKey.FUNCTION_ID.value - ) - if fn_id is None: - raise errors.QueryParamMissingError( - server_lib.QueryParamKey.FUNCTION_ID.value - ) - - step_id = flask.request.args.get( - server_lib.QueryParamKey.STEP_ID.value - ) - if step_id is None: - raise errors.QueryParamMissingError( - server_lib.QueryParamKey.STEP_ID.value - ) - - call = server_lib.ServerRequest.from_raw( - json.loads(flask.request.data) - ) - if isinstance(call, Exception): - return _to_response( - client, - comm.CommResponse.from_error(client.logger, call), - server_kind, - ) - return _to_response( client, handler.call_function_sync( - call=call, - fn_id=fn_id, + body=flask.request.data, + headers=dict(flask.request.headers.items()), + query_params=flask.request.args, raw_request=flask.request, - req_sig=req_sig, - target_hashed_id=step_id, ), - server_kind, ) if flask.request.method == "PUT": - sync_id = flask.request.args.get( - server_lib.QueryParamKey.SYNC_ID.value - ) - return _to_response( client, handler.register_sync( - app_url=net.create_serve_url( - request_url=flask.request.url, - serve_origin=serve_origin, - serve_path=serve_path, - ), - server_kind=server_kind, - sync_id=sync_id, + headers=dict(flask.request.headers.items()), + query_params=flask.request.args, + request_url=flask.request.url, + serve_origin=serve_origin, + serve_path=serve_path, ), - server_kind, ) # Should be unreachable @@ -264,23 +164,15 @@ def inngest_api() -> typing.Union[flask.Response, str]: def _to_response( client: client_lib.Inngest, - comm_res: comm.CommResponse, - server_kind: typing.Optional[server_lib.ServerKind], + comm_res: comm_lib.CommResponse, ) -> flask.Response: body = transforms.dump_json(comm_res.body) if isinstance(body, Exception): - comm_res = comm.CommResponse.from_error(client.logger, body) + comm_res = comm_lib.CommResponse.from_error(client.logger, body) body = json.dumps(comm_res.body) return flask.Response( - headers={ - **comm_res.headers, - **net.create_headers( - env=client.env, - framework=FRAMEWORK, - server_kind=server_kind, - ), - }, + headers=comm_res.headers, response=body.encode("utf-8"), status=comm_res.status_code, ) diff --git a/inngest/tornado.py b/inngest/tornado.py index 6363d43..b2f00a7 100644 --- a/inngest/tornado.py +++ b/inngest/tornado.py @@ -7,10 +7,8 @@ from inngest._internal import ( client_lib, - comm, - errors, + comm_lib, function, - net, server_lib, transforms, ) @@ -38,7 +36,7 @@ def serve( serve_origin: Origin to serve the functions from. serve_path: Path to serve the functions from. """ - handler = comm.CommHandler( + handler = comm_lib.CommHandler( api_base_url=client.api_origin, client=client, framework=FRAMEWORK, @@ -52,129 +50,54 @@ def data_received( return None def get(self) -> None: - headers = net.normalize_headers(dict(self.request.headers.items())) - - server_kind = transforms.get_server_kind(headers) - if isinstance(server_kind, Exception): - client.logger.error(server_kind) - server_kind = None - - req_sig = net.RequestSignature( - body=self.request.body, - headers=headers, - mode=client._mode, - ) - comm_res = handler.inspect( + body=self.request.body, + headers=dict(self.request.headers.items()), serve_origin=serve_origin, serve_path=serve_path, - server_kind=server_kind, - req_sig=req_sig, ) - self._write_comm_response(comm_res, server_kind) + self._write_comm_response(comm_res) def post(self) -> None: - fn_id: typing.Optional[str] - raw_fn_id = self.request.query_arguments.get( - server_lib.QueryParamKey.FUNCTION_ID.value - ) - if raw_fn_id is None or len(raw_fn_id) == 0: - raise errors.QueryParamMissingError( - server_lib.QueryParamKey.FUNCTION_ID.value - ) - fn_id = raw_fn_id[0].decode("utf-8") - - step_id: typing.Optional[str] - raw_step_id = self.request.query_arguments.get( - server_lib.QueryParamKey.STEP_ID.value - ) - if raw_step_id is None or len(raw_step_id) == 0: - raise errors.QueryParamMissingError( - server_lib.QueryParamKey.STEP_ID.value - ) - step_id = raw_step_id[0].decode("utf-8") - - headers = net.normalize_headers(dict(self.request.headers.items())) - - server_kind = transforms.get_server_kind(headers) - if isinstance(server_kind, Exception): - client.logger.error(server_kind) - server_kind = None - - req_sig = net.RequestSignature( - body=self.request.body, - headers=headers, - mode=client._mode, - ) - - call = server_lib.ServerRequest.from_raw( - json.loads(self.request.body) - ) - if isinstance(call, Exception): - return self._write_comm_response( - comm.CommResponse.from_error(client.logger, call), - server_kind, - ) - comm_res = handler.call_function_sync( - call=call, - fn_id=fn_id, + body=self.request.body, + headers=dict(self.request.headers.items()), + query_params=_parse_query_params(self.request.query_arguments), raw_request=self.request, - req_sig=req_sig, - target_hashed_id=step_id, ) - self._write_comm_response(comm_res, server_kind) + self._write_comm_response(comm_res) def put(self) -> None: - headers = net.normalize_headers(dict(self.request.headers.items())) - - server_kind = transforms.get_server_kind(headers) - if isinstance(server_kind, Exception): - client.logger.error(server_kind) - server_kind = None - - sync_id: typing.Optional[str] = None - raw_sync_id = self.request.query_arguments.get( - server_lib.QueryParamKey.SYNC_ID.value - ) - if raw_sync_id is not None: - sync_id = raw_sync_id[0].decode("utf-8") - comm_res = handler.register_sync( - app_url=net.create_serve_url( - request_url=self.request.full_url(), - serve_origin=serve_origin, - serve_path=serve_path, - ), - server_kind=server_kind, - sync_id=sync_id, + headers=dict(self.request.headers.items()), + query_params=_parse_query_params(self.request.query_arguments), + request_url=self.request.full_url(), + serve_origin=serve_origin, + serve_path=serve_path, ) - self._write_comm_response(comm_res, server_kind) + self._write_comm_response(comm_res) def _write_comm_response( self, - comm_res: comm.CommResponse, - server_kind: typing.Optional[server_lib.ServerKind], + comm_res: comm_lib.CommResponse, ) -> None: body = transforms.dump_json(comm_res.body) if isinstance(body, Exception): - comm_res = comm.CommResponse.from_error(client.logger, body) + comm_res = comm_lib.CommResponse.from_error(client.logger, body) body = json.dumps(comm_res.body) self.write(body) for k, v in comm_res.headers.items(): self.add_header(k, v) - for k, v in net.create_headers( - env=client.env, - framework=FRAMEWORK, - server_kind=server_kind, - ).items(): - self.add_header(k, v) self.set_status(comm_res.status_code) app.add_handlers(r".*", [("/api/inngest", InngestHandler)]) + + +def _parse_query_params(raw: dict[str, list[bytes]]) -> dict[str, str]: + return {k: v[0].decode("utf-8") for k, v in raw.items()}