Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Change current_state_delta stream to use min(stream_id)
Browse files Browse the repository at this point in the history
Fixes a race in the `get_rooms_for_user_with_stream_ordering` cache
invalidation. Hopefully, won't break anything else.
  • Loading branch information
richvdh committed Mar 28, 2019
1 parent a6ee317 commit 10027b6
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 55 deletions.
22 changes: 14 additions & 8 deletions synapse/storage/events.py
Expand Up @@ -79,7 +79,7 @@ def encode_json(json_object):
"""
out = frozendict_json_encoder.encode(json_object)
if isinstance(out, bytes):
out = out.decode('utf8')
out = out.decode("utf8")
return out


Expand Down Expand Up @@ -813,9 +813,10 @@ def _persist_events_txn(
"""
all_events_and_contexts = events_and_contexts

min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering

self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)

self._update_forward_extremities_txn(
txn,
Expand Down Expand Up @@ -890,7 +891,7 @@ def _persist_events_txn(
backfilled=backfilled,
)

def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
for room_id, current_state_tuple in iteritems(state_delta_by_room):
to_delete, to_insert = current_state_tuple

Expand All @@ -899,6 +900,12 @@ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
# that we can use it to calculate the `prev_event_id`. (This
# allows us to not have to pull out the existing state
# unnecessarily).
#
# The stream_id for the update is chosen to be the minimum of the stream_ids
# for the batch of the events that we are persisting; that means we do not
# end up in a situation where workers see events before the
# current_state_delta updates.
#
sql = """
INSERT INTO current_state_delta_stream
(stream_id, room_id, type, state_key, event_id, prev_event_id)
Expand All @@ -911,7 +918,7 @@ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
sql,
(
(
max_stream_order,
stream_id,
room_id,
etype,
state_key,
Expand All @@ -929,7 +936,7 @@ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
sql,
(
(
max_stream_order,
stream_id,
room_id,
etype,
state_key,
Expand Down Expand Up @@ -970,7 +977,7 @@ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
room_id,
max_stream_order,
stream_id,
)

# Invalidate the various caches
Expand All @@ -988,8 +995,7 @@ def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):

for member in members_changed:
txn.call_after(
self.get_rooms_for_user_with_stream_ordering.invalidate,
(member,),
self.get_rooms_for_user_with_stream_ordering.invalidate, (member,)
)

self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
Expand Down
28 changes: 24 additions & 4 deletions tests/replication/slave/storage/_base.py
Expand Up @@ -56,7 +56,9 @@ def prepare(self, reactor, clock, hs):
client = client_factory.buildProtocol(None)

client.makeConnection(FakeTransport(server, reactor))
server.makeConnection(FakeTransport(client, reactor))

self.server_to_client_transport = FakeTransport(client, reactor)
server.makeConnection(self.server_to_client_transport)

def replicate(self):
"""Tell the master side of replication that something has happened, and then
Expand All @@ -69,6 +71,24 @@ def check(self, method, args, expected_result=None):
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)
self.assertEqual(master_result, slaved_result)
self.assertEqual(
master_result,
expected_result,
"Expected master result to be %r but was %r" % (
expected_result, master_result
),
)
self.assertEqual(
slaved_result,
expected_result,
"Expected slave result to be %r but was %r" % (
expected_result, slaved_result
),
)
self.assertEqual(
master_result,
slaved_result,
"Slave result %r does not match master result %r" % (
slaved_result, master_result
),
)
139 changes: 117 additions & 22 deletions tests/replication/slave/storage/test_events.py
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

from canonicaljson import encode_canonical_json

from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser

Expand All @@ -26,6 +28,8 @@
OUTLIER = {"outlier": True}
ROOM_ID = "!room:blue"

logger = logging.getLogger(__name__)


def dict_equals(self, other):
me = encode_canonical_json(self.get_pdu_json())
Expand Down Expand Up @@ -191,17 +195,123 @@ def test_get_rooms_for_user_with_stream_ordering(self):
{(ROOM_ID, j2.internal_metadata.stream_ordering)},
)

def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
"""Check that current_state invalidation happens correctly with multiple events
in the persistence batch.
This test attempts to reproduce a race condition between the event persistence
loop and a worker-based Sync handler.
The problem occurred when the master persisted several events in one batch. It
only updates the current_state at the end of each batch, so the obvious thing
to do is then to issue a current_state_delta stream update corresponding to the
last stream_id in the batch.
However, that raises the possibility that a worker will see the replication
notification for a join event before the current_state caches are invalidated.
The test involves:
* creating a join and a message event for a user, and persisting them in the
same batch
* controlling the replication stream so that updates are sent gradually
* between each bunch of replication updates, check that we see a consistent
snapshot of the state.
"""
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
self.replicate()
self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())

