Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions matrix_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def _mkroom(self, room_id):
self.rooms[room_id] = room
return self.rooms[room_id]

# TODO better handling of the blocking I/O caused by update_one_time_key_counts
def _sync(self, timeout_ms=30000):
response = self.api.sync(self.sync_token, timeout_ms, filter=self.sync_filter)
self.sync_token = response["next_batch"]
Expand All @@ -583,6 +584,10 @@ def _sync(self, timeout_ms=30000):
if room_id in self.rooms:
del self.rooms[room_id]

if self._encryption and 'device_one_time_keys_count' in response:
self.olm_device.update_one_time_key_counts(
response['device_one_time_keys_count'])

for room_id, sync_room in response['rooms']['join'].items():
if room_id not in self.rooms:
self._mkroom(room_id)
Expand Down
27 changes: 25 additions & 2 deletions matrix_client/crypto/olm_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,26 @@ class OlmDevice(object):
``0`` means only unsigned keys. The actual amount of keys is determined at
runtime from the given proportion and the maximum number of one-time keys
we can physically hold.
keys_threshold (float): Optional. Threshold below which a one-time key
replenishment is triggered. Must be between ``0`` and ``1``. For example,
``0.1`` means that new one-time keys will be uploaded when there is less than
10% of the maximum number of one-time keys on the server.
"""

_olm_algorithm = 'm.olm.v1.curve25519-aes-sha2'
_megolm_algorithm = 'm.megolm.v1.aes-sha2'
_algorithms = [_olm_algorithm, _megolm_algorithm]

def __init__(self, api, user_id, device_id, signed_keys_proportion=1):
def __init__(self,
api,
user_id,
device_id,
signed_keys_proportion=1,
keys_threshold=0.1):
if not 0 <= signed_keys_proportion <= 1:
raise ValueError('signed_keys_proportion must be between 0 and 1.')
if not 0 <= keys_threshold <= 1:
raise ValueError('keys_threshold must be between 0 and 1.')
self.api = api
check_user_id(user_id)
self.user_id = user_id
Expand All @@ -46,7 +57,8 @@ def __init__(self, api, user_id, device_id, signed_keys_proportion=1):
# and it starts discarding keys, starting by the oldest.
target_keys_number = self.olm_account.max_one_time_keys // 2
self.one_time_keys_manager = OneTimeKeysManager(target_keys_number,
signed_keys_proportion)
signed_keys_proportion,
keys_threshold)

def upload_identity_keys(self):
"""Uploads this device's identity keys to HS.
Expand Down Expand Up @@ -113,6 +125,17 @@ def upload_one_time_keys(self, force_update=False):
logger.info('Uploaded new one-time keys: %s.', keys_uploaded)
return keys_uploaded

def update_one_time_key_counts(self, counts):
"""Update data on one-time keys count and upload new ones if necessary.

Args:
counts (dict): Counts of keys currently on the HS for each key type.
"""
self.one_time_keys_manager.server_counts = counts
if self.one_time_keys_manager.should_upload():
logger.info('Uploading new one-time keys.')
self.upload_one_time_keys()

def sign_json(self, json):
"""Signs a JSON object.

Expand Down
11 changes: 10 additions & 1 deletion matrix_client/crypto/one_time_keys.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
class OneTimeKeysManager(object):
"""Handles one-time keys accounting for an OlmDevice."""

def __init__(self, target_keys_number, signed_keys_proportion):
def __init__(self, target_keys_number, signed_keys_proportion, keys_threshold):
self.target_counts = {
'signed_curve25519': int(round(signed_keys_proportion * target_keys_number)),
'curve25519': int(round((1 - signed_keys_proportion) * target_keys_number)),
}
self._server_counts = {}
self.to_upload = {}
self.keys_threshold = keys_threshold

@property
def server_counts(self):
Expand All @@ -24,6 +25,14 @@ def update_keys_to_upload(self):
num_to_create = max(target_number - num_keys, 0)
self.to_upload[key_type] = num_to_create

def should_upload(self):
if not self._server_counts:
return True
for key_type, target_number in self.target_counts.items():
if self._server_counts.get(key_type, 0) < target_number * self.keys_threshold:
return True
return False

@property
def curve25519_to_upload(self):
return self.to_upload.get('curve25519', 0)
Expand Down
23 changes: 23 additions & 0 deletions test/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,26 @@ def test_detect_encryption_state():

room = client._mkroom(room_id)
assert not room.encrypted


@responses.activate
def test_one_time_keys_sync():
client = MatrixClient(HOSTNAME, encryption=True)
sync_url = HOSTNAME + MATRIX_V2_API_PATH + "/sync"
sync_response = deepcopy(response_examples.example_sync)
payload = {'dummy': 1}
sync_response["device_one_time_keys_count"] = payload
sync_response['rooms']['join'] = {}

class DummyDevice:

def update_one_time_key_counts(self, payload):
self.payload = payload

device = DummyDevice()
client.olm_device = device

responses.add(responses.GET, sync_url, json=sync_response)

client._sync()
assert device.payload == payload
33 changes: 33 additions & 0 deletions test/crypto/olm_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,36 @@ def test_upload_one_time_keys_force_update(self):

self.device.upload_one_time_keys(force_update=True)
assert len(responses.calls) == 3

@responses.activate
@pytest.mark.parametrize('count,should_upload', [(0, True), (25, False), (4, True)])
def test_update_one_time_key_counts(self, count, should_upload):
upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload'
responses.add(responses.POST, upload_url, json={'one_time_key_counts': {}})
self.device.one_time_keys_manager.target_counts['signed_curve25519'] = 50
self.device.one_time_keys_manager.server_counts.clear()

count_dict = {}
if count:
count_dict['signed_curve25519'] = count

self.device.update_one_time_key_counts(count_dict)

if should_upload:
if count:
req_otk = json.loads(responses.calls[0].request.body)['one_time_keys']
assert len(responses.calls) == 1
else:
req_otk = json.loads(responses.calls[1].request.body)['one_time_keys']
assert len(responses.calls) == 2
assert len(req_otk) == 50 - count
else:
assert not len(responses.calls)

@pytest.mark.parametrize('threshold', [-1, 2])
def test_invalid_keys_threshold(self, threshold):
with pytest.raises(ValueError):
OlmDevice(self.cli.api,
self.user_id,
self.device_id,
keys_threshold=threshold)