diff --git a/docs/interceptors.md b/docs/interceptors.md index f722c6e..8f4da83 100644 --- a/docs/interceptors.md +++ b/docs/interceptors.md @@ -175,65 +175,91 @@ to be able to intercept RPC messages. However, many interceptors, such as for au tracing, only need access to headers and not messages. Connect provides a metadata interceptor protocol that can be implemented to work with any RPC type. -An authentication interceptor checking bearer tokens may look like this: +An authentication interceptor checking bearer tokens and storing them to a context variable may look like this: === "ASGI" ```python - class AuthInterceptor: + from contextvars import ContextVar, Token + + _auth_token = ContextVar["auth_token"]("current_auth_token") + + class ServerAuthInterceptor: def __init__(self, valid_tokens: list[str]): self._valid_tokens = valid_tokens - async def on_start(self, ctx: RequestContext): + async def on_start(self, ctx: RequestContext) -> Token["auth_token"]: authorization = ctx.request_headers().get("authorization") if not authorization or not authorization.startswith("Bearer "): raise ConnectError(Code.UNAUTHENTICATED) token = authorization[len("Bearer "):] - if token not in valid_tokens: + if token not in self._valid_tokens: raise ConnectError(Code.PERMISSION_DENIED) + return _auth_token.set(token) + + async def on_end(self, token: Token["auth_token"], ctx: RequestContext): + _auth_token.reset(token) ``` === "WSGI" ```python - class AuthInterceptor: + from contextvars import ContextVar, Token + + _auth_token = ContextVar["auth_token"]("current_auth_token") + + class ServerAuthInterceptor: def __init__(self, valid_tokens: list[str]): self._valid_tokens = valid_tokens - def on_start(self, ctx: RequestContext): + def on_start(self, ctx: RequestContext) -> Token["auth_token"]: authorization = ctx.request_headers().get("authorization") if not authorization or not authorization.startswith("Bearer "): raise ConnectError(Code.UNAUTHENTICATED) token = authorization[len("Bearer "):] - if token not in valid_tokens: + if token not in self._valid_tokens: raise ConnectError(Code.PERMISSION_DENIED) + return _auth_token.set(token) + + def on_end(self, token: Token["auth_token"], ctx: RequestContext): + _auth_token.reset(token) ``` -`on_start` can return any value, which is passed to the optional `on_end` method. This can be -used, for example, to record the time of execution for the method. +`on_start` can return any value, which is passed to the optional `on_end` method. Here, we +return the token to reset the context variable. -=== "ASGI" +Clients can add an interceptor that reads the token from the context variable and populates +the authorization header. + +=== "Async" ```python - import time + from contextvars import ContextVar - class TimingInterceptor: - async def on_start(self, ctx: RequestContext) -> float: - return time.perf_counter() + _auth_token = ContextVar["auth_token"]("current_auth_token") - async def on_end(self, token: float, ctx: RequestContext): - print(f"Method took {} seconds.", token - time.perf_counter()) + class ClientAuthInterceptor: + async def on_start(self, ctx: RequestContext): + auth_token = _auth_token.get(None) + if auth_token: + ctx.request_headers()["authorization"] = f"Bearer {auth_token}" ``` -=== "WSGI" +=== "Sync" ```python - import time + from contextvars import ContextVar - class TimingInterceptor: - def on_start(self, ctx: RequestContext): - return time.perf_counter() + _auth_token = ContextVar["auth_token"]("current_auth_token") - def on_end(self, token: float, ctx: RequestContext): - print(f"Method took {} seconds.", token - time.perf_counter()) + class ClientAuthInterceptor: + def on_start(self, ctx: RequestContext): + auth_token = _auth_token.get(None) + if auth_token: + ctx.request_headers()["authorization"] = f"Bearer {auth_token}" ``` + +Note that in the client interceptor, we do not need to define `on_end`. + +The above interceptors would allow a server to receive and validate an auth token and automatically +propagate it to the authorization header of backend calls.