From edf5070c38ad8cd563a7a045d8d4370d4cbe810b Mon Sep 17 00:00:00 2001 From: Martin Turoci Date: Mon, 14 Aug 2023 14:29:08 +0200 Subject: [PATCH] chore: Add the same routing improvements to lightwaveas well. #1484 --- py/h2o_lightwave/h2o_lightwave/routing.py | 90 ++++++++++++++++++++--- py/h2o_wave/h2o_wave/routing.py | 5 +- 2 files changed, 81 insertions(+), 14 deletions(-) diff --git a/py/h2o_lightwave/h2o_lightwave/routing.py b/py/h2o_lightwave/h2o_lightwave/routing.py index ee20a3b598..b6400d9e81 100644 --- a/py/h2o_lightwave/h2o_lightwave/routing.py +++ b/py/h2o_lightwave/h2o_lightwave/routing.py @@ -14,7 +14,6 @@ from typing import Optional, Callable from inspect import signature -import asyncio import logging from starlette.routing import compile_path from .core import expando_to_dict @@ -25,6 +24,8 @@ _event_handlers = {} # dictionary of event_source => [(event_type, predicate, handler)] _arg_handlers = {} # dictionary of arg_name => [(predicate, handler)] _path_handlers = [] +_arg_with_params_handlers = [] +_handle_on_deprecated_warning_printed = False def _get_arity(func: Callable) -> int: @@ -86,9 +87,8 @@ def wrap(func): # if not asyncio.iscoroutinefunction(func): # raise ValueError(f"@on function '{func_name}' must be async") - if predicate: - if not callable(predicate): - raise ValueError(f"@on predicate must be callable for '{func_name}'") + if predicate and not callable(predicate): + raise ValueError(f"@on predicate must be callable for '{func_name}'") if isinstance(arg, str) and len(arg): if arg.startswith('#'): # location hash rx, _, conv = compile_path(arg[1:]) @@ -100,6 +100,9 @@ def wrap(func): if not len(event): raise ValueError(f"@on event type cannot be empty in '{arg}' for '{func_name}'") _add_event_handler(source, event, func, predicate) + elif "{" in arg and "}" in arg: + rx, _, conv = compile_path(arg) + _arg_with_params_handlers.append((predicate, func, _get_arity(func), rx, conv)) else: _add_handler(arg, func, predicate) else: @@ -110,28 +113,32 @@ def wrap(func): return wrap -async def _invoke_handler(func: Callable, arity: int, q: Q, arg: any): +async def _invoke_handler(func: Callable, arity: int, q: Q, arg: any, **params: any): if arity == 0: await func() elif arity == 1: await func(q) - else: + elif len(params) == 0: await func(q, arg) + elif arity == len(params) + 1: + await func(q, **params) + else: + await func(q, arg, **params) -async def _match_predicate(predicate: Callable, func: Callable, arity: int, q: Q, arg: any) -> bool: +async def _match_predicate(predicate: Callable, func: Callable, arity: int, q: Q, arg: any, **params: any) -> bool: if predicate: if predicate(arg): - await _invoke_handler(func, arity, q, arg) + await _invoke_handler(func, arity, q, arg, **params) return True else: - if arg: - await _invoke_handler(func, arity, q, arg) + if arg is not None: + await _invoke_handler(func, arity, q, arg, **params) return True return False -async def handle_on(q: Q) -> bool: +async def run_on(q: Q) -> bool: """ Handle the query using a query handler (a function annotated with `@on()`). @@ -141,6 +148,67 @@ async def handle_on(q: Q) -> bool: Returns: True if a matching query handler was found and invoked, else False. """ + submitted = str(q.args['__wave_submission_name__']) + + # Event handlers. + for event_source in expando_to_dict(q.events): + for entry in _event_handlers.get(event_source, []): + event_type, predicate, func, arity = entry + event = q.events[event_source] + if event_type in event: + arg_value = event[event_type] + if await _match_predicate(predicate, func, arity, q, arg_value): + return True + + # Hash handlers. + if submitted == '#': + for rx, conv, func, arity in _path_handlers: + match = rx.match(q.args[submitted]) + if match: + params = match.groupdict() + for key, value in params.items(): + params[key] = conv[key].convert(value) + if len(params): + if arity <= 1: + await _invoke_handler(func, arity, q, None) + else: + await func(q, **params) + else: + await _invoke_handler(func, arity, q, None) + return True + + # Arg handlers. + for entry in _arg_handlers.get(submitted, []): + predicate, func, arity = entry + if await _match_predicate(predicate, func, arity, q, q.args[submitted]): + return True + for predicate, func, arity, rx, conv in _arg_with_params_handlers: + match = rx.match(submitted) + if match: + params = match.groupdict() + for key, value in params.items(): + params[key] = conv[key].convert(value) + if await _match_predicate(predicate, func, arity, q, q.args[submitted], **params): + return True + + return False + + +async def handle_on(q: Q) -> bool: + """ + DEPRECATED: Handle the query using a query handler (a function annotated with `@on()`). + + Args: + q: The query context. + + Returns: + True if a matching query handler was found and invoked, else False. + """ + global _handle_on_deprecated_warning_printed + if not _handle_on_deprecated_warning_printed: + print('\033[93m' + 'WARNING: handle_on() is deprecated, use run_on() instead.' + '\033[0m') + _handle_on_deprecated_warning_printed = True + event_sources = expando_to_dict(q.events) for event_source in event_sources: event = q.events[event_source] diff --git a/py/h2o_wave/h2o_wave/routing.py b/py/h2o_wave/h2o_wave/routing.py index 13dbdcc63b..b6400d9e81 100644 --- a/py/h2o_wave/h2o_wave/routing.py +++ b/py/h2o_wave/h2o_wave/routing.py @@ -87,9 +87,8 @@ def wrap(func): # if not asyncio.iscoroutinefunction(func): # raise ValueError(f"@on function '{func_name}' must be async") - if predicate: - if not callable(predicate): - raise ValueError(f"@on predicate must be callable for '{func_name}'") + if predicate and not callable(predicate): + raise ValueError(f"@on predicate must be callable for '{func_name}'") if isinstance(arg, str) and len(arg): if arg.startswith('#'): # location hash rx, _, conv = compile_path(arg[1:])