Skip to content

Commit

Permalink
improve event handler state references (reflex-dev#2818)
Browse files Browse the repository at this point in the history
  • Loading branch information
benedikt-bartscher authored and Malte Klemm committed Mar 12, 2024
1 parent a4ae178 commit 5a5ac14
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 15 deletions.
4 changes: 4 additions & 0 deletions reflex/event.py
Expand Up @@ -147,6 +147,10 @@ class EventHandler(EventActionsMixin):
# The function to call in response to the event.
fn: Any

# The full name of the state class this event handler is attached to.
# Emtpy string means this event handler is a server side event.
state_full_name: str = ""

class Config:
"""The Pydantic config."""

Expand Down
25 changes: 20 additions & 5 deletions reflex/state.py
Expand Up @@ -484,7 +484,7 @@ def __init_subclass__(cls, **kwargs):
events[name] = value

for name, fn in events.items():
handler = EventHandler(fn=fn)
handler = cls._create_event_handler(fn)
cls.event_handlers[name] = handler
setattr(cls, name, handler)

Expand Down Expand Up @@ -689,7 +689,7 @@ def get_full_name(cls) -> str:

@classmethod
@functools.lru_cache()
def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
def get_class_substate(cls, path: Sequence[str] | str) -> Type[BaseState]:
"""Get the class substate.
Args:
Expand All @@ -701,6 +701,9 @@ def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
Raises:
ValueError: If the substate is not found.
"""
if isinstance(path, str):
path = tuple(path.split("."))

if len(path) == 0:
return cls
if path[0] == cls.get_name():
Expand Down Expand Up @@ -801,6 +804,18 @@ def _set_var(cls, prop: BaseVar):
"""
setattr(cls, prop._var_name, prop)

@classmethod
def _create_event_handler(cls, fn):
"""Create an event handler for the given function.
Args:
fn: The function to create an event handler for.
Returns:
The event handler.
"""
return EventHandler(fn=fn, state_full_name=cls.get_full_name())

@classmethod
def _create_setter(cls, prop: BaseVar):
"""Create a setter for the var.
Expand All @@ -810,7 +825,7 @@ def _create_setter(cls, prop: BaseVar):
"""
setter_name = prop.get_setter_name(include_state=False)
if setter_name not in cls.__dict__:
event_handler = EventHandler(fn=prop.get_setter())
event_handler = cls._create_event_handler(prop.get_setter())
cls.event_handlers[setter_name] = event_handler
setattr(cls, setter_name, event_handler)

Expand Down Expand Up @@ -1764,7 +1779,7 @@ async def update_vars_internal(self, vars: dict[str, Any]) -> None:
"""
for var, value in vars.items():
state_name, _, var_name = var.rpartition(".")
var_state_cls = State.get_class_substate(tuple(state_name.split(".")))
var_state_cls = State.get_class_substate(state_name)
var_state = await self.get_state(var_state_cls)
setattr(var_state, var_name, value)

Expand Down Expand Up @@ -2280,7 +2295,7 @@ async def get_state(
_, state_path = _split_substate_key(token)
if state_path:
# Get the State class associated with the given path.
state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
state_cls = self.state.get_class_substate(state_path)
else:
raise RuntimeError(
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
Expand Down
19 changes: 9 additions & 10 deletions reflex/utils/format.py
Expand Up @@ -6,7 +6,6 @@
import json
import os
import re
import sys
from typing import TYPE_CHECKING, Any, List, Union

from reflex import constants
Expand Down Expand Up @@ -470,18 +469,18 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
if len(parts) == 1:
return ("", parts[-1])

# Get the state and the function name.
state_name, name = parts[-2:]
# Get the state full name
state_full_name = handler.state_full_name

# Construct the full event handler name.
try:
# Try to get the state from the module.
state = vars(sys.modules[handler.fn.__module__])[state_name]
except Exception:
# If the state isn't in the module, just return the function name.
# Get the function name
name = parts[-1]

from reflex.state import State

if state_full_name == "state" and name not in State.__dict__:
return ("", to_snake_case(handler.fn.__qualname__))

return (state.get_full_name(), name)
return (state_full_name, name)


def format_event_handler(handler: EventHandler) -> str:
Expand Down

0 comments on commit 5a5ac14

Please sign in to comment.