diff --git a/examples/functions/invoke.py b/examples/functions/invoke.py index 6bba9fd..c7ef543 100644 --- a/examples/functions/invoke.py +++ b/examples/functions/invoke.py @@ -25,6 +25,7 @@ async def invoker( "invoke", app_id=client.app_id, function_id="invokee", + timeout=60_000, ) print(res) @@ -56,6 +57,7 @@ def invoker( "invoke", app_id=client.app_id, function_id="invokee", + timeout=60_000, ) print(type(res)) print(res) diff --git a/inngest/__init__.py b/inngest/__init__.py index 40a8852..37c2f4a 100644 --- a/inngest/__init__.py +++ b/inngest/__init__.py @@ -1,10 +1,9 @@ """Public entrypoint for the Inngest SDK.""" -from ._internal.client_lib import Inngest +from ._internal.client_lib import Inngest, SendEventsResult from ._internal.errors import NonRetriableError, RetryAfterError, StepError from ._internal.event_lib import Event -from ._internal.execution import Output from ._internal.function import Context, Function from ._internal.function_config import ( Batch, @@ -16,8 +15,12 @@ TriggerCron, TriggerEvent, ) -from ._internal.middleware_lib import Middleware, MiddlewareSync -from ._internal.step_lib import FunctionID, Step, StepSync +from ._internal.middleware_lib import ( + Middleware, + MiddlewareSync, + TransformOutputResult, +) +from ._internal.step_lib import Step, StepMemos, StepSync from ._internal.types import JSON __all__ = [ @@ -28,19 +31,20 @@ "Debounce", "Event", "Function", - "FunctionID", "Inngest", "JSON", "Middleware", "MiddlewareSync", "NonRetriableError", - "Output", "RateLimit", "RetryAfterError", + "SendEventsResult", "Step", "StepError", + "StepMemos", "StepSync", "Throttle", + "TransformOutputResult", "TriggerCron", "TriggerEvent", ] diff --git a/inngest/_internal/client_lib/__init__.py b/inngest/_internal/client_lib/__init__.py new file mode 100644 index 0000000..8a56bb3 --- /dev/null +++ b/inngest/_internal/client_lib/__init__.py @@ -0,0 +1,4 @@ +from .client import Inngest +from .models import SendEventsResult + +__all__ = ["Inngest", "SendEventsResult"] diff --git a/inngest/_internal/client_lib.py b/inngest/_internal/client_lib/client.py similarity index 90% rename from inngest/_internal/client_lib.py rename to inngest/_internal/client_lib/client.py index afea9d7..50b002f 100644 --- a/inngest/_internal/client_lib.py +++ b/inngest/_internal/client_lib/client.py @@ -8,7 +8,7 @@ import httpx -from . import ( +from inngest._internal import ( const, env_lib, errors, @@ -20,6 +20,8 @@ types, ) +from . import models + # Dummy value _DEV_SERVER_EVENT_KEY = "NO_EVENT_KEY_SET" @@ -404,21 +406,39 @@ async def send( if not isinstance(events, list): events = [events] + middleware = None if not skip_middleware: - middleware = middleware_lib.MiddlewareManager.from_client(self) + middleware = middleware_lib.MiddlewareManager.from_client( + self, + raw_request=None, + ) await middleware.before_send_events(events) req = self._build_send_request(events) if isinstance(req, Exception): raise req - res = await net.fetch_with_thready_safety( - self._http_client, - self._http_client_sync, - req, + result = models.SendEventsResult.from_raw( + ( + await net.fetch_with_thready_safety( + self._http_client, + self._http_client_sync, + req, + ) + ).json() ) + if isinstance(result, Exception): + raise result + + if middleware is not None: + err = await middleware.after_send_events(result) + if isinstance(err, Exception): + raise err + + if result.error is not None: + raise errors.SendEventsError(result.error, result.ids) - return _extract_ids(res.json()) + return result.ids def send_sync( self, @@ -438,28 +458,38 @@ def send_sync( if not isinstance(events, list): events = [events] + middleware = None if not skip_middleware: - middleware = middleware_lib.MiddlewareManager.from_client(self) - middleware.before_send_events_sync(events) + middleware = middleware_lib.MiddlewareManager.from_client( + self, + raw_request=None, + ) + err = middleware.before_send_events_sync(events) + if isinstance(err, Exception): + raise err req = self._build_send_request(events) if isinstance(req, Exception): raise req - return _extract_ids((self._http_client_sync.send(req)).json()) - def set_logger(self, logger: types.Logger) -> None: - self.logger = logger + result = models.SendEventsResult.from_raw( + (self._http_client_sync.send(req)).json(), + ) + if isinstance(result, Exception): + raise result + if middleware is not None: + err = middleware.after_send_events_sync(result) + if isinstance(err, Exception): + raise err -def _extract_ids(body: object) -> list[str]: - if not isinstance(body, dict) or "ids" not in body: - raise errors.BodyInvalidError("unexpected response when sending events") + if result.error is not None: + raise errors.SendEventsError(result.error, result.ids) - ids = body["ids"] - if not isinstance(ids, list): - raise errors.BodyInvalidError("unexpected response when sending events") + return result.ids - return ids + def set_logger(self, logger: types.Logger) -> None: + self.logger = logger def _get_mode( diff --git a/inngest/_internal/client_lib_test.py b/inngest/_internal/client_lib/client_test.py similarity index 98% rename from inngest/_internal/client_lib_test.py rename to inngest/_internal/client_lib/client_test.py index 00747bf..3e5d334 100644 --- a/inngest/_internal/client_lib_test.py +++ b/inngest/_internal/client_lib/client_test.py @@ -3,7 +3,7 @@ import pytest -from . import client_lib, const, errors, event_lib +from inngest._internal import client_lib, const, errors, event_lib class Test(unittest.TestCase): diff --git a/inngest/_internal/client_lib/models.py b/inngest/_internal/client_lib/models.py new file mode 100644 index 0000000..7253ecc --- /dev/null +++ b/inngest/_internal/client_lib/models.py @@ -0,0 +1,8 @@ +import typing + +from inngest._internal import types + + +class SendEventsResult(types.BaseModel): + error: typing.Optional[str] = None + ids: list[str] diff --git a/inngest/_internal/comm.py b/inngest/_internal/comm.py index 3d5d590..3be9607 100644 --- a/inngest/_internal/comm.py +++ b/inngest/_internal/comm.py @@ -24,6 +24,75 @@ ) +class _ErrorData(types.BaseModel): + code: const.ErrorCode + message: str + name: str + stack: typing.Optional[str] + + @classmethod + def from_error(cls, err: Exception) -> _ErrorData: + if isinstance(err, errors.Error): + code = err.code + message = err.message + name = err.name + stack = err.stack + else: + code = const.ErrorCode.UNKNOWN + message = str(err) + name = type(err).__name__ + stack = transforms.get_traceback(err) + + return cls( + code=code, + message=message, + name=name, + stack=stack, + ) + + +def _prep_call_result( + call_res: execution.CallResult, +) -> types.MaybeError[object]: + """ + Convert a CallResult to the shape the Inngest Server expects. For step-level + results this is a dict and for function-level results this is the output or + error. + """ + + if call_res.step is not None: + d = call_res.step.to_dict() + if isinstance(d, Exception): + # Unreachable + return d + else: + d = {} + + if call_res.error is not None: + e = _ErrorData.from_error(call_res.error).to_dict() + if isinstance(e, Exception): + return e + d["error"] = e + + if call_res.output is not types.empty_sentinel: + err = transforms.dump_json(call_res.output) + if isinstance(err, Exception): + msg = "returned unserializable data" + if call_res.step is not None: + msg = f'"{call_res.step.display_name}" {msg}' + + return errors.OutputUnserializableError(msg) + + d["data"] = call_res.output + + is_function_level = call_res.step is None + if is_function_level: + # Don't nest function-level results + return d.get("error") or d.get("data") + + return d + + class CommResponse: def __init__( self, @@ -46,65 +115,43 @@ def from_call_result( const.HeaderKey.SERVER_TIMING.value: "handler", } - if execution.is_step_call_responses(call_res): - out: list[dict[str, object]] = [] - for item in call_res: - d = item.to_dict() + if call_res.multi: + multi_body: list[object] = [] + for item in call_res.multi: + d = _prep_call_result(item) if isinstance(d, Exception): - return cls.from_error( - logger, - errors.OutputUnserializableError( - f'"{item.display_name}" returned unserializable data' - ), - ) - - # Unnest data and error fields to work with the StepRun opcode. - # They should probably be unnested lower in the code, but this - # is a quick fix that doesn't break middleware contracts - nested_data = d.get("data") - if isinstance(nested_data, dict): - d["data"] = nested_data.get("data") - d["error"] = nested_data.get("error") - - out.append(d) + return cls.from_error(logger, d) + multi_body.append(d) + + if item.error is not None: + if errors.is_retriable(item.error) is False: + headers[const.HeaderKey.NO_RETRY.value] = "true" return cls( - body=transforms.prep_body(out), + body=multi_body, headers=headers, status_code=http.HTTPStatus.PARTIAL_CONTENT.value, ) - if isinstance(call_res, execution.CallError): - if call_res.quiet is False: - logger.error(call_res.stack) - - d = call_res.to_dict() - if isinstance(d, Exception): - return cls.from_error(logger, d) + body = _prep_call_result(call_res) + status_code = http.HTTPStatus.OK.value + if isinstance(body, Exception): + return cls.from_error(logger, body) - if call_res.is_retriable is False: + if call_res.error is not None: + status_code = http.HTTPStatus.INTERNAL_SERVER_ERROR.value + if errors.is_retriable(call_res.error) is False: headers[const.HeaderKey.NO_RETRY.value] = "true" - if call_res.retry_after is not None: + if isinstance(call_res.error, errors.RetryAfterError): headers[ const.HeaderKey.RETRY_AFTER.value - ] = transforms.to_iso_utc(call_res.retry_after) + ] = transforms.to_iso_utc(call_res.error.retry_after) - return cls( - body=transforms.prep_body(d), - headers=headers, - status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - if isinstance(call_res, execution.FunctionCallResponse): - return cls( - body=call_res.data, - headers=headers, - ) - - return cls.from_error( - logger, - errors.UnknownError("unknown call result"), + return cls( + body=body, + headers=headers, + status_code=status_code, ) @classmethod @@ -120,7 +167,8 @@ def from_error( else: code = const.ErrorCode.UNKNOWN.value - logger.error(f"{code}: {err!s}") + if errors.is_quiet(err) is False: + logger.error(f"{code}: {err!s}") return cls( body={ @@ -238,7 +286,7 @@ def _build_registration_request( "POST", registration_url, headers=headers, - json=transforms.prep_body(body), + json=transforms.deep_strip_none(body), params=params, timeout=30, ) @@ -248,6 +296,7 @@ async def call_function( *, call: execution.Call, fn_id: str, + raw_request: object, req_sig: net.RequestSignature, target_hashed_id: str, ) -> CommResponse: @@ -258,7 +307,10 @@ async def call_function( else: target_step_id = target_hashed_id - middleware = middleware_lib.MiddlewareManager.from_client(self._client) + middleware = middleware_lib.MiddlewareManager.from_client( + self._client, + raw_request, + ) # Validate the request signature. err = req_sig.validate( @@ -266,12 +318,12 @@ async def call_function( signing_key_fallback=self._signing_key_fallback, ) if isinstance(err, Exception): - return await self._respond(middleware, err) + return await self._respond(err) # Get the function we should call. fn = self._get_function(fn_id) if isinstance(fn, Exception): - return await self._respond(middleware, fn) + return await self._respond(fn) events = call.events steps = call.steps @@ -286,19 +338,16 @@ async def call_function( self._client._get_steps(call.ctx.run_id), ) except Exception as err: - return await self._respond(middleware, err) + return await self._respond(err) if events is None: # Should be unreachable. The Executor should always either send the # batch or tell the SDK to fetch the batch - return await self._respond( - middleware, Exception("events not in request") - ) + return await self._respond(Exception("events not in request")) call_res = await fn.call( self._client, function.Context( - _steps=step_lib.StepMemos.from_raw(steps), attempt=call.ctx.attempt, event=call.event, events=events, @@ -307,16 +356,18 @@ async def call_function( ), fn_id, middleware, + step_lib.StepMemos.from_raw(steps), target_step_id, ) - return await self._respond(middleware, call_res) + return await self._respond(call_res) def call_function_sync( self, *, call: execution.Call, fn_id: str, + raw_request: object, req_sig: net.RequestSignature, target_hashed_id: str, ) -> CommResponse: @@ -327,7 +378,10 @@ def call_function_sync( else: target_step_id = target_hashed_id - middleware = middleware_lib.MiddlewareManager.from_client(self._client) + middleware = middleware_lib.MiddlewareManager.from_client( + self._client, + raw_request, + ) # Validate the request signature. err = req_sig.validate( @@ -335,12 +389,12 @@ def call_function_sync( signing_key_fallback=self._signing_key_fallback, ) if isinstance(err, Exception): - return self._respond_sync(middleware, err) + return self._respond_sync(err) # Get the function we should call. fn = self._get_function(fn_id) if isinstance(fn, Exception): - return self._respond_sync(middleware, fn) + return self._respond_sync(fn) events = call.events steps = call.steps @@ -353,19 +407,16 @@ def call_function_sync( events = self._client._get_batch_sync(call.ctx.run_id) steps = self._client._get_steps_sync(call.ctx.run_id) except Exception as err: - return self._respond_sync(middleware, err) + return self._respond_sync(err) if events is None: # Should be unreachable. The Executor should always either send the # batch or tell the SDK to fetch the batch - return self._respond_sync( - middleware, Exception("events not in request") - ) + return self._respond_sync(Exception("events not in request")) call_res = fn.call_sync( self._client, function.Context( - _steps=step_lib.StepMemos.from_raw(steps), attempt=call.ctx.attempt, event=call.event, events=events, @@ -374,10 +425,11 @@ def call_function_sync( ), fn_id, middleware, + step_lib.StepMemos.from_raw(steps), target_step_id, ) - return self._respond_sync(middleware, call_res) + return self._respond_sync(call_res) def _get_function(self, fn_id: str) -> types.MaybeError[function.Function]: # Look for the function ID in the list of user functions, but also @@ -620,13 +672,8 @@ def register_sync( async def _respond( self, - middleware: middleware_lib.MiddlewareManager, value: typing.Union[execution.CallResult, Exception], ) -> CommResponse: - err = await middleware.before_response() - if isinstance(err, Exception): - return CommResponse.from_error(self._client.logger, err) - if isinstance(value, Exception): return CommResponse.from_error(self._client.logger, value) @@ -634,13 +681,8 @@ async def _respond( def _respond_sync( self, - middleware: middleware_lib.MiddlewareManager, value: typing.Union[execution.CallResult, Exception], ) -> CommResponse: - err = middleware.before_response_sync() - if isinstance(err, Exception): - return CommResponse.from_error(self._client.logger, err) - if isinstance(value, Exception): return CommResponse.from_error(self._client.logger, value) diff --git a/inngest/_internal/errors.py b/inngest/_internal/errors.py index 5117fb5..21376c5 100644 --- a/inngest/_internal/errors.py +++ b/inngest/_internal/errors.py @@ -131,11 +131,9 @@ class NonRetriableError(Error): def __init__( self, message: typing.Optional[str] = None, - cause: typing.Optional[typing.Mapping[str, object]] = None, quiet: bool = False, ) -> None: super().__init__(message) - self.cause = cause self.quiet = quiet @@ -171,6 +169,19 @@ def __init__( self.quiet: bool = quiet +class SendEventsError(Error): + def __init__(self, message: str, ids: list[str]) -> None: + """ + Args: + ---- + message: Error message + ids: List of event IDs that successfully sent + """ + + super().__init__(message) + self.ids = ids + + class StepError(Error): """ Wraps a userland error. This is necessary because the Executor sends @@ -227,6 +238,18 @@ def __init__( self._stack = stack +def is_retriable(err: Exception) -> bool: + if isinstance(err, Error): + return err.is_retriable + return True + + +def is_quiet(err: Exception) -> bool: + if isinstance(err, _Quietable): + return err.quiet + return False + + @typing.runtime_checkable -class Quietable(typing.Protocol): +class _Quietable(typing.Protocol): quiet: bool diff --git a/inngest/_internal/execution.py b/inngest/_internal/execution.py index 2f11c05..7fec227 100644 --- a/inngest/_internal/execution.py +++ b/inngest/_internal/execution.py @@ -1,13 +1,13 @@ from __future__ import annotations -import datetime +import dataclasses import enum import typing import pydantic import typing_extensions -from . import const, errors, event_lib, transforms, types +from . import event_lib, transforms, types class Call(types.BaseModel): @@ -28,70 +28,7 @@ class CallStack(types.BaseModel): stack: list[str] -class CallError(types.BaseModel): - """ - When an error that occurred during a call. Used for both function- and step-level - errors. - """ - - code: const.ErrorCode - is_retriable: bool - message: str - name: str - original_error: object = pydantic.Field(exclude=True) - quiet: bool = pydantic.Field(exclude=True) - retry_after: typing.Optional[datetime.datetime] - stack: typing.Optional[str] - step_id: typing.Optional[str] - - @classmethod - def from_error( - cls, - err: Exception, - step_id: typing.Optional[str] = None, - ) -> CallError: - code = const.ErrorCode.UNKNOWN - if isinstance(err, errors.Error): - code = err.code - is_retriable = err.is_retriable - message = err.message - name = err.name - stack = err.stack - else: - is_retriable = True - message = str(err) - name = type(err).__name__ - stack = transforms.get_traceback(err) - - retry_after = None - if isinstance(err, errors.RetryAfterError): - retry_after = err.retry_after - - quiet = False - if isinstance(err, errors.Quietable): - quiet = err.quiet - - return cls( - code=code, - is_retriable=is_retriable, - message=message, - name=name, - original_error=err, - quiet=quiet, - retry_after=retry_after, - stack=stack, - step_id=step_id, - ) - - -class FunctionCallResponse(types.BaseModel): - """When a function successfully returns.""" - - data: object - - -class StepResponse(types.BaseModel): - data: typing.Optional[Output] = None +class StepInfo(types.BaseModel): display_name: str = pydantic.Field(..., serialization_alias="displayName") id: str @@ -102,6 +39,12 @@ class StepResponse(types.BaseModel): opts: typing.Optional[dict[str, object]] = None +class StepResponse(types.BaseModel): + output: object = None + original_error: object = pydantic.Field(default=None, exclude=True) + step: StepInfo + + class MemoizedError(types.BaseModel): message: str name: str @@ -122,10 +65,7 @@ class Output(types.BaseModel): model_config = pydantic.ConfigDict(extra="forbid") data: object = None - - # TODO: Change the type to MemoizedError. But that requires a breaking - # change, so do it in version 0.4 - error: typing.Optional[dict[str, object]] = None + error: typing.Optional[MemoizedError] = None def is_step_call_responses( @@ -136,9 +76,53 @@ def is_step_call_responses( return all(isinstance(item, StepResponse) for item in value) -CallResult: typing_extensions.TypeAlias = typing.Union[ - list[StepResponse], FunctionCallResponse, CallError -] +@dataclasses.dataclass +class CallResult: + error: typing.Optional[Exception] = None + + # Multiple results from a single call (only used for steps). This will only + # be longer than 1 for parallel steps. Otherwise, it will be 1 long for + # sequential steps + multi: typing.Optional[list[CallResult]] = None + + # Need a sentinel value to differentiate between None and unset + output: object = types.empty_sentinel + + # Step metadata (e.g. user-specified ID) + step: typing.Optional[StepInfo] = None + + @property + def is_empty(self) -> bool: + return all( + [ + self.error is None, + self.multi is None, + self.output is types.empty_sentinel, + self.step is None, + ] + ) + + @classmethod + def from_responses( + cls, + responses: list[StepResponse], + ) -> CallResult: + multi = [] + + for response in responses: + error = None + if isinstance(response.original_error, Exception): + error = response.original_error + + multi.append( + cls( + error=error, + output=response.output, + step=response.step, + ) + ) + + return cls(multi=multi) class Opcode(enum.Enum): diff --git a/inngest/_internal/function.py b/inngest/_internal/function.py index 58ff44e..370ebc4 100644 --- a/inngest/_internal/function.py +++ b/inngest/_internal/function.py @@ -24,10 +24,6 @@ @dataclasses.dataclass class Context: - # TODO: Remove this in v0.4.0. It's only here to avoid a breaking change to - # Middleware.transform_input - _steps: step_lib.StepMemos - attempt: int event: event_lib.Event events: list[event_lib.Event] @@ -132,6 +128,14 @@ def is_on_failure_handler_async(self) -> typing.Optional[bool]: return None return _is_function_handler_async(self._opts.on_failure) + @property + def local_id(self) -> str: + return self._opts.local_id + + @property + def name(self) -> str: + return self._opts.name + @property def on_failure_fn_id(self) -> typing.Optional[str]: return self._on_failure_fn_id @@ -167,24 +171,54 @@ async def call( ctx: Context, fn_id: str, middleware: middleware_lib.MiddlewareManager, + steps: step_lib.StepMemos, target_hashed_id: typing.Optional[str], ) -> execution.CallResult: middleware = middleware_lib.MiddlewareManager.from_manager(middleware) for m in self._middleware: middleware.add(m) + # Move business logic to a private method to make it simpler to run the + # transform_output hook on every call result code path + call_res = await self._call( + client, + ctx, + fn_id, + middleware, + steps, + target_hashed_id, + ) + + err = await middleware.transform_output(call_res) + if isinstance(err, Exception): + return execution.CallResult(err) + + err = await middleware.before_response() + if isinstance(err, Exception): + return execution.CallResult(err) + + return call_res + + async def _call( + self, + client: client_lib.Inngest, + ctx: Context, + fn_id: str, + middleware: middleware_lib.MiddlewareManager, + steps: step_lib.StepMemos, + target_hashed_id: typing.Optional[str], + ) -> execution.CallResult: # Give middleware the opportunity to change some of params passed to the # user's handler. - new_ctx = await middleware.transform_input(ctx) - if isinstance(new_ctx, Exception): - return execution.CallError.from_error(new_ctx) - ctx = new_ctx + middleware_err = await middleware.transform_input(ctx, self, steps) + if isinstance(middleware_err, Exception): + return execution.CallResult(middleware_err) # No memoized data means we're calling the function for the first time. - if ctx._steps.size == 0: + if steps.size == 0: err = await middleware.before_execution() if isinstance(err, Exception): - return execution.CallError.from_error(err) + return execution.CallResult(err) try: handler: typing.Union[FunctionHandlerAsync, FunctionHandlerSync] @@ -192,12 +226,12 @@ async def call( handler = self._handler elif self.on_failure_fn_id == fn_id: if self._opts.on_failure is None: - return execution.CallError.from_error( + return execution.CallResult( errors.FunctionNotFoundError("on_failure not defined") ) handler = self._opts.on_failure else: - return execution.CallError.from_error( + return execution.CallResult( errors.FunctionNotFoundError("function ID mismatch") ) @@ -212,7 +246,7 @@ async def call( ctx=ctx, step=step_lib.Step( client, - ctx._steps, + steps, middleware, step_lib.StepIDCounter(), target_hashed_id, @@ -223,7 +257,7 @@ async def call( ctx=ctx, step=step_lib.StepSync( client, - ctx._steps, + steps, middleware, step_lib.StepIDCounter(), target_hashed_id, @@ -232,7 +266,7 @@ async def call( else: # Should be unreachable but Python's custom type guards don't # support negative checks :( - return execution.CallError.from_error( + return execution.CallResult( errors.UnknownError( "unable to determine function handler type" ) @@ -243,48 +277,29 @@ async def call( err = await middleware.after_execution() if isinstance(err, Exception): - return execution.CallError.from_error(err) - - output = await middleware.transform_output( - # Function output isn't wrapped in an Output object, so we need - # to wrap it to make it compatible with middleware. - execution.Output(data=output) - ) - if isinstance(output, Exception): - return execution.CallError.from_error(output) + return execution.CallResult(err) - if output is None: - return execution.FunctionCallResponse(data=None) - return execution.FunctionCallResponse(data=output.data) + return execution.CallResult(output=output) except step_lib.ResponseInterrupt as interrupt: err = await middleware.after_execution() if isinstance(err, Exception): - return execution.CallError.from_error(err) + return execution.CallResult(err) - # TODO: How should transform_output work with multiple responses? - if len(interrupt.responses) == 1: - output = await middleware.transform_output( - interrupt.responses[0].data - ) - if isinstance(output, Exception): - return execution.CallError.from_error(output) - interrupt.responses[0].data = output - - return interrupt.responses + return execution.CallResult.from_responses(interrupt.responses) except _UserError as err: - return execution.CallError.from_error(err.err) + return execution.CallResult(err.err) except step_lib.SkipInterrupt as err: # This should only happen in a non-deterministic scenario, where # step targeting is enabled and an unexpected step is encountered. # We don't currently have a way to recover from this scenario. - return execution.CallError.from_error( + return execution.CallResult( errors.StepUnexpectedError( f'found step "{err.step_id}" when targeting a different step' ) ) except Exception as err: - return execution.CallError.from_error(err) + return execution.CallResult(err) def call_sync( self, @@ -292,22 +307,54 @@ def call_sync( ctx: Context, fn_id: str, middleware: middleware_lib.MiddlewareManager, + steps: step_lib.StepMemos, target_hashed_id: typing.Optional[str], ) -> execution.CallResult: middleware = middleware_lib.MiddlewareManager.from_manager(middleware) for m in self._middleware: middleware.add(m) + # Move business logic to a private method to make it simpler to run the + # transform_output hook on every call result code path + call_res = self._call_sync( + client, + ctx, + fn_id, + middleware, + steps, + target_hashed_id, + ) + + err = middleware.transform_output_sync(call_res) + if isinstance(err, Exception): + return execution.CallResult(err) + + err = middleware.before_response_sync() + if isinstance(err, Exception): + return execution.CallResult(err) + + return call_res + + def _call_sync( + self, + client: client_lib.Inngest, + ctx: Context, + fn_id: str, + middleware: middleware_lib.MiddlewareManager, + steps: step_lib.StepMemos, + target_hashed_id: typing.Optional[str], + ) -> execution.CallResult: # Give middleware the opportunity to change some of params passed to the # user's handler. - new_ctx = middleware.transform_input_sync(ctx) - if isinstance(new_ctx, Exception): - return execution.CallError.from_error(new_ctx) - ctx = new_ctx + middleware_err = middleware.transform_input_sync(ctx, self, steps) + if isinstance(middleware_err, Exception): + return execution.CallResult(middleware_err) # No memoized data means we're calling the function for the first time. - if ctx._steps.size == 0: - middleware.before_execution_sync() + if steps.size == 0: + err = middleware.before_execution_sync() + if isinstance(err, Exception): + return execution.CallResult(err) try: handler: typing.Union[FunctionHandlerAsync, FunctionHandlerSync] @@ -315,12 +362,12 @@ def call_sync( handler = self._handler elif self.on_failure_fn_id == fn_id: if self._opts.on_failure is None: - return execution.CallError.from_error( + return execution.CallResult( errors.FunctionNotFoundError("on_failure not defined") ) handler = self._opts.on_failure else: - return execution.CallError.from_error( + return execution.CallResult( errors.FunctionNotFoundError("function ID mismatch") ) @@ -330,7 +377,7 @@ def call_sync( ctx=ctx, step=step_lib.StepSync( client, - ctx._steps, + steps, middleware, step_lib.StepIDCounter(), target_hashed_id, @@ -340,7 +387,7 @@ def call_sync( transforms.remove_first_traceback_frame(user_err) raise _UserError(user_err) else: - return execution.CallError.from_error( + return execution.CallResult( errors.AsyncUnsupportedError( "encountered async function in non-async context" ) @@ -348,48 +395,29 @@ def call_sync( err = middleware.after_execution_sync() if isinstance(err, Exception): - return execution.CallError.from_error(err) + return execution.CallResult(err) - output = middleware.transform_output_sync( - # Function output isn't wrapped in an Output object, so we need - # to wrap it to make it compatible with middleware. - execution.Output(data=output) - ) - if isinstance(output, Exception): - return execution.CallError.from_error(output) - - if output is None: - return execution.FunctionCallResponse(data=None) - return execution.FunctionCallResponse(data=output.data) + return execution.CallResult(output=output) except step_lib.ResponseInterrupt as interrupt: err = middleware.after_execution_sync() if isinstance(err, Exception): - return execution.CallError.from_error(err) - - # TODO: How should transform_output work with multiple responses? - if len(interrupt.responses) == 1: - output = middleware.transform_output_sync( - interrupt.responses[0].data - ) - if isinstance(output, Exception): - return execution.CallError.from_error(output) - interrupt.responses[0].data = output + return execution.CallResult(err) - return interrupt.responses + return execution.CallResult.from_responses(interrupt.responses) except _UserError as err: - return execution.CallError.from_error(err.err) + return execution.CallResult(err.err) except step_lib.SkipInterrupt as err: # This should only happen in a non-deterministic scenario, where # step targeting is enabled and an unexpected step is encountered. # We don't currently have a way to recover from this scenario. - return execution.CallError.from_error( + return execution.CallResult( errors.StepUnexpectedError( f'found step "{err.step_id}" when targeting a different step' ) ) except Exception as err: - return execution.CallError.from_error(err) + return execution.CallResult(err) def get_config(self, app_url: str) -> _Config: fn_id = self._opts.fully_qualified_id diff --git a/inngest/_internal/log.py b/inngest/_internal/log.py index a460117..987e57a 100644 --- a/inngest/_internal/log.py +++ b/inngest/_internal/log.py @@ -1,6 +1,6 @@ from __future__ import annotations -from . import client_lib, function, middleware_lib, types +from . import types class LoggerProxy: @@ -35,21 +35,3 @@ def __getattr__(self, name: str) -> object: def enable(self) -> None: self._is_enabled = True - - -class LoggerMiddleware(middleware_lib.MiddlewareSync): - def __init__(self, client: client_lib.Inngest) -> None: - super().__init__(client) - self.logger = LoggerProxy(client.logger) - - def before_execution(self) -> None: - # Enable logging because we've encountered new code. - self.logger.enable() - - def transform_input( - self, - ctx: function.Context, - ) -> function.Context: - self.logger.logger = ctx.logger - ctx.logger = self.logger # type: ignore - return ctx diff --git a/inngest/_internal/middleware_lib/__init__.py b/inngest/_internal/middleware_lib/__init__.py index b4c411d..942f65f 100644 --- a/inngest/_internal/middleware_lib/__init__.py +++ b/inngest/_internal/middleware_lib/__init__.py @@ -1,9 +1,17 @@ from .manager import MiddlewareManager -from .middleware import Middleware, MiddlewareSync, UninitializedMiddleware +from .middleware import ( + Middleware, + MiddlewareSync, + TransformOutputResult, + TransformOutputStepInfo, + UninitializedMiddleware, +) __all__ = [ "Middleware", "MiddlewareManager", "MiddlewareSync", + "TransformOutputResult", + "TransformOutputStepInfo", "UninitializedMiddleware", ] diff --git a/inngest/_internal/middleware_lib/log.py b/inngest/_internal/middleware_lib/log.py index 8fff953..d0ed602 100644 --- a/inngest/_internal/middleware_lib/log.py +++ b/inngest/_internal/middleware_lib/log.py @@ -1,6 +1,6 @@ from __future__ import annotations -from inngest._internal import client_lib, function, types +from inngest._internal import client_lib, function, step_lib, types from .middleware import MiddlewareSync @@ -34,8 +34,8 @@ def enable(self) -> None: class LoggerMiddleware(MiddlewareSync): - def __init__(self, client: client_lib.Inngest) -> None: - super().__init__(client) + def __init__(self, client: client_lib.Inngest, raw_request: object) -> None: + super().__init__(client, raw_request) self.logger = LoggerProxy(client.logger) def before_execution(self) -> None: @@ -44,7 +44,8 @@ def before_execution(self) -> None: def transform_input( self, ctx: function.Context, - ) -> function.Context: + function: function.Function, + steps: step_lib.StepMemos, + ) -> None: self.logger.logger = ctx.logger ctx.logger = self.logger # type: ignore - return ctx diff --git a/inngest/_internal/middleware_lib/manager.py b/inngest/_internal/middleware_lib/manager.py index c6495a1..7b8334f 100644 --- a/inngest/_internal/middleware_lib/manager.py +++ b/inngest/_internal/middleware_lib/manager.py @@ -8,12 +8,19 @@ event_lib, execution, function, + step_lib, transforms, types, ) from .log import LoggerMiddleware -from .middleware import Middleware, MiddlewareSync, UninitializedMiddleware +from .middleware import ( + Middleware, + MiddlewareSync, + TransformOutputResult, + TransformOutputStepInfo, + UninitializedMiddleware, +) if typing.TYPE_CHECKING: from inngest._internal import client_lib @@ -30,18 +37,23 @@ class MiddlewareManager: def middleware(self) -> list[typing.Union[Middleware, MiddlewareSync]]: return [*self._middleware] - def __init__(self, client: client_lib.Inngest) -> None: + def __init__(self, client: client_lib.Inngest, raw_request: object) -> None: self.client = client self._disabled_hooks = set[str]() self._middleware = list[typing.Union[Middleware, MiddlewareSync]]() + self._raw_request = raw_request @classmethod - def from_client(cls, client: client_lib.Inngest) -> MiddlewareManager: + def from_client( + cls, + client: client_lib.Inngest, + raw_request: object, + ) -> MiddlewareManager: """ Create a new manager from an Inngest client, using the middleware on the client. """ - mgr = cls(client) + mgr = cls(client, raw_request) for m in DEFAULT_CLIENT_MIDDLEWARE: mgr.add(m) @@ -57,13 +69,16 @@ def from_manager(cls, manager: MiddlewareManager) -> MiddlewareManager: Create a new manager from another manager, using the middleware on the passed manager. Effectively wraps a manager. """ - new_mgr = cls(manager.client) + new_mgr = cls(manager.client, manager._raw_request) for m in manager.middleware: new_mgr._middleware = [*new_mgr._middleware, m] return new_mgr def add(self, middleware: UninitializedMiddleware) -> None: - self._middleware = [*self._middleware, middleware(self.client)] + self._middleware = [ + *self._middleware, + middleware(self.client, self._raw_request), + ] async def after_execution(self) -> types.MaybeError[None]: try: @@ -83,40 +98,82 @@ def after_execution_sync(self) -> types.MaybeError[None]: except Exception as err: return err + async def after_send_events( + self, + result: client_lib.SendEventsResult, + ) -> types.MaybeError[None]: + try: + for m in self._middleware: + await transforms.maybe_await(m.after_send_events(result)) + return None + except Exception as err: + return err + + def after_send_events_sync( + self, + result: client_lib.SendEventsResult, + ) -> types.MaybeError[None]: + try: + for m in self._middleware: + if inspect.iscoroutinefunction(m.after_execution): + return _mismatched_sync + m.after_send_events(result) + return None + except Exception as err: + return err + async def before_execution(self) -> types.MaybeError[None]: hook = "before_execution" if hook in self._disabled_hooks: # Only allow before_execution to be called once. This simplifies # code since execution can start at the function or step level. return None + self._disabled_hooks.add(hook) + # Also handle after_memoization here since it's always called + # immediately before before_execution try: for m in self._middleware: - await transforms.maybe_await(m.before_execution()) + await transforms.maybe_await(m.after_memoization()) + except Exception as err: + return err - self._disabled_hooks.add(hook) - return None + try: + for m in self._middleware: + await transforms.maybe_await(m.before_execution()) except Exception as err: return err + return None + def before_execution_sync(self) -> types.MaybeError[None]: hook = "before_execution" if hook in self._disabled_hooks: # Only allow before_execution to be called once. This simplifies # code since execution can start at the function or step level. return None + self._disabled_hooks.add(hook) + try: + for m in self._middleware: + if inspect.iscoroutinefunction(m.after_memoization): + return _mismatched_sync + m.after_memoization() + except Exception as err: + return err + + # Also handle after_memoization here since it's always called + # immediately before before_execution try: for m in self._middleware: if inspect.iscoroutinefunction(m.before_execution): return _mismatched_sync m.before_execution() - - self._disabled_hooks.add(hook) - return None except Exception as err: return err + return None + async def before_response(self) -> types.MaybeError[None]: try: for m in self._middleware: @@ -162,59 +219,133 @@ def before_send_events_sync( async def transform_input( self, ctx: function.Context, - ) -> types.MaybeError[function.Context]: + function: function.Function, + steps: step_lib.StepMemos, + ) -> types.MaybeError[None]: try: for m in self._middleware: - ctx = await transforms.maybe_await( - m.transform_input(ctx), + await transforms.maybe_await( + m.transform_input(ctx, function, steps), ) - return ctx except Exception as err: return err + # Also handle before_memoization here since it's always called + # immediately after transform_input + try: + for m in self._middleware: + await transforms.maybe_await(m.before_memoization()) + except Exception as err: + return err + + return None + def transform_input_sync( self, ctx: function.Context, - ) -> types.MaybeError[function.Context]: + function: function.Function, + steps: step_lib.StepMemos, + ) -> types.MaybeError[None]: try: for m in self._middleware: - if isinstance(m, Middleware): + if inspect.iscoroutinefunction(m.transform_input): return _mismatched_sync - ctx = m.transform_input(ctx) - return ctx + m.transform_input(ctx, function, steps) except Exception as err: return err + # Also handle before_memoization here since it's always called + # immediately after transform_input + try: + for m in self._middleware: + if inspect.iscoroutinefunction(m.before_memoization): + return _mismatched_sync + m.before_memoization() + except Exception as err: + return err + + return None + async def transform_output( self, - output: typing.Optional[execution.Output], - ) -> types.MaybeError[typing.Optional[execution.Output]]: - # Nothing to transform - if output is None: + call_res: execution.CallResult, + ) -> types.MaybeError[None]: + # This should only happen when planning parallel steps + if call_res.multi is not None: + if len(call_res.multi) > 1: + return None + call_res = call_res.multi[0] + + # Not sure how this can happen, but we should handle it + if call_res.is_empty: return None + # Create a new result object to pass to the middleware. We don't want to + # pass the CallResult object because it exposes too many internal + # implementation details + result = TransformOutputResult( + error=call_res.error, + output=call_res.output, + step=None, + ) + if call_res.step is not None: + result.step = TransformOutputStepInfo( + id=call_res.step.display_name, + op=call_res.step.op, + opts=call_res.step.opts, + ) + try: for m in self._middleware: - output = await transforms.maybe_await( - m.transform_output(output) - ) - return output + await transforms.maybe_await(m.transform_output(result)) + + # Update the original call result with the (possibly) mutated fields + call_res.error = result.error + call_res.output = result.output + + return None except Exception as err: return err def transform_output_sync( self, - output: typing.Optional[execution.Output], - ) -> types.MaybeError[typing.Optional[execution.Output]]: - # Nothing to transform - if output is None: + call_res: execution.CallResult, + ) -> types.MaybeError[None]: + # This should only happen when planning parallel steps + if call_res.multi is not None: + if len(call_res.multi) > 1: + return None + call_res = call_res.multi[0] + + # Not sure how this can happen, but we should handle it + if call_res.is_empty: return None + # Create a new result object to pass to the middleware. We don't want to + # pass the CallResult object because it exposes too many internal + # implementation details + result = TransformOutputResult( + error=call_res.error, + output=call_res.output, + step=None, + ) + if call_res.step is not None: + result.step = TransformOutputStepInfo( + id=call_res.step.display_name, + op=call_res.step.op, + opts=call_res.step.opts, + ) + try: for m in self._middleware: if isinstance(m, Middleware): return _mismatched_sync - output = m.transform_output(output) - return output + m.transform_output(result) + + # Update the original call result with the (possibly) mutated fields + call_res.error = result.error + call_res.output = result.output + + return None except Exception as err: return err diff --git a/inngest/_internal/middleware_lib/middleware.py b/inngest/_internal/middleware_lib/middleware.py index b32c467..97395d5 100644 --- a/inngest/_internal/middleware_lib/middleware.py +++ b/inngest/_internal/middleware_lib/middleware.py @@ -1,16 +1,25 @@ from __future__ import annotations +import dataclasses import typing -from inngest._internal import event_lib, execution, function +from inngest._internal import event_lib, execution, function, step_lib if typing.TYPE_CHECKING: from inngest._internal import client_lib class Middleware: - def __init__(self, client: client_lib.Inngest) -> None: - self._client = client + def __init__(self, client: client_lib.Inngest, raw_request: object) -> None: + """ + Args: + ---- + client: Inngest client. + raw_request: Framework/platform specific request object. + """ + + self.client = client + self.raw_request = raw_request async def after_execution(self) -> None: """ @@ -19,6 +28,22 @@ async def after_execution(self) -> None: """ return None + async def after_memoization(self) -> None: + """ + After exhausting memoized step data. Always called immediately before + before_execution. + """ + return None + + async def after_send_events( + self, + result: client_lib.SendEventsResult, + ) -> None: + """ + After sending events. + """ + return None + async def before_execution(self) -> None: """ Before executing new code. Called multiple times per run when using @@ -26,12 +51,19 @@ async def before_execution(self) -> None: """ return None + async def before_memoization(self) -> None: + """ + Before checking memoized step data. Always called immediately after + transform_input. + """ + return None + async def before_response(self) -> None: """ After the output has been set and before the response is sent back to Inngest. This is where you can perform any final actions before the response is sent back to Inngest. Called multiple times per run when - using steps. Not called for function middleware. + using steps. """ return None @@ -44,30 +76,37 @@ async def before_send_events(self, events: list[event_lib.Event]) -> None: async def transform_input( self, ctx: function.Context, - ) -> function.Context: + function: function.Function, + steps: step_lib.StepMemos, + ) -> None: """ Before calling a function or step. Used to replace certain arguments in the function. Called multiple times per run when using steps. """ - return ctx + return None - async def transform_output( - self, - output: execution.Output, - ) -> execution.Output: + async def transform_output(self, result: TransformOutputResult) -> None: """ After a function or step returns. Used to modify the returned data. Called multiple times per run when using steps. Not called when an error is thrown. """ - return output + return None class MiddlewareSync: client: client_lib.Inngest - def __init__(self, client: client_lib.Inngest) -> None: + def __init__(self, client: client_lib.Inngest, raw_request: object) -> None: + """ + Args: + ---- + client: Inngest client. + raw_request: Framework/platform specific request object. + """ + self.client = client + self.raw_request = raw_request def after_execution(self) -> None: """ @@ -76,6 +115,22 @@ def after_execution(self) -> None: """ return None + def after_memoization(self) -> None: + """ + After exhausting memoized step data. Always called immediately before + before_execution. + """ + return None + + def after_send_events( + self, + result: client_lib.SendEventsResult, + ) -> None: + """ + After sending events. + """ + return None + def before_execution(self) -> None: """ Before executing new code. Called multiple times per run when using @@ -83,12 +138,19 @@ def before_execution(self) -> None: """ return None + def before_memoization(self) -> None: + """ + Before checking memoized step data. Always called immediately after + transform_input. + """ + return None + def before_response(self) -> None: """ After the output has been set and before the response is sent back to Inngest. This is where you can perform any final actions before the response is sent back to Inngest. Called multiple times per run when - using steps. Not called for function middleware. + using steps. """ return None @@ -101,26 +163,44 @@ def before_send_events(self, events: list[event_lib.Event]) -> None: def transform_input( self, ctx: function.Context, - ) -> function.Context: + function: function.Function, + steps: step_lib.StepMemos, + ) -> None: """ Before calling a function or step. Used to replace certain arguments in the function. Called multiple times per run when using steps. """ - return ctx + return None - def transform_output( - self, - output: execution.Output, - ) -> execution.Output: + def transform_output(self, result: TransformOutputResult) -> None: """ After a function or step returns. Used to modify the returned data. Called multiple times per run when using steps. Not called when an error is thrown. """ - return output + return None UninitializedMiddleware = typing.Callable[ # Used a "client_lib.Inngest" string to avoid a circular import - ["client_lib.Inngest"], typing.Union[Middleware, MiddlewareSync] + ["client_lib.Inngest", object], typing.Union[Middleware, MiddlewareSync] ] + + +@dataclasses.dataclass +class TransformOutputResult: + # Mutations to these fields within middleware will be kept after running + # middleware + error: typing.Optional[Exception] + output: object + + # Mutations to these fields within middleware will be discarded after + # running middleware + step: typing.Optional[TransformOutputStepInfo] + + +@dataclasses.dataclass +class TransformOutputStepInfo: + id: str + op: execution.Opcode + opts: typing.Optional[dict[str, object]] diff --git a/inngest/_internal/step_lib/__init__.py b/inngest/_internal/step_lib/__init__.py index 24ce57c..3f828df 100644 --- a/inngest/_internal/step_lib/__init__.py +++ b/inngest/_internal/step_lib/__init__.py @@ -1,15 +1,8 @@ -from .base import ( - FunctionID, - ResponseInterrupt, - SkipInterrupt, - StepIDCounter, - StepMemos, -) +from .base import ResponseInterrupt, SkipInterrupt, StepIDCounter, StepMemos from .step_async import Step from .step_sync import StepSync __all__ = [ - "FunctionID", "ResponseInterrupt", "SkipInterrupt", "Step", diff --git a/inngest/_internal/step_lib/base.py b/inngest/_internal/step_lib/base.py index c478935..07d1208 100644 --- a/inngest/_internal/step_lib/base.py +++ b/inngest/_internal/step_lib/base.py @@ -88,16 +88,12 @@ async def _get_memo( if not isinstance(memo, types.EmptySentinel): if memo.error is not None: - error = execution.MemoizedError.from_raw(memo.error) - if isinstance(error, Exception): - raise error - # If there's a memoized error then raise an error, since the # step exhausted its retries raise errors.StepError( - message=error.message, - name=error.name, - stack=error.stack, + message=memo.error.message, + name=memo.error.name, + stack=memo.error.stack, ) return memo @@ -114,16 +110,12 @@ def _get_memo_sync( if not isinstance(memo, types.EmptySentinel): if memo.error is not None: - error = execution.MemoizedError.from_raw(memo.error) - if isinstance(error, Exception): - raise error - # If there's a memoized error then raise an error, since the # step exhausted its retries raise errors.StepError( - message=error.message, - name=error.name, - stack=error.stack, + message=memo.error.message, + name=memo.error.name, + stack=memo.error.stack, ) return memo @@ -207,12 +199,6 @@ def __init__(self, step_id: str) -> None: self.step_id = step_id -@dataclasses.dataclass -class FunctionID: - app_id: str - function_id: str - - class InvokeOpts(types.BaseModel): function_id: str payload: InvokeOptsPayload diff --git a/inngest/_internal/step_lib/step_async.py b/inngest/_internal/step_lib/step_async.py index 498a7ed..8cbe9a4 100644 --- a/inngest/_internal/step_lib/step_async.py +++ b/inngest/_internal/step_lib/step_async.py @@ -4,6 +4,7 @@ import typing_extensions from inngest._internal import errors, event_lib, execution, transforms, types +from inngest._internal.client_lib import models as client_models from . import base @@ -21,7 +22,7 @@ async def invoke( *, function: Function, data: typing.Optional[types.JSON] = None, - timeout: typing.Union[int, datetime.timedelta, None] = None, + timeout: typing.Union[int, datetime.timedelta], user: typing.Optional[types.JSON] = None, v: typing.Optional[str] = None, ) -> object: @@ -60,7 +61,7 @@ async def invoke_by_id( app_id: typing.Optional[str] = None, function_id: str, data: typing.Optional[types.JSON] = None, - timeout: typing.Union[int, datetime.timedelta, None] = None, + timeout: typing.Union[int, datetime.timedelta], user: typing.Optional[types.JSON] = None, v: typing.Optional[str] = None, ) -> object: @@ -98,7 +99,7 @@ async def invoke_by_id( if isinstance(err, Exception): raise err - timeout_str = transforms.to_maybe_duration_str(timeout) + timeout_str = transforms.to_duration_str(timeout) if isinstance(timeout_str, Exception): raise timeout_str @@ -116,11 +117,13 @@ async def invoke_by_id( raise base.ResponseInterrupt( execution.StepResponse( - display_name=parsed_step_id.user_facing, - id=parsed_step_id.hashed, - name=parsed_step_id.user_facing, - op=execution.Opcode.INVOKE, - opts=opts, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + name=parsed_step_id.user_facing, + op=execution.Opcode.INVOKE, + opts=opts, + ) ) ) @@ -215,10 +218,12 @@ async def run( # Plan this step because we're in parallel mode. raise base.ResponseInterrupt( execution.StepResponse( - display_name=parsed_step_id.user_facing, - id=parsed_step_id.hashed, - name=parsed_step_id.user_facing, - op=execution.Opcode.PLANNED, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + name=parsed_step_id.user_facing, + op=execution.Opcode.PLANNED, + ) ) ) @@ -231,11 +236,13 @@ async def run( raise base.ResponseInterrupt( execution.StepResponse( - data=execution.Output(data=output), - display_name=parsed_step_id.user_facing, - id=parsed_step_id.hashed, - name=parsed_step_id.user_facing, - op=execution.Opcode.STEP_RUN, + output=output, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + name=parsed_step_id.user_facing, + op=execution.Opcode.STEP_RUN, + ), ) ) except (errors.NonRetriableError, errors.RetryAfterError) as err: @@ -244,16 +251,14 @@ async def run( except Exception as err: transforms.remove_first_traceback_frame(err) - error_dict = execution.MemoizedError.from_error(err).to_dict() - if isinstance(error_dict, Exception): - raise error_dict - raise base.ResponseInterrupt( execution.StepResponse( - data=execution.Output(error=error_dict), - display_name=parsed_step_id.user_facing, - id=parsed_step_id.hashed, - op=execution.Opcode.STEP_ERROR, + original_error=err, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + op=execution.Opcode.STEP_ERROR, + ), ) ) @@ -277,13 +282,35 @@ async def fn() -> list[str]: else: _events = [events] - await self._middleware.before_send_events(_events) - return await self._client.send( - events, - # Skip middleware since we're already running it above. Without - # this, we'll double-call middleware hooks - skip_middleware=True, - ) + middleware_err = await self._middleware.before_send_events(_events) + if isinstance(middleware_err, Exception): + raise middleware_err + + try: + result = client_models.SendEventsResult( + ids=( + await self._client.send( + events, + # Skip middleware since we're already running it above. Without + # this, we'll double-call middleware hooks + skip_middleware=True, + ) + ) + ) + except errors.SendEventsError as err: + result = client_models.SendEventsResult( + error=str(err), + ids=err.ids, + ) + raise err + finally: + middleware_err = await self._middleware.after_send_events( + result + ) + if isinstance(middleware_err, Exception): + raise middleware_err + + return result.ids return await self.run(step_id, fn) @@ -337,10 +364,12 @@ async def sleep_until( raise base.ResponseInterrupt( execution.StepResponse( - display_name=parsed_step_id.user_facing, - id=parsed_step_id.hashed, - name=transforms.to_iso_utc(until), - op=execution.Opcode.SLEEP, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + name=transforms.to_iso_utc(until), + op=execution.Opcode.SLEEP, + ) ) ) @@ -393,10 +422,12 @@ async def wait_for_event( raise base.ResponseInterrupt( execution.StepResponse( - id=parsed_step_id.hashed, - display_name=parsed_step_id.user_facing, - name=event, - op=execution.Opcode.WAIT_FOR_EVENT, - opts=opts, + step=execution.StepInfo( + id=parsed_step_id.hashed, + display_name=parsed_step_id.user_facing, + name=event, + op=execution.Opcode.WAIT_FOR_EVENT, + opts=opts, + ) ) ) diff --git a/inngest/_internal/step_lib/step_sync.py b/inngest/_internal/step_lib/step_sync.py index 100df6f..15b6f6b 100644 --- a/inngest/_internal/step_lib/step_sync.py +++ b/inngest/_internal/step_lib/step_sync.py @@ -4,6 +4,7 @@ import typing_extensions from inngest._internal import errors, event_lib, execution, transforms, types +from inngest._internal.client_lib import models as client_models from . import base @@ -21,7 +22,7 @@ def invoke( *, function: Function, data: typing.Optional[types.JSON] = None, - timeout: typing.Union[int, datetime.timedelta, None] = None, + timeout: typing.Union[int, datetime.timedelta], user: typing.Optional[types.JSON] = None, v: typing.Optional[str] = None, ) -> object: @@ -60,7 +61,7 @@ def invoke_by_id( app_id: typing.Optional[str] = None, function_id: str, data: typing.Optional[types.JSON] = None, - timeout: typing.Union[int, datetime.timedelta, None] = None, + timeout: typing.Union[int, datetime.timedelta], user: typing.Optional[types.JSON] = None, v: typing.Optional[str] = None, ) -> object: @@ -98,7 +99,7 @@ def invoke_by_id( if isinstance(err, Exception): raise err - timeout_str = transforms.to_maybe_duration_str(timeout) + timeout_str = transforms.to_duration_str(timeout) if isinstance(timeout_str, Exception): raise timeout_str @@ -116,11 +117,13 @@ def invoke_by_id( raise base.ResponseInterrupt( execution.StepResponse( - display_name=parsed_step_id.user_facing, - id=parsed_step_id.hashed, - name=parsed_step_id.user_facing, - op=execution.Opcode.INVOKE, - opts=opts, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + name=parsed_step_id.user_facing, + op=execution.Opcode.INVOKE, + opts=opts, + ) ) ) @@ -187,10 +190,12 @@ def run( # Plan this step because we're in parallel mode. raise base.ResponseInterrupt( execution.StepResponse( - display_name=parsed_step_id.user_facing, - id=parsed_step_id.hashed, - name=parsed_step_id.user_facing, - op=execution.Opcode.PLANNED, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + name=parsed_step_id.user_facing, + op=execution.Opcode.PLANNED, + ) ) ) @@ -203,11 +208,13 @@ def run( raise base.ResponseInterrupt( execution.StepResponse( - data=execution.Output(data=output), - display_name=parsed_step_id.user_facing, - id=parsed_step_id.hashed, - name=parsed_step_id.user_facing, - op=execution.Opcode.STEP_RUN, + output=output, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + name=parsed_step_id.user_facing, + op=execution.Opcode.STEP_RUN, + ), ) ) except (errors.NonRetriableError, errors.RetryAfterError) as err: @@ -216,16 +223,14 @@ def run( except Exception as err: transforms.remove_first_traceback_frame(err) - error_dict = execution.MemoizedError.from_error(err).to_dict() - if isinstance(error_dict, Exception): - raise error_dict - raise base.ResponseInterrupt( execution.StepResponse( - display_name=parsed_step_id.user_facing, - data=execution.Output(error=error_dict), - id=parsed_step_id.hashed, - op=execution.Opcode.STEP_ERROR, + original_error=err, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + op=execution.Opcode.STEP_ERROR, + ), ) ) @@ -249,13 +254,31 @@ def fn() -> list[str]: else: _events = [events] - self._middleware.before_send_events_sync(_events) - return self._client.send_sync( - events, - # Skip middleware since we're already running it above. Without - # this, we'll double-call middleware hooks - skip_middleware=True, - ) + middleware_err = self._middleware.before_send_events_sync(_events) + if isinstance(middleware_err, Exception): + raise middleware_err + + try: + result = client_models.SendEventsResult( + ids=self._client.send_sync( + events, + # Skip middleware since we're already running it above. Without + # this, we'll double-call middleware hooks + skip_middleware=True, + ) + ) + except errors.SendEventsError as err: + result = client_models.SendEventsResult( + error=str(err), + ids=err.ids, + ) + raise err + finally: + middleware_err = self._middleware.after_send_events_sync(result) + if isinstance(middleware_err, Exception): + raise middleware_err + + return result.ids return self.run(step_id, fn) @@ -309,10 +332,12 @@ def sleep_until( raise base.ResponseInterrupt( execution.StepResponse( - display_name=parsed_step_id.user_facing, - id=parsed_step_id.hashed, - name=transforms.to_iso_utc(until), - op=execution.Opcode.SLEEP, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + name=transforms.to_iso_utc(until), + op=execution.Opcode.SLEEP, + ) ) ) @@ -368,10 +393,12 @@ def wait_for_event( raise base.ResponseInterrupt( execution.StepResponse( - display_name=parsed_step_id.user_facing, - id=parsed_step_id.hashed, - name=event, - op=execution.Opcode.WAIT_FOR_EVENT, - opts=opts, + step=execution.StepInfo( + display_name=parsed_step_id.user_facing, + id=parsed_step_id.hashed, + name=event, + op=execution.Opcode.WAIT_FOR_EVENT, + opts=opts, + ) ) ) diff --git a/inngest/_internal/transforms.py b/inngest/_internal/transforms.py index c055872..66f4f4c 100644 --- a/inngest/_internal/transforms.py +++ b/inngest/_internal/transforms.py @@ -45,17 +45,15 @@ def remove_signing_key_prefix(key: str) -> str: return key[len(prefix) :] -def prep_body(obj: types.T) -> types.T: +def deep_strip_none(obj: types.T) -> types.T: """ - Prep body before sending to the Inngest server. This function will: - - Remove items whose value is None. - - Convert keys to camelCase. + Recursively remove items whose value is None. """ if isinstance(obj, dict): - return {k: prep_body(v) for k, v in obj.items() if v is not None} # type: ignore + return {k: deep_strip_none(v) for k, v in obj.items() if v is not None} # type: ignore if isinstance(obj, list): - return [prep_body(v) for v in obj if v is not None] # type: ignore + return [deep_strip_none(v) for v in obj if v is not None] # type: ignore return obj diff --git a/inngest/digital_ocean.py b/inngest/digital_ocean.py index bedcc13..e235726 100644 --- a/inngest/digital_ocean.py +++ b/inngest/digital_ocean.py @@ -133,6 +133,10 @@ def main(event: dict[str, object], context: _Context) -> _Response: handler.call_function_sync( call=call, fn_id=fn_id, + raw_request={ + "context": context, + "event": event, + }, req_sig=req_sig, target_hashed_id=step_id, ), diff --git a/inngest/django.py b/inngest/django.py index 0b78201..c594e0a 100644 --- a/inngest/django.py +++ b/inngest/django.py @@ -29,7 +29,6 @@ def serve( client: client_lib.Inngest, functions: list[function.Function], *, - async_mode: typing.Optional[bool] = None, serve_origin: typing.Optional[str] = None, serve_path: typing.Optional[str] = None, ) -> django.urls.URLPattern: @@ -53,12 +52,10 @@ def serve( functions=functions, ) - # TODO: Remove async_mode kwarg in v0.4.0 - if async_mode is None: - async_mode = any( - function.is_handler_async or function.is_on_failure_handler_async - for function in functions - ) + async_mode = any( + function.is_handler_async or function.is_on_failure_handler_async + for function in functions + ) if async_mode: return _create_handler_async( @@ -137,6 +134,7 @@ def inngest_api( handler.call_function_sync( call=call, fn_id=fn_id, + raw_request=request, req_sig=req_sig, target_hashed_id=step_id, ), @@ -242,6 +240,7 @@ async def inngest_api( await handler.call_function( call=call, fn_id=fn_id, + raw_request=request, req_sig=req_sig, target_hashed_id=step_id, ), diff --git a/inngest/experimental/encryption_middleware.py b/inngest/experimental/encryption_middleware.py index 40a1371..8922c80 100644 --- a/inngest/experimental/encryption_middleware.py +++ b/inngest/experimental/encryption_middleware.py @@ -27,16 +27,18 @@ class EncryptionMiddleware(inngest.MiddlewareSync): def __init__( self, client: inngest.Inngest, + raw_request: object, secret_key: typing.Union[bytes, str], ) -> None: """ Args: ---- client: Inngest client. + raw_request: Framework/platform specific request object. secret_key: Fernet secret key used for encryption and decryption. """ - super().__init__(client) + super().__init__(client, raw_request) if isinstance(secret_key, str): secret_key = bytes.fromhex(secret_key) @@ -47,7 +49,7 @@ def __init__( def factory( cls, secret_key: typing.Union[bytes, str], - ) -> typing.Callable[[inngest.Inngest], EncryptionMiddleware]: + ) -> typing.Callable[[inngest.Inngest, object], EncryptionMiddleware]: """ Create an encryption middleware factory that can be passed to an Inngest client or function. @@ -57,8 +59,11 @@ def factory( secret_key: Fernet secret key used for encryption and decryption. """ - def _factory(client: inngest.Inngest) -> EncryptionMiddleware: - return cls(client, secret_key) + def _factory( + client: inngest.Inngest, + raw_request: object, + ) -> EncryptionMiddleware: + return cls(client, raw_request, secret_key) return _factory @@ -114,12 +119,17 @@ def before_send_events(self, events: list[inngest.Event]) -> None: for event in events: event.data = self._encrypt(event.data) - def transform_input(self, ctx: inngest.Context) -> inngest.Context: + def transform_input( + self, + ctx: inngest.Context, + function: inngest.Function, + steps: inngest.StepMemos, + ) -> None: """ Decrypt data from the Inngest server. """ - for step in ctx._steps.values(): + for step in steps.values(): step.data = self._decrypt(step.data) ctx.event.data = self._decrypt_event_data(ctx.event.data) @@ -127,15 +137,13 @@ def transform_input(self, ctx: inngest.Context) -> inngest.Context: for event in ctx.events: event.data = self._decrypt_event_data(event.data) - return ctx - - def transform_output(self, output: inngest.Output) -> inngest.Output: + def transform_output(self, result: inngest.TransformOutputResult) -> None: """ Encrypt data before sending it to the Inngest server. """ - output.data = self._encrypt(output.data) - return output + if result.output is not None: + result.output = self._encrypt(result.output) def _is_encrypted(value: object) -> bool: diff --git a/inngest/experimental/sentry_middleware.py b/inngest/experimental/sentry_middleware.py new file mode 100644 index 0000000..80d77b2 --- /dev/null +++ b/inngest/experimental/sentry_middleware.py @@ -0,0 +1,56 @@ +""" +Sentry middleware for Inngest. + +NOT STABLE! This is an experimental feature and may change in the future. +""" + +from __future__ import annotations + +import sentry_sdk + +import inngest + + +class SentryMiddleware(inngest.MiddlewareSync): + """ + Middleware that adds Sentry tags and captures exceptions. + """ + + def __init__( + self, + client: inngest.Inngest, + raw_request: object, + ) -> None: + """ + Args: + ---- + client: Inngest client. + raw_request: Framework/platform specific request object. + """ + + super().__init__(client, raw_request) + + if sentry_sdk.is_initialized() is False: + client.logger.warning("Sentry SDK is not initialized") + + sentry_sdk.set_tag("inngest.app.id", client.app_id) + + def before_response(self) -> None: + sentry_sdk.flush() + + def transform_input( + self, + ctx: inngest.Context, + function: inngest.Function, + steps: inngest.StepMemos, + ) -> None: + sentry_sdk.set_tag("inngest.event.count", len(ctx.events)) + sentry_sdk.set_tag("inngest.event.id", ctx.event.id) + sentry_sdk.set_tag("inngest.event.name", ctx.event.name) + sentry_sdk.set_tag("inngest.function.id", function.local_id) + sentry_sdk.set_tag("inngest.function.name", function.name) + sentry_sdk.set_tag("inngest.run.id", ctx.run_id) + + def transform_output(self, output: inngest.TransformOutputResult) -> None: + if output.error: + sentry_sdk.capture_exception(output.error) diff --git a/inngest/fast_api.py b/inngest/fast_api.py index 1a26b3d..e918daf 100644 --- a/inngest/fast_api.py +++ b/inngest/fast_api.py @@ -100,6 +100,7 @@ async def post_inngest_api( await handler.call_function( call=call, fn_id=fnId, + raw_request=request, req_sig=net.RequestSignature( body=body, headers=headers, diff --git a/inngest/flask.py b/inngest/flask.py index f24fe4d..0dca1c5 100644 --- a/inngest/flask.py +++ b/inngest/flask.py @@ -132,6 +132,7 @@ async def inngest_api() -> typing.Union[flask.Response, str]: await handler.call_function( call=call, fn_id=fn_id, + raw_request=flask.request, req_sig=req_sig, target_hashed_id=step_id, ), @@ -222,6 +223,7 @@ def inngest_api() -> typing.Union[flask.Response, str]: handler.call_function_sync( call=call, fn_id=fn_id, + raw_request=flask.request, req_sig=req_sig, target_hashed_id=step_id, ), diff --git a/inngest/tornado.py b/inngest/tornado.py index 5616580..9e3860a 100644 --- a/inngest/tornado.py +++ b/inngest/tornado.py @@ -119,6 +119,7 @@ def post(self) -> None: comm_res = handler.call_function_sync( call=call, fn_id=fn_id, + raw_request=self.request, req_sig=req_sig, target_hashed_id=step_id, ) diff --git a/pyproject.toml b/pyproject.toml index e43048d..053b30c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ extra = [ "pytest-django==4.7.0", "pytest-xdist[psutil]==3.3.1", "ruff==0.1.9", + "sentry-sdk==2.1.1", "toml==0.10.2", "tornado==6.3", "types-toml==0.10.8.7", @@ -109,6 +110,7 @@ mccabe = { max-complexity = 21 } [tool.ruff.extend-per-file-ignores] "examples/**/*.py" = ['D', 'T20'] "inngest/**/*_test.py" = ['C901', 'D', 'N', 'S', 'T20'] +"inngest/experimental/**/*.py" = ['D102', 'D417'] "tests/**/*.py" = ['C901', 'D', 'N', 'S', 'T20'] [tool.ruff.lint] diff --git a/tests/test_client/__init__.py b/tests/test_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_client/test_client_middleware.py b/tests/test_client/test_client_middleware.py index 4f2a290..e96fc9c 100644 --- a/tests/test_client/test_client_middleware.py +++ b/tests/test_client/test_client_middleware.py @@ -42,21 +42,29 @@ def assert_state( # Assert that the middleware hooks were called in the correct order assert state.hook_list == [ "before_send_events", + "after_send_events", # Entry 1 "transform_input", + "before_memoization", + "after_memoization", "before_execution", "after_execution", "transform_output", "before_response", # Entry 2 "transform_input", + "before_memoization", + "after_memoization", "before_execution", "before_send_events", + "after_send_events", "after_execution", "transform_output", "before_response", # Entry 3 "transform_input", + "before_memoization", + "after_memoization", "before_execution", "after_execution", "transform_output", @@ -102,12 +110,24 @@ class Middleware(inngest.Middleware): async def after_execution(self) -> None: state.hook_list.append("after_execution") - async def before_response(self) -> None: - state.hook_list.append("before_response") + async def after_memoization(self) -> None: + state.hook_list.append("after_memoization") + + async def after_send_events( + self, + result: inngest.SendEventsResult, + ) -> None: + state.hook_list.append("after_send_events") async def before_execution(self) -> None: state.hook_list.append("before_execution") + async def before_memoization(self) -> None: + state.hook_list.append("before_memoization") + + async def before_response(self) -> None: + state.hook_list.append("before_response") + async def before_send_events( self, events: list[inngest.Event], @@ -117,18 +137,18 @@ async def before_send_events( async def transform_input( self, ctx: inngest.Context, - ) -> inngest.Context: + function: inngest.Function, + steps: inngest.StepMemos, + ) -> None: state.hook_list.append("transform_input") - return ctx async def transform_output( self, - output: inngest.Output, - ) -> inngest.Output: + result: inngest.TransformOutputResult, + ) -> None: state.hook_list.append("transform_output") - if output.data == "original output": - output.data = "transformed output" - return output + if result.output == "original output": + result.output = "transformed output" client = inngest.Inngest( api_base_url=dev_server.origin, @@ -191,12 +211,24 @@ class Middleware(inngest.MiddlewareSync): def after_execution(self) -> None: state.hook_list.append("after_execution") - def before_response(self) -> None: - state.hook_list.append("before_response") + def after_memoization(self) -> None: + state.hook_list.append("after_memoization") + + def after_send_events( + self, + result: inngest.SendEventsResult, + ) -> None: + state.hook_list.append("after_send_events") def before_execution(self) -> None: state.hook_list.append("before_execution") + def before_memoization(self) -> None: + state.hook_list.append("before_memoization") + + def before_response(self) -> None: + state.hook_list.append("before_response") + def before_send_events( self, events: list[inngest.Event], @@ -206,18 +238,18 @@ def before_send_events( def transform_input( self, ctx: inngest.Context, - ) -> inngest.Context: + function: inngest.Function, + steps: inngest.StepMemos, + ) -> None: state.hook_list.append("transform_input") - return ctx def transform_output( self, - output: inngest.Output, - ) -> inngest.Output: + result: inngest.TransformOutputResult, + ) -> None: state.hook_list.append("transform_output") - if output.data == "original output": - output.data = "transformed output" - return output + if result.output == "original output": + result.output = "transformed output" client = inngest.Inngest( api_base_url=dev_server.origin, diff --git a/tests/test_client/test_send.py b/tests/test_client/test_send.py index 6cf47c7..468b742 100644 --- a/tests/test_client/test_send.py +++ b/tests/test_client/test_send.py @@ -7,13 +7,13 @@ import inngest import inngest.flask from inngest._internal import const, errors -from tests import http_proxy +from tests import dev_server, http_proxy -class TestSend(unittest.TestCase): - def test_send_event_to_cloud_branch_env(self) -> None: +class TestSend(unittest.IsolatedAsyncioTestCase): + async def test_send_event_to_cloud_branch_env(self) -> None: """ - Test that the SDK correctly syncs itself with Cloud. + Test that the SDK sends the correct headers to Cloud. We need to use a mock Cloud since the Dev Server doesn't have a mode that simulates Cloud. @@ -55,6 +55,17 @@ def on_request( event_key=event_key, ) + await client.send(inngest.Event(name="foo")) + assert state.headers.get("X-Inngest-Env") == ["my-env"] + assert state.headers.get("X-Inngest-SDK") == [ + f"inngest-py:v{const.VERSION}" + ] + assert event_key in state.path + + # Clear test state + state.headers = {} + state.path = "" + client.send_sync(inngest.Event(name="foo")) assert state.headers.get("X-Inngest-Env") == ["my-env"] assert state.headers.get("X-Inngest-SDK") == [ @@ -71,6 +82,7 @@ async def test_many_parallel_sends(self) -> None: method_name = self._testMethodName client = inngest.Inngest( app_id=f"{class_name}-{method_name}", + event_api_base_url=f"http://localhost:{dev_server.PORT}", is_production=False, ) @@ -82,12 +94,41 @@ async def test_many_parallel_sends(self) -> None: await asyncio.gather(*sends) - def test_cloud_mode_without_event_key(self) -> None: + async def test_cloud_mode_without_event_key(self) -> None: client = inngest.Inngest(app_id="my-app") + with self.assertRaises(errors.EventKeyUnspecifiedError): + await client.send(inngest.Event(name="foo")) + with self.assertRaises(errors.EventKeyUnspecifiedError): client.send_sync(inngest.Event(name="foo")) + async def test_partial_send_error(self) -> None: + """ + Sending bulk events can result in a partial error. For example, sending + a valid event and an invalid event will result in 1 successfully sent + event and 1 error + """ + + client = inngest.Inngest( + app_id="my-app", + event_key="event-key-123abc", + event_api_base_url=f"http://localhost:{dev_server.PORT}", + is_production=False, + ) + + with self.assertRaises(errors.SendEventsError) as ctx: + await client.send( + [ + inngest.Event(name="foo"), + inngest.Event(name=""), + # This event will not be processed since the previous event + # is invalid + inngest.Event(name=""), + ] + ) + assert len(ctx.exception.ids) == 1 + if __name__ == "__main__": unittest.main() diff --git a/tests/test_function/cases/__init__.py b/tests/test_function/cases/__init__.py index d546950..0cb3964 100644 --- a/tests/test_function/cases/__init__.py +++ b/tests/test_function/cases/__init__.py @@ -20,6 +20,7 @@ invoke_failure, invoke_timeout, logger, + middleware_parallel_steps, multiple_triggers, no_cancel_if_exp_not_match, no_steps, @@ -60,6 +61,7 @@ invoke_failure, invoke_timeout, logger, + middleware_parallel_steps, multiple_triggers, no_cancel_if_exp_not_match, no_steps, diff --git a/tests/test_function/cases/change_step_error.py b/tests/test_function/cases/change_step_error.py index 596769a..ba3adc0 100644 --- a/tests/test_function/cases/change_step_error.py +++ b/tests/test_function/cases/change_step_error.py @@ -77,7 +77,6 @@ def run_test(self: base.TestClass) -> None: output = json.loads(run.output) assert output == { "code": "unknown", - "is_retriable": True, "message": "I am new", "name": "MyError", "stack": unittest.mock.ANY, diff --git a/tests/test_function/cases/function_middleware.py b/tests/test_function/cases/function_middleware.py index 8c4158d..a798ef5 100644 --- a/tests/test_function/cases/function_middleware.py +++ b/tests/test_function/cases/function_middleware.py @@ -1,7 +1,13 @@ import json +import django.core.handlers.wsgi +import fastapi +import tornado.httputil +import werkzeug.local + import inngest import tests.helper +from inngest._internal import const from . import base @@ -10,7 +16,8 @@ class _State(base.BaseState): def __init__(self) -> None: - self.hook_list: list[str] = [] + self.messages: list[str] = [] + self.raw_request: object = None def create( @@ -24,69 +31,105 @@ def create( state = _State() class _MiddlewareSync(inngest.MiddlewareSync): + def __init__( + self, + client: inngest.Inngest, + raw_request: object, + ) -> None: + super().__init__(client, raw_request) + state.raw_request = raw_request + def after_execution(self) -> None: - state.hook_list.append("after_execution") + state.messages.append("hook:after_execution") - def before_response(self) -> None: - # This hook is not called for function middleware but we'll include - # in anyway to verify that. - state.hook_list.append("before_response") + def after_memoization(self) -> None: + state.messages.append("hook:after_memoization") + + def after_send_events( + self, + result: inngest.SendEventsResult, + ) -> None: + state.messages.append("hook:after_send_events") def before_execution(self) -> None: - state.hook_list.append("before_execution") + state.messages.append("hook:before_execution") + + def before_memoization(self) -> None: + state.messages.append("hook:before_memoization") + + def before_response(self) -> None: + state.messages.append("hook:before_response") def before_send_events(self, events: list[inngest.Event]) -> None: - state.hook_list.append("before_send_events") + state.messages.append("hook:before_send_events") def transform_input( self, ctx: inngest.Context, - ) -> inngest.Context: - state.hook_list.append("transform_input") - return ctx + function: inngest.Function, + steps: inngest.StepMemos, + ) -> None: + state.messages.append("hook:transform_input") def transform_output( self, - output: inngest.Output, - ) -> inngest.Output: - state.hook_list.append("transform_output") - if output.data == "original output": - output.data = "transformed output" - return output + result: inngest.TransformOutputResult, + ) -> None: + state.messages.append("hook:transform_output") + if result.output == "original output": + result.output = "transformed output" class _MiddlewareAsync(inngest.Middleware): + def __init__( + self, + client: inngest.Inngest, + raw_request: object, + ) -> None: + super().__init__(client, raw_request) + state.raw_request = raw_request + async def after_execution(self) -> None: - state.hook_list.append("after_execution") + state.messages.append("hook:after_execution") - async def before_response(self) -> None: - # This hook is not called for function middleware but we'll include - # in anyway to verify that. - state.hook_list.append("before_response") + async def after_memoization(self) -> None: + state.messages.append("hook:after_memoization") + + async def after_send_events( + self, + result: inngest.SendEventsResult, + ) -> None: + state.messages.append("hook:after_send_events") async def before_execution(self) -> None: - state.hook_list.append("before_execution") + state.messages.append("hook:before_execution") + + async def before_memoization(self) -> None: + state.messages.append("hook:before_memoization") + + async def before_response(self) -> None: + state.messages.append("hook:before_response") async def before_send_events( self, events: list[inngest.Event], ) -> None: - state.hook_list.append("before_send_events") + state.messages.append("hook:before_send_events") async def transform_input( self, ctx: inngest.Context, - ) -> inngest.Context: - state.hook_list.append("transform_input") - return ctx + function: inngest.Function, + steps: inngest.StepMemos, + ) -> None: + state.messages.append("hook:transform_input") async def transform_output( self, - output: inngest.Output, - ) -> inngest.Output: - state.hook_list.append("transform_output") - if output.data == "original output": - output.data = "transformed output" - return output + result: inngest.TransformOutputResult, + ) -> None: + state.messages.append("hook:transform_output") + if result.output == "original output": + result.output = "transformed output" @client.create_function( fn_id=fn_id, @@ -103,8 +146,11 @@ def fn_sync( def _step_1() -> str: return "original output" + state.messages.append("fn_logic: before step_1") step.run("step_1", _step_1) + state.messages.append("fn_logic: after step_1") step.send_event("send", [inngest.Event(name="dummy")]) + state.messages.append("fn_logic: after send") @client.create_function( fn_id=fn_id, @@ -121,8 +167,11 @@ async def fn_async( async def _step_1() -> str: return "original output" + state.messages.append("fn_logic: before step_1") await step.run("step_1", _step_1) + state.messages.append("fn_logic: after step_1") await step.send_event("send", [inngest.Event(name="dummy")]) + state.messages.append("fn_logic: after send") def run_test(self: base.TestClass) -> None: self.client.send_sync(inngest.Event(name=event_name)) @@ -132,24 +181,57 @@ def run_test(self: base.TestClass) -> None: tests.helper.RunStatus.COMPLETED, ) + if framework == const.Framework.DIGITAL_OCEAN.value: + assert isinstance(state.raw_request, dict) + elif framework == const.Framework.DJANGO.value: + assert isinstance( + state.raw_request, django.core.handlers.wsgi.WSGIRequest + ) + elif framework == const.Framework.FAST_API.value: + assert isinstance(state.raw_request, fastapi.Request) + elif framework == const.Framework.FLASK.value: + assert isinstance(state.raw_request, werkzeug.local.LocalProxy) + elif framework == const.Framework.TORNADO.value: + assert isinstance( + state.raw_request, tornado.httputil.HTTPServerRequest + ) + else: + raise ValueError(f"unknown framework: {framework}") + # Assert that the middleware hooks were called in the correct order - assert state.hook_list == [ + assert state.messages == [ # Entry 1 - "transform_input", - "before_execution", - "after_execution", - "transform_output", + "hook:transform_input", + "hook:before_memoization", + "hook:after_memoization", + "hook:before_execution", + "fn_logic: before step_1", + "hook:after_execution", + "hook:transform_output", + "hook:before_response", # Entry 2 - "transform_input", - "before_execution", - "before_send_events", - "after_execution", - "transform_output", + "hook:transform_input", + "hook:before_memoization", + "fn_logic: before step_1", + "hook:after_memoization", + "hook:before_execution", + "fn_logic: after step_1", + "hook:before_send_events", + "hook:after_send_events", + "hook:after_execution", + "hook:transform_output", + "hook:before_response", # Entry 3 - "transform_input", - "before_execution", - "after_execution", - "transform_output", + "hook:transform_input", + "hook:before_memoization", + "fn_logic: before step_1", + "fn_logic: after step_1", + "hook:after_memoization", + "hook:before_execution", + "fn_logic: after send", + "hook:after_execution", + "hook:transform_output", + "hook:before_response", ] step_1_output = json.loads( diff --git a/tests/test_function/cases/invoke_by_id.py b/tests/test_function/cases/invoke_by_id.py index 383eca7..6090450 100644 --- a/tests/test_function/cases/invoke_by_id.py +++ b/tests/test_function/cases/invoke_by_id.py @@ -45,6 +45,7 @@ def fn_sender_sync( "invoke", app_id=client.app_id, function_id=f"{fn_id}/invokee", + timeout=60_000, ) @client.create_function( @@ -72,6 +73,7 @@ async def fn_sender_async( "invoke", app_id=client.app_id, function_id=f"{fn_id}/invokee", + timeout=60_000, ) def run_test(self: base.TestClass) -> None: diff --git a/tests/test_function/cases/invoke_by_object.py b/tests/test_function/cases/invoke_by_object.py index 1b642c6..5e323c5 100644 --- a/tests/test_function/cases/invoke_by_object.py +++ b/tests/test_function/cases/invoke_by_object.py @@ -44,6 +44,7 @@ def fn_sender_sync( state.step_output = step.invoke( "invoke", function=fn_receiver_sync, + timeout=60_000, ) @client.create_function( @@ -70,6 +71,7 @@ async def fn_sender_async( state.step_output = await step.invoke( "invoke", function=fn_receiver_async, + timeout=60_000, ) def run_test(self: base.TestClass) -> None: diff --git a/tests/test_function/cases/invoke_failure.py b/tests/test_function/cases/invoke_failure.py index ae77c48..0589c61 100644 --- a/tests/test_function/cases/invoke_failure.py +++ b/tests/test_function/cases/invoke_failure.py @@ -57,6 +57,7 @@ def fn_sender_sync( step.invoke( "invoke", function=fn_receiver_sync, + timeout=60_000, ) except inngest.StepError as err: state.raised_error = err @@ -88,6 +89,7 @@ async def fn_sender_async( await step.invoke( "invoke", function=fn_receiver_async, + timeout=60_000, ) except inngest.StepError as err: state.raised_error = err diff --git a/tests/test_function/cases/invoke_timeout.py b/tests/test_function/cases/invoke_timeout.py index 3e96abe..94b22bd 100644 --- a/tests/test_function/cases/invoke_timeout.py +++ b/tests/test_function/cases/invoke_timeout.py @@ -88,9 +88,9 @@ def run_test(self: base.TestClass) -> None: assert run.output is not None assert json.loads(run.output) == { "code": "step_errored", - "is_retriable": False, "message": "Timed out waiting for invoked function to complete", "name": "InngestInvokeTimeoutError", + "stack": None, } if is_sync: diff --git a/tests/test_function/cases/middleware_parallel_steps.py b/tests/test_function/cases/middleware_parallel_steps.py new file mode 100644 index 0000000..206fe13 --- /dev/null +++ b/tests/test_function/cases/middleware_parallel_steps.py @@ -0,0 +1,139 @@ +import inngest +import tests.helper +from inngest._internal import execution, middleware_lib + +from . import base + +_TEST_NAME = "middleware_parallel_steps" + + +class _State(base.BaseState): + def __init__(self) -> None: + self.results: list[inngest.TransformOutputResult] = [] + self.messages: list[str] = [] + + +def create( + client: inngest.Inngest, + framework: str, + is_sync: bool, +) -> base.Case: + test_name = base.create_test_name(_TEST_NAME, is_sync) + event_name = base.create_event_name(framework, test_name) + fn_id = base.create_fn_id(test_name) + state = _State() + + class _Middleware(inngest.MiddlewareSync): + def transform_output( + self, + result: inngest.TransformOutputResult, + ) -> None: + state.results.append(result) + + @client.create_function( + fn_id=fn_id, + middleware=[_Middleware], + retries=0, + trigger=inngest.TriggerEvent(event=event_name), + ) + def fn_sync( + ctx: inngest.Context, + step: inngest.StepSync, + ) -> str: + state.run_id = ctx.run_id + + step.parallel( + ( + lambda: step.run("1.1", lambda: "1.1 (step)"), + lambda: step.run("1.2", lambda: "1.2 (step)"), + ) + ) + + return "2 (fn)" + + @client.create_function( + fn_id=fn_id, + middleware=[_Middleware], + retries=0, + trigger=inngest.TriggerEvent(event=event_name), + ) + async def fn_async( + ctx: inngest.Context, + step: inngest.Step, + ) -> str: + state.run_id = ctx.run_id + + await step.parallel( + ( + lambda: step.run("1.1", lambda: "1.1 (step)"), + lambda: step.run("1.2", lambda: "1.2 (step)"), + ) + ) + + return "2 (fn)" + + def run_test(self: base.TestClass) -> None: + self.client.send_sync(inngest.Event(name=event_name)) + run_id = state.wait_for_run_id() + tests.helper.client.wait_for_run_status( + run_id, + tests.helper.RunStatus.COMPLETED, + ) + + results = sorted(state.results, key=lambda x: str(x.output)) + + if len(results) == 4: + # The last request (the function return) usually happens twice but + # sometimes only once. This is probably a race condition between the + # Executor and SDK, so we'll pop the "extra" result if it exists + results.pop() + + _assert_results( + results, + [ + inngest.TransformOutputResult( + error=None, + output="1.1 (step)", + step=middleware_lib.TransformOutputStepInfo( + id="1.1", + op=execution.Opcode.STEP_RUN, + opts=None, + ), + ), + inngest.TransformOutputResult( + error=None, + output="1.2 (step)", + step=middleware_lib.TransformOutputStepInfo( + id="1.2", + op=execution.Opcode.STEP_RUN, + opts=None, + ), + ), + inngest.TransformOutputResult( + error=None, + output="2 (fn)", + step=None, + ), + ], + ) + + if is_sync: + fn = fn_sync + else: + fn = fn_async + + return base.Case( + fn=fn, + run_test=run_test, + name=test_name, + ) + + +def _assert_results( + actual: list[inngest.TransformOutputResult], + expected: list[inngest.TransformOutputResult], +) -> None: + assert len(actual) == len(expected) + + for i, (a, e) in enumerate(zip(actual, expected)): + assert a.__dict__ == e.__dict__, f"index={i}" diff --git a/tests/test_function/cases/non_retriable_error.py b/tests/test_function/cases/non_retriable_error.py index e9c3ac5..0e150e7 100644 --- a/tests/test_function/cases/non_retriable_error.py +++ b/tests/test_function/cases/non_retriable_error.py @@ -72,7 +72,6 @@ def assert_output() -> None: assert output == { "code": "non_retriable_error", - "is_retriable": False, "message": "foo", "name": "NonRetriableError", "stack": unittest.mock.ANY, diff --git a/tests/test_function/cases/on_failure.py b/tests/test_function/cases/on_failure.py index edf0cdc..b60ef69 100644 --- a/tests/test_function/cases/on_failure.py +++ b/tests/test_function/cases/on_failure.py @@ -104,7 +104,6 @@ def assert_is_done() -> None: output = json.loads(run.output) assert output == { "code": "unknown", - "is_retriable": True, "message": "intentional failure", "name": "MyError", "stack": unittest.mock.ANY, diff --git a/tests/test_function/cases/pydantic_output.py b/tests/test_function/cases/pydantic_output.py index a4f27b7..04de1bd 100644 --- a/tests/test_function/cases/pydantic_output.py +++ b/tests/test_function/cases/pydantic_output.py @@ -1,12 +1,10 @@ """ -We don't officially support returning a Pydantic object from a step. Returning a -Pydantic object fails a type check, however it'll be converted to a dict at -runtime. Users may be relying on this behavior, so it's probably best to avoid -fixing it. - -Note that returning a Pydantic object from a function will fail at runtime. +We don't support returning Pydantic models in steps or functions. This may +change in the future. """ +import json + import pydantic import inngest @@ -67,13 +65,17 @@ async def fn_async( def run_test(self: base.TestClass) -> None: self.client.send_sync(inngest.Event(name=event_name)) - tests.helper.client.wait_for_run_status( + run = tests.helper.client.wait_for_run_status( state.wait_for_run_id(), - tests.helper.RunStatus.COMPLETED, + tests.helper.RunStatus.FAILED, ) - user = _User.model_validate(state.step_output) - assert user.name == "Alice" + assert run.output is not None + assert json.loads(run.output) == { + "code": "output_unserializable", + "message": '"a" returned unserializable data', + "name": "OutputUnserializableError", + } if is_sync: fn = fn_sync diff --git a/tests/test_function/cases/two_steps.py b/tests/test_function/cases/two_steps.py index c0c0eea..a9f4231 100644 --- a/tests/test_function/cases/two_steps.py +++ b/tests/test_function/cases/two_steps.py @@ -35,9 +35,9 @@ def fn_sync( ) -> None: state.run_id = ctx.run_id - def step_1() -> list[dict[str, dict[str, int]]]: + def step_1() -> list[dict[str, inngest.JSON]]: state.step_1_counter += 1 - return [{"foo": {"bar": 1}}] + return [{"foo": {"bar": 1}, "empty": None}] state.step_1_output = step.run("step_1", step_1) @@ -57,9 +57,9 @@ async def fn_async( ) -> None: state.run_id = ctx.run_id - async def step_1() -> list[dict[str, dict[str, int]]]: + async def step_1() -> list[dict[str, inngest.JSON]]: state.step_1_counter += 1 - return [{"foo": {"bar": 1}}] + return [{"foo": {"bar": 1}, "empty": None}] state.step_1_output = await step.run("step_1", step_1) @@ -78,7 +78,7 @@ def run_test(self: base.TestClass) -> None: assert state.step_1_counter == 1 assert state.step_2_counter == 1 - assert state.step_1_output == [{"foo": {"bar": 1}}] + assert state.step_1_output == [{"empty": None, "foo": {"bar": 1}}] step_1_output_in_api = json.loads( tests.helper.client.get_step_output( @@ -86,7 +86,9 @@ def run_test(self: base.TestClass) -> None: step_id="step_1", ) ) - assert step_1_output_in_api == {"data": [{"foo": {"bar": 1}}]} + assert step_1_output_in_api == { + "data": [{"empty": None, "foo": {"bar": 1}}] + } if is_sync: fn = fn_sync diff --git a/tests/test_function/cases/unexpected_step_during_targeting.py b/tests/test_function/cases/unexpected_step_during_targeting.py index 2d2555a..744968f 100644 --- a/tests/test_function/cases/unexpected_step_during_targeting.py +++ b/tests/test_function/cases/unexpected_step_during_targeting.py @@ -106,9 +106,9 @@ def run_test(self: base.TestClass) -> None: output = json.loads(run.output) assert output == { "code": "step_unexpected", - "is_retriable": True, "message": 'found step "unexpected" when targeting a different step', "name": "StepUnexpectedError", + "stack": None, } # None of the step callbacks were called