# limit the replication rate
repl_transport = self.server_to_client_transport
repl_transport.autoflush = False

# build the join and message events and persist them in the same batch.
logger.info("----- build test events ------")
j2, j2ctx = self.build_event(
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
self.get_success(self.master_store.persist_events([
(j2, j2ctx),
(msg, msgctx),
]))
self.replicate()

event_source = RoomEventSource(self.hs)
event_source.store = self.slaved_store
current_token = self.get_success(event_source.get_current_key())

# gradually stream out the replication
while repl_transport.buffer:
logger.info("------ flush ------")
repl_transport.flush(30)
self.pump(0)

prev_token = current_token
current_token = self.get_success(event_source.get_current_key())

# attempt to replicate the behaviour of the sync handler.
#
# First, we get a list of the rooms we are joined to
joined_rooms = self.get_success(
self.slaved_store.get_rooms_for_user_with_stream_ordering(
USER_ID_2,
),
)

# Then, we get a list of the events since the last sync
membership_changes = self.get_success(
self.slaved_store.get_membership_changes_for_user(
USER_ID_2, prev_token, current_token,
)
)

logger.info(
"%s->%s: joined_rooms=%r membership_changes=%r",
prev_token,
current_token,
joined_rooms,
membership_changes,
)

# the membership change is only any use to us if the room is in the
# joined_rooms list.
if membership_changes:
self.assertEqual(
joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
)

event_id = 0

def persist(
def persist(self, backfill=False, **kwargs):
"""
Returns:
synapse.events.FrozenEvent: The event that was persisted.
"""
event, context = self.build_event(**kwargs)

if backfill:
self.get_success(
self.master_store.persist_events([(event, context)], backfilled=True)
)
else:
self.get_success(
self.master_store.persist_event(event, context)
)

return event

def build_event(
self,
sender=USER_ID,
room_id=ROOM_ID,
type="m.room.message",
key=None,
internal={},
state=None,
backfill=False,
depth=None,
prev_events=[],
auth_events=[],
Expand All @@ -210,10 +320,7 @@ def persist(
push_actions=[],
**content
):
"""
Returns:
synapse.events.FrozenEvent: The event that was persisted.
"""

if depth is None:
depth = self.event_id

Expand Down Expand Up @@ -252,23 +359,11 @@ def persist(
)
else:
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
context = self.get_success(state_handler.compute_event_context(
event
))

self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
)

ordering = None
if backfill:
self.get_success(
self.master_store.persist_events([(event, context)], backfilled=True)
)
else:
ordering, _ = self.get_success(
self.master_store.persist_event(event, context)
)

if ordering:
event.internal_metadata.stream_ordering = ordering

return event
return event, context
56 changes: 35 additions & 21 deletions tests/server.py
Expand Up @@ -365,6 +365,7 @@ class FakeTransport(object):
disconnected = False
buffer = attr.ib(default=b'')
producer = attr.ib(default=None)
autoflush = attr.ib(default=True)

def getPeer(self):
return None
Expand Down Expand Up @@ -415,31 +416,44 @@ def _produce():
def write(self, byt):
self.buffer = self.buffer + byt

def _write():
if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol
return

if self.disconnected:
return
logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)

if getattr(self.other, "transport") is not None:
try:
self.other.dataReceived(self.buffer)
self.buffer = b""
except Exception as e:
logger.warning("Exception writing to protocol: %s", e)
return

self._reactor.callLater(0.0, _write)

# always actually do the write asynchronously. Some protocols (notably the
# TLSMemoryBIOProtocol) get very confused if a read comes back while they are
# still doing a write. Doing a callLater here breaks the cycle.
self._reactor.callLater(0.0, _write)
if self.autoflush:
self._reactor.callLater(0.0, self.flush)

def writeSequence(self, seq):
for x in seq:
self.write(x)

def flush(self, maxbytes=None):
if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol
return

if self.disconnected:
return

if getattr(self.other, "transport") is None:
# the other has no transport yet; reschedule
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
return

if maxbytes is not None:
to_write = self.buffer[:maxbytes]
else:
to_write = self.buffer

logger.info("%s->%s: %s", self._protocol, self.other, to_write)

try:
self.other.dataReceived(to_write)
except Exception as e:
logger.warning("Exception writing to protocol: %s", e)
return

self.buffer = self.buffer[len(to_write):]
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)

0 comments on commit 10027b6

Please sign in to comment.