Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Make StateFilter frozen so we can hash it (#10816)
Browse files Browse the repository at this point in the history
Also enables Mypy for related tests.
  • Loading branch information
reivilibre committed Sep 14, 2021
1 parent 14b8c04 commit 8eb7cb2
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 30 deletions.
1 change: 1 addition & 0 deletions changelog.d/10816.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `StateFilter` frozen so it is hashable.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ files =
tests/handlers/test_sync.py,
tests/rest/client/test_login.py,
tests/rest/client/test_auth.py,
tests/storage/test_state.py,
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

Expand Down
45 changes: 32 additions & 13 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
)

import attr
from frozendict import frozendict

from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap

if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad

from synapse.server import HomeServer
from synapse.storage.databases import Databases

Expand All @@ -40,7 +43,7 @@
T = TypeVar("T")


@attr.s(slots=True)
@attr.s(slots=True, frozen=True)
class StateFilter:
"""A filter used when querying for state.
Expand All @@ -53,14 +56,19 @@ class StateFilter:
appear in `types`.
"""

types = attr.ib(type=Dict[str, Optional[Set[str]]])
types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
include_others = attr.ib(default=False, type=bool)

def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
self.types = {k: v for k, v in self.types.items() if v is not None}
# this is needed to work around the fact that StateFilter is frozen
object.__setattr__(
self,
"types",
frozendict({k: v for k, v in self.types.items() if v is not None}),
)

@staticmethod
def all() -> "StateFilter":
Expand All @@ -69,7 +77,7 @@ def all() -> "StateFilter":
Returns:
The new state filter.
"""
return StateFilter(types={}, include_others=True)
return StateFilter(types=frozendict(), include_others=True)

@staticmethod
def none() -> "StateFilter":
Expand All @@ -78,7 +86,7 @@ def none() -> "StateFilter":
Returns:
The new state filter.
"""
return StateFilter(types={}, include_others=False)
return StateFilter(types=frozendict(), include_others=False)

@staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
Expand All @@ -103,7 +111,12 @@ def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":

type_dict.setdefault(typ, set()).add(s) # type: ignore

return StateFilter(types=type_dict)
return StateFilter(
types=frozendict(
(k, frozenset(v) if v is not None else None)
for k, v in type_dict.items()
)
)

@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
Expand All @@ -116,7 +129,10 @@ def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
Returns:
The new state filter
"""
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
return StateFilter(
types=frozendict({EventTypes.Member: frozenset(members)}),
include_others=True,
)

def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
Expand Down Expand Up @@ -173,7 +189,7 @@ def return_expanded(self) -> "StateFilter":
# We want to return all non-members, but only particular
# memberships
return StateFilter(
types={EventTypes.Member: self.types[EventTypes.Member]},
types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True,
)

Expand Down Expand Up @@ -245,14 +261,15 @@ def max_entries_returned(self) -> Optional[int]:

return len(self.concrete_types())

def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
"""Returns the state filtered with by this StateFilter
def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
"""Returns the state filtered with by this StateFilter.
Args:
state: The state map to filter
Returns:
The filtered state map
The filtered state map.
This is a copy, so it's safe to mutate.
"""
if self.is_full():
return dict(state_dict)
Expand Down Expand Up @@ -324,14 +341,16 @@ def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
if state_keys is None:
member_filter = StateFilter.all()
else:
member_filter = StateFilter({EventTypes.Member: state_keys})
member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
elif self.include_others:
member_filter = StateFilter.all()
else:
member_filter = StateFilter.none()

non_member_filter = StateFilter(
types={k: v for k, v in self.types.items() if k != EventTypes.Member},
types=frozendict(
{k: v for k, v in self.types.items() if k != EventTypes.Member}
),
include_others=self.include_others,
)

Expand Down
46 changes: 29 additions & 17 deletions tests/storage/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import logging

from frozendict import frozendict

from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
Expand Down Expand Up @@ -183,7 +185,9 @@ def test_get_state_for_event(self):
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
types=frozendict(
{EventTypes.Member: frozenset({self.u_alice.to_string()})}
),
include_others=True,
),
)
Expand All @@ -203,7 +207,8 @@ def test_get_state_for_event(self):
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: frozenset()}),
include_others=True,
),
)
)
Expand All @@ -228,7 +233,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)

Expand All @@ -245,7 +250,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)

Expand All @@ -258,7 +263,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)

Expand All @@ -275,7 +280,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)

Expand All @@ -295,7 +300,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
),
)

Expand All @@ -312,7 +318,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
),
)

Expand All @@ -325,7 +332,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
),
)

Expand Down Expand Up @@ -375,7 +383,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)

Expand All @@ -387,7 +395,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)

Expand All @@ -400,7 +408,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)

Expand All @@ -411,7 +419,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)

Expand All @@ -430,7 +438,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
),
)

Expand All @@ -441,7 +450,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
),
)

Expand All @@ -454,7 +464,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
),
)

Expand All @@ -465,7 +476,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
),
)

Expand Down

0 comments on commit 8eb7cb2

Please sign in to comment.