Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
update eventcontext unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
richvdh committed Nov 5, 2019
1 parent a4c9735 commit ba3a5e8
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 14 deletions.
10 changes: 7 additions & 3 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
from collections import namedtuple
from typing import Iterable, Optional

from six import iteritems, itervalues

Expand All @@ -27,6 +28,7 @@

from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
Expand Down Expand Up @@ -212,15 +214,17 @@ def get_hosts_in_room_at_events(self, room_id, event_ids):
return joined_hosts

@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
"""Build an EventContext structure for the event.
This works out what the current state should be for the event, and
generates a new state group if necessary.
Args:
event (synapse.events.EventBase):
old_state (dict|None): The state at the event if it can't be
event:
old_state: The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
Expand Down
61 changes: 50 additions & 11 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler

from tests import unittest
Expand Down Expand Up @@ -198,16 +199,22 @@ def test_branch_no_conflict(self):

self.store.register_events(graph.walk())

context_store = {}
context_store = {} # type: dict[str, EventContext]

for event in graph.walk():
context = yield self.state.compute_event_context(event)
self.store.register_event_context(event, context)
context_store[event.event_id] = context

prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
ctx_c = context_store["C"]
ctx_d = context_store["D"]

prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertEqual(2, len(prev_state_ids))

self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)

@defer.inlineCallbacks
def test_branch_basic_conflict(self):
graph = Graph(
Expand Down Expand Up @@ -241,12 +248,19 @@ def test_branch_basic_conflict(self):
self.store.register_event_context(event, context)
context_store[event.event_id] = context

prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
# C ends up winning the resolution between B and C

ctx_c = context_store["C"]
ctx_d = context_store["D"]

prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
)

self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)

@defer.inlineCallbacks
def test_branch_have_banned_conflict(self):
graph = Graph(
Expand Down Expand Up @@ -292,11 +306,18 @@ def test_branch_have_banned_conflict(self):
self.store.register_event_context(event, context)
context_store[event.event_id] = context

prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
# C ends up winning the resolution between C and D because bans win over other
# changes

ctx_c = context_store["C"]
ctx_e = context_store["E"]

prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
)
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)

@defer.inlineCallbacks
def test_branch_have_perms_conflict(self):
Expand Down Expand Up @@ -360,12 +381,20 @@ def test_branch_have_perms_conflict(self):
self.store.register_event_context(event, context)
context_store[event.event_id] = context

prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
# B ends up winning the resolution between B and C because power levels
# win over other changes.

ctx_b = context_store["B"]
ctx_d = context_store["D"]

prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
)

self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)

def _add_depths(self, nodes, edges):
def _get_depth(ev):
node = nodes[ev]
Expand All @@ -390,13 +419,16 @@ def test_annotate_with_old_message(self):

context = yield self.state.compute_event_context(event, old_state=old_state)

current_state_ids = yield context.get_current_state_ids(self.store)
prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())

self.assertEqual(
set(e.event_id for e in old_state), set(current_state_ids.values())
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
)

self.assertIsNotNone(context.state_group)
self.assertIsNotNone(context.state_group_before_event)
self.assertEqual(context.state_group_before_event, context.state_group)

@defer.inlineCallbacks
def test_annotate_with_old_state(self):
Expand All @@ -411,11 +443,18 @@ def test_annotate_with_old_state(self):
context = yield self.state.compute_event_context(event, old_state=old_state)

prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())

self.assertEqual(
set(e.event_id for e in old_state), set(prev_state_ids.values())
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
)

self.assertIsNotNone(context.state_group_before_event)
self.assertNotEqual(context.state_group_before_event, context.state_group)
self.assertEqual(context.state_group_before_event, context.prev_group)
self.assertEqual({("state", ""): event.event_id}, context.delta_ids)

@defer.inlineCallbacks
def test_trivial_annotate_message(self):
prev_event_id = "prev_event_id"
Expand Down

0 comments on commit ba3a5e8

Please sign in to comment.