Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Make StateFilter frozen so we can hash it #10816

Merged
merged 3 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10816.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `StateFilter` frozen so it is hashable.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ files =
tests/handlers/test_sync.py,
tests/rest/client/test_login.py,
tests/rest/client/test_auth.py,
tests/storage/test_state.py,
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py

Expand Down
45 changes: 32 additions & 13 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
)

import attr
from frozendict import frozendict

from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap

if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad

from synapse.server import HomeServer
from synapse.storage.databases import Databases

Expand All @@ -40,7 +43,7 @@
T = TypeVar("T")


@attr.s(slots=True)
@attr.s(slots=True, frozen=True)
class StateFilter:
"""A filter used when querying for state.

Expand All @@ -53,14 +56,19 @@ class StateFilter:
appear in `types`.
"""

types = attr.ib(type=Dict[str, Optional[Set[str]]])
types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
include_others = attr.ib(default=False, type=bool)

def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
self.types = {k: v for k, v in self.types.items() if v is not None}
# this is needed to work around the fact that StateFilter is frozen
object.__setattr__(
self,
"types",
frozendict({k: v for k, v in self.types.items() if v is not None}),
)

@staticmethod
def all() -> "StateFilter":
Expand All @@ -69,7 +77,7 @@ def all() -> "StateFilter":
Returns:
The new state filter.
"""
return StateFilter(types={}, include_others=True)
return StateFilter(types=frozendict(), include_others=True)

@staticmethod
def none() -> "StateFilter":
Expand All @@ -78,7 +86,7 @@ def none() -> "StateFilter":
Returns:
The new state filter.
"""
return StateFilter(types={}, include_others=False)
return StateFilter(types=frozendict(), include_others=False)

@staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
Expand All @@ -103,7 +111,12 @@ def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":

type_dict.setdefault(typ, set()).add(s) # type: ignore

return StateFilter(types=type_dict)
return StateFilter(
types=frozendict(
(k, frozenset(v) if v is not None else None)
for k, v in type_dict.items()
)
)

@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
Expand All @@ -116,7 +129,10 @@ def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
Returns:
The new state filter
"""
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
return StateFilter(
types=frozendict({EventTypes.Member: frozenset(members)}),
include_others=True,
)

def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
Expand Down Expand Up @@ -173,7 +189,7 @@ def return_expanded(self) -> "StateFilter":
# We want to return all non-members, but only particular
# memberships
return StateFilter(
types={EventTypes.Member: self.types[EventTypes.Member]},
types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True,
)

Expand Down Expand Up @@ -245,14 +261,15 @@ def max_entries_returned(self) -> Optional[int]:

return len(self.concrete_types())

def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
"""Returns the state filtered with by this StateFilter
def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
"""Returns the state filtered with by this StateFilter.

Args:
state: The state map to filter

Returns:
The filtered state map
The filtered state map.
This is a copy, so it's safe to mutate.
"""
if self.is_full():
return dict(state_dict)
Expand Down Expand Up @@ -324,14 +341,16 @@ def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
if state_keys is None:
member_filter = StateFilter.all()
else:
member_filter = StateFilter({EventTypes.Member: state_keys})
member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
elif self.include_others:
member_filter = StateFilter.all()
else:
member_filter = StateFilter.none()

non_member_filter = StateFilter(
types={k: v for k, v in self.types.items() if k != EventTypes.Member},
types=frozendict(
{k: v for k, v in self.types.items() if k != EventTypes.Member}
),
include_others=self.include_others,
)

Expand Down
46 changes: 29 additions & 17 deletions tests/storage/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import logging

from frozendict import frozendict

from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
Expand Down Expand Up @@ -183,7 +185,9 @@ def test_get_state_for_event(self):
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
types=frozendict(
{EventTypes.Member: frozenset({self.u_alice.to_string()})}
),
include_others=True,
),
)
Expand All @@ -203,7 +207,8 @@ def test_get_state_for_event(self):
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: frozenset()}),
include_others=True,
),
)
)
Expand All @@ -228,7 +233,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)

Expand All @@ -245,7 +250,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)

Expand All @@ -258,7 +263,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)

Expand All @@ -275,7 +280,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)

Expand All @@ -295,7 +300,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
),
)

Expand All @@ -312,7 +318,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
),
)

Expand All @@ -325,7 +332,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
),
)

Expand Down Expand Up @@ -375,7 +383,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)

Expand All @@ -387,7 +395,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)

Expand All @@ -400,7 +408,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)

Expand All @@ -411,7 +419,7 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)

Expand All @@ -430,7 +438,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
),
)

Expand All @@ -441,7 +450,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
),
)

Expand All @@ -454,7 +464,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
),
)

Expand All @@ -465,7 +476,8 @@ def test_get_state_for_event(self):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
),
)

Expand Down