Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Keep fallback key marked as used if it's re-uploaded #11382

Merged
merged 3 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11382.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Keep fallback key marked as used if it's re-uploaded.
51 changes: 40 additions & 11 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,29 +408,58 @@ async def set_e2e_fallback_keys(
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
await self.db_pool.runInteraction(
"set_e2e_fallback_keys_txn",
self._set_e2e_fallback_keys_txn,
user_id,
device_id,
fallback_keys,
)

await self.invalidate_cache_and_stream(
"get_e2e_unused_fallback_key_types", (user_id, device_id)
)

def _set_e2e_fallback_keys_txn(
self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
await self.db_pool.simple_upsert(
"e2e_fallback_keys_json",
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
values={
"key_id": key_id,
"key_json": json_encoder.encode(fallback_key),
"used": False,
},
desc="set_e2e_fallback_key",
retcol="key_json",
allow_none=True,
)

await self.invalidate_cache_and_stream(
"get_e2e_unused_fallback_key_types", (user_id, device_id)
)
new_key_json = encode_canonical_json(fallback_key).decode("utf-8")

# If the uploaded key is the same as the current fallback key,
# don't do anything. This prevents marking the key as unused if it
# was already used.
if old_key_json != new_key_json:
self.db_pool.simple_upsert_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
values={
"key_id": key_id,
"key_json": json_encoder.encode(fallback_key),
"used": False,
},
)

@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
Expand Down
32 changes: 31 additions & 1 deletion tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def test_fallback_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "key1"}
fallback_key2 = {"alg1:k2": "key2"}
otk = {"alg1:k2": "key2"}

# we shouldn't have any unused fallback keys yet
Expand Down Expand Up @@ -213,6 +214,35 @@ def test_fallback_key(self):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)

# re-uploading the same fallback key should still result in no unused fallback
# keys
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"org.matrix.msc2732.fallback_keys": fallback_key},
)
)

res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])

# uploading a new fallback key should result in an unused fallback key
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"org.matrix.msc2732.fallback_keys": fallback_key2},
)
)

res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, ["alg1"])

# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
self.get_success(
Expand All @@ -238,7 +268,7 @@ def test_fallback_key(self):
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)

def test_replace_master_key(self):
Expand Down