Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Handle threads when fetching events for push.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Sep 26, 2022
1 parent 77c6dc7 commit 5a8e1ef
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 40 deletions.
1 change: 1 addition & 0 deletions changelog.d/13878.feature
@@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
80 changes: 57 additions & 23 deletions synapse/storage/databases/main/event_push_actions.py
Expand Up @@ -120,6 +120,32 @@
]


@attr.s(slots=True, auto_attribs=True)
class _RoomReceipt:
"""
HttpPushAction instances include the information used to generate HTTP
requests to a push gateway.
"""

unthreaded_stream_ordering: int = 0
# threaded_stream_ordering includes the main pseudo-thread.
threaded_stream_ordering: Dict[str, int] = attr.Factory(dict)

def is_unread(self, thread_id: str, stream_ordering: int) -> bool:
"""Returns True if the stream ordering is unread according to the receipt information."""

# Only include push actions with a stream ordering after both the unthreaded
# and threaded receipt. Properly handles a user without any receipts present.
return (
self.unthreaded_stream_ordering < stream_ordering
and self.threaded_stream_ordering.get(thread_id, 0) < stream_ordering
)


# A _RoomReceipt with no receipts in it.
MISSING_ROOM_RECEIPT = _RoomReceipt()


@attr.s(slots=True, frozen=True, auto_attribs=True)
class HttpPushAction:
"""
Expand Down Expand Up @@ -705,7 +731,7 @@ def f(txn: LoggingTransaction) -> List[str]:

def _get_receipts_by_room_txn(
self, txn: LoggingTransaction, user_id: str
) -> Dict[str, int]:
) -> Dict[str, _RoomReceipt]:
"""
Generate a map of room ID to the latest stream ordering that has been
read by the given user.
Expand All @@ -715,7 +741,8 @@ def _get_receipts_by_room_txn(
user_id: The user to fetch receipts for.
Returns:
A map of room ID to stream ordering for all rooms the user has a receipt in.
A map including all rooms the user is in with a receipt. It maps
room IDs to _RoomReceipt instances
"""
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
Expand All @@ -727,20 +754,26 @@ def _get_receipts_by_room_txn(
)

sql = f"""
SELECT room_id, MAX(stream_ordering)
SELECT room_id, thread_id, MAX(stream_ordering)
FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
WHERE {receipt_types_clause}
AND user_id = ?
GROUP BY room_id
GROUP BY room_id, thread_id
"""

args.extend((user_id,))
txn.execute(sql, args)
return {
room_id: latest_stream_ordering
for room_id, latest_stream_ordering in txn.fetchall()
}

result: Dict[str, _RoomReceipt] = {}
for room_id, thread_id, stream_ordering in txn:
room_receipt = result.setdefault(room_id, _RoomReceipt())
if thread_id is None:
room_receipt.unthreaded_stream_ordering = stream_ordering
else:
room_receipt.threaded_stream_ordering[thread_id] = stream_ordering

return result

async def get_unread_push_actions_for_user_in_range_for_http(
self,
Expand Down Expand Up @@ -773,9 +806,10 @@ async def get_unread_push_actions_for_user_in_range_for_http(

def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool]]:
) -> List[Tuple[str, str, str, int, str, bool]]:
sql = """
SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
ep.actions, ep.highlight
FROM event_push_actions AS ep
WHERE
ep.user_id = ?
Expand All @@ -785,7 +819,7 @@ def get_push_actions_txn(
ORDER BY ep.stream_ordering ASC LIMIT ?
"""
txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
return cast(List[Tuple[str, str, str, int, str, bool]], txn.fetchall())

push_actions = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
Expand All @@ -798,10 +832,10 @@ def get_push_actions_txn(
stream_ordering=stream_ordering,
actions=_deserialize_action(actions, highlight),
)
for event_id, room_id, stream_ordering, actions, highlight in push_actions
# Only include push actions with a stream ordering after any receipt, or without any
# receipt present (invited to but never read rooms).
if stream_ordering > receipts_by_room.get(room_id, 0)
for event_id, room_id, thread_id, stream_ordering, actions, highlight in push_actions
if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
thread_id, stream_ordering
)
]

# Now sort it so it's ordered correctly, since currently it will
Expand Down Expand Up @@ -845,10 +879,10 @@ async def get_unread_push_actions_for_user_in_range_for_email(

def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]:
) -> List[Tuple[str, str, str, int, str, bool, int]]:
sql = """
SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
ep.highlight, e.received_ts
SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
ep.actions, ep.highlight, e.received_ts
FROM event_push_actions AS ep
INNER JOIN events AS e USING (room_id, event_id)
WHERE
Expand All @@ -859,7 +893,7 @@ def get_push_actions_txn(
ORDER BY ep.stream_ordering DESC LIMIT ?
"""
txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall())

push_actions = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
Expand All @@ -874,10 +908,10 @@ def get_push_actions_txn(
actions=_deserialize_action(actions, highlight),
received_ts=received_ts,
)
for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
# Only include push actions with a stream ordering after any receipt, or without any
# receipt present (invited to but never read rooms).
if stream_ordering > receipts_by_room.get(room_id, 0)
for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions
if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
thread_id, stream_ordering
)
]

# Now sort it so it's ordered correctly, since currently it will
Expand Down
56 changes: 39 additions & 17 deletions tests/storage/test_event_push_actions.py
Expand Up @@ -16,6 +16,7 @@

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
Expand Down Expand Up @@ -65,16 +66,23 @@ def test_get_unread_push_actions_for_user_in_range(self) -> None:
user_id, token, _, other_token, room_id = self._create_users_and_room()

# Create two events, one of which is a highlight.
self.helper.send_event(
first_event_id = self.helper.send_event(
room_id,
type="m.room.message",
content={"msgtype": "m.text", "body": "msg"},
tok=other_token,
)
event_id = self.helper.send_event(
)["event_id"]
second_event_id = self.helper.send_event(
room_id,
type="m.room.message",
content={"msgtype": "m.text", "body": user_id},
content={
"msgtype": "m.text",
"body": user_id,
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": first_event_id,
},
},
tok=other_token,
)["event_id"]

Expand All @@ -94,13 +102,13 @@ def test_get_unread_push_actions_for_user_in_range(self) -> None:
)
self.assertEqual(2, len(email_actions))

# Send a receipt, which should clear any actions.
# Send a receipt, which should clear the first action.
self.get_success(
self.store.insert_receipt(
room_id,
"m.read",
user_id=user_id,
event_ids=[event_id],
event_ids=[first_event_id],
thread_id=None,
data={},
)
Expand All @@ -110,6 +118,30 @@ def test_get_unread_push_actions_for_user_in_range(self) -> None:
user_id, 0, 1000, 20
)
)
self.assertEqual(1, len(http_actions))
email_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
user_id, 0, 1000, 20
)
)
self.assertEqual(1, len(email_actions))

# Send a thread receipt to clear the thread action.
self.get_success(
self.store.insert_receipt(
room_id,
"m.read",
user_id=user_id,
event_ids=[second_event_id],
thread_id=first_event_id,
data={},
)
)
http_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
user_id, 0, 1000, 20
)
)
self.assertEqual([], http_actions)
email_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
Expand Down Expand Up @@ -416,17 +448,7 @@ def test_count_aggregation_mixed(self) -> None:
sends both unthreaded and threaded receipts.
"""

# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")

# And another users to send events.
other_id = self.register_user("other", "pass")
other_token = self.login("other", "pass")

# Create a room and put both users in it.
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)
user_id, token, _, other_token, room_id = self._create_users_and_room()
thread_id: str

last_event_id: str
Expand Down

0 comments on commit 5a8e1ef

Please sign in to comment.