Skip to content

Commit

Permalink
chore: Add the same routing improvements to lightwaveas well. #1484
Browse files Browse the repository at this point in the history
  • Loading branch information
mturoci committed Sep 19, 2023
1 parent 2cc0783 commit c126743
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 14 deletions.
90 changes: 79 additions & 11 deletions py/h2o_lightwave/h2o_lightwave/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from typing import Optional, Callable
from inspect import signature
import asyncio
import logging
from starlette.routing import compile_path
from .core import expando_to_dict
Expand All @@ -25,6 +24,8 @@
_event_handlers = {} # dictionary of event_source => [(event_type, predicate, handler)]
_arg_handlers = {} # dictionary of arg_name => [(predicate, handler)]
_path_handlers = []
_arg_with_params_handlers = []
_handle_on_deprecated_warning_printed = False


def _get_arity(func: Callable) -> int:
Expand Down Expand Up @@ -86,9 +87,8 @@ def wrap(func):
# if not asyncio.iscoroutinefunction(func):
# raise ValueError(f"@on function '{func_name}' must be async")

if predicate:
if not callable(predicate):
raise ValueError(f"@on predicate must be callable for '{func_name}'")
if predicate and not callable(predicate):
raise ValueError(f"@on predicate must be callable for '{func_name}'")
if isinstance(arg, str) and len(arg):
if arg.startswith('#'): # location hash
rx, _, conv = compile_path(arg[1:])
Expand All @@ -100,6 +100,9 @@ def wrap(func):
if not len(event):
raise ValueError(f"@on event type cannot be empty in '{arg}' for '{func_name}'")
_add_event_handler(source, event, func, predicate)
elif "{" in arg and "}" in arg:
rx, _, conv = compile_path(arg)
_arg_with_params_handlers.append((predicate, func, _get_arity(func), rx, conv))
else:
_add_handler(arg, func, predicate)
else:
Expand All @@ -110,28 +113,32 @@ def wrap(func):
return wrap


async def _invoke_handler(func: Callable, arity: int, q: Q, arg: any):
async def _invoke_handler(func: Callable, arity: int, q: Q, arg: any, **params: any):
if arity == 0:
await func()
elif arity == 1:
await func(q)
else:
elif len(params) == 0:
await func(q, arg)
elif arity == len(params) + 1:
await func(q, **params)
else:
await func(q, arg, **params)


async def _match_predicate(predicate: Callable, func: Callable, arity: int, q: Q, arg: any) -> bool:
async def _match_predicate(predicate: Callable, func: Callable, arity: int, q: Q, arg: any, **params: any) -> bool:
if predicate:
if predicate(arg):
await _invoke_handler(func, arity, q, arg)
await _invoke_handler(func, arity, q, arg, **params)
return True
else:
if arg:
await _invoke_handler(func, arity, q, arg)
if arg is not None:
await _invoke_handler(func, arity, q, arg, **params)
return True
return False


async def handle_on(q: Q) -> bool:
async def run_on(q: Q) -> bool:
"""
Handle the query using a query handler (a function annotated with `@on()`).
Expand All @@ -141,6 +148,67 @@ async def handle_on(q: Q) -> bool:
Returns:
True if a matching query handler was found and invoked, else False.
"""
submitted = str(q.args['__wave_submission_name__'])

# Event handlers.
for event_source in expando_to_dict(q.events):
for entry in _event_handlers.get(event_source, []):
event_type, predicate, func, arity = entry
event = q.events[event_source]
if event_type in event:
arg_value = event[event_type]
if await _match_predicate(predicate, func, arity, q, arg_value):
return True

# Hash handlers.
if submitted == '#':
for rx, conv, func, arity in _path_handlers:
match = rx.match(q.args[submitted])
if match:
params = match.groupdict()
for key, value in params.items():
params[key] = conv[key].convert(value)
if len(params):
if arity <= 1:
await _invoke_handler(func, arity, q, None)
else:
await func(q, **params)
else:
await _invoke_handler(func, arity, q, None)
return True

# Arg handlers.
for entry in _arg_handlers.get(submitted, []):
predicate, func, arity = entry
if await _match_predicate(predicate, func, arity, q, q.args[submitted]):
return True
for predicate, func, arity, rx, conv in _arg_with_params_handlers:
match = rx.match(submitted)
if match:
params = match.groupdict()
for key, value in params.items():
params[key] = conv[key].convert(value)
if await _match_predicate(predicate, func, arity, q, q.args[submitted], **params):
return True

return False


async def handle_on(q: Q) -> bool:
"""
DEPRECATED: Handle the query using a query handler (a function annotated with `@on()`).
Args:
q: The query context.
Returns:
True if a matching query handler was found and invoked, else False.
"""
global _handle_on_deprecated_warning_printed
if not _handle_on_deprecated_warning_printed:
print('\033[93m' + 'WARNING: handle_on() is deprecated, use run_on() instead.' + '\033[0m')
_handle_on_deprecated_warning_printed = True

event_sources = expando_to_dict(q.events)
for event_source in event_sources:
event = q.events[event_source]
Expand Down
5 changes: 2 additions & 3 deletions py/h2o_wave/h2o_wave/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ def wrap(func):
# if not asyncio.iscoroutinefunction(func):
# raise ValueError(f"@on function '{func_name}' must be async")

if predicate:
if not callable(predicate):
raise ValueError(f"@on predicate must be callable for '{func_name}'")
if predicate and not callable(predicate):
raise ValueError(f"@on predicate must be callable for '{func_name}'")
if isinstance(arg, str) and len(arg):
if arg.startswith('#'): # location hash
rx, _, conv = compile_path(arg[1:])
Expand Down

0 comments on commit c126743

Please sign in to comment.