diff --git a/matrix_client/client.py b/matrix_client/client.py index 416c5539..f28e62ab 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -565,6 +565,7 @@ def _mkroom(self, room_id): self.rooms[room_id] = room return self.rooms[room_id] + # TODO better handling of the blocking I/O caused by update_one_time_key_counts def _sync(self, timeout_ms=30000): response = self.api.sync(self.sync_token, timeout_ms, filter=self.sync_filter) self.sync_token = response["next_batch"] @@ -583,6 +584,10 @@ def _sync(self, timeout_ms=30000): if room_id in self.rooms: del self.rooms[room_id] + if self._encryption and 'device_one_time_keys_count' in response: + self.olm_device.update_one_time_key_counts( + response['device_one_time_keys_count']) + for room_id, sync_room in response['rooms']['join'].items(): if room_id not in self.rooms: self._mkroom(room_id) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index ed27553f..514965db 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -24,15 +24,26 @@ class OlmDevice(object): ``0`` means only unsigned keys. The actual amount of keys is determined at runtime from the given proportion and the maximum number of one-time keys we can physically hold. + keys_threshold (float): Optional. Threshold below which a one-time key + replenishment is triggered. Must be between ``0`` and ``1``. For example, + ``0.1`` means that new one-time keys will be uploaded when there is less than + 10% of the maximum number of one-time keys on the server. """ _olm_algorithm = 'm.olm.v1.curve25519-aes-sha2' _megolm_algorithm = 'm.megolm.v1.aes-sha2' _algorithms = [_olm_algorithm, _megolm_algorithm] - def __init__(self, api, user_id, device_id, signed_keys_proportion=1): + def __init__(self, + api, + user_id, + device_id, + signed_keys_proportion=1, + keys_threshold=0.1): if not 0 <= signed_keys_proportion <= 1: raise ValueError('signed_keys_proportion must be between 0 and 1.') + if not 0 <= keys_threshold <= 1: + raise ValueError('keys_threshold must be between 0 and 1.') self.api = api check_user_id(user_id) self.user_id = user_id @@ -46,7 +57,8 @@ def __init__(self, api, user_id, device_id, signed_keys_proportion=1): # and it starts discarding keys, starting by the oldest. target_keys_number = self.olm_account.max_one_time_keys // 2 self.one_time_keys_manager = OneTimeKeysManager(target_keys_number, - signed_keys_proportion) + signed_keys_proportion, + keys_threshold) def upload_identity_keys(self): """Uploads this device's identity keys to HS. @@ -113,6 +125,17 @@ def upload_one_time_keys(self, force_update=False): logger.info('Uploaded new one-time keys: %s.', keys_uploaded) return keys_uploaded + def update_one_time_key_counts(self, counts): + """Update data on one-time keys count and upload new ones if necessary. + + Args: + counts (dict): Counts of keys currently on the HS for each key type. + """ + self.one_time_keys_manager.server_counts = counts + if self.one_time_keys_manager.should_upload(): + logger.info('Uploading new one-time keys.') + self.upload_one_time_keys() + def sign_json(self, json): """Signs a JSON object. diff --git a/matrix_client/crypto/one_time_keys.py b/matrix_client/crypto/one_time_keys.py index 5706732d..131dc023 100644 --- a/matrix_client/crypto/one_time_keys.py +++ b/matrix_client/crypto/one_time_keys.py @@ -1,13 +1,14 @@ class OneTimeKeysManager(object): """Handles one-time keys accounting for an OlmDevice.""" - def __init__(self, target_keys_number, signed_keys_proportion): + def __init__(self, target_keys_number, signed_keys_proportion, keys_threshold): self.target_counts = { 'signed_curve25519': int(round(signed_keys_proportion * target_keys_number)), 'curve25519': int(round((1 - signed_keys_proportion) * target_keys_number)), } self._server_counts = {} self.to_upload = {} + self.keys_threshold = keys_threshold @property def server_counts(self): @@ -24,6 +25,14 @@ def update_keys_to_upload(self): num_to_create = max(target_number - num_keys, 0) self.to_upload[key_type] = num_to_create + def should_upload(self): + if not self._server_counts: + return True + for key_type, target_number in self.target_counts.items(): + if self._server_counts.get(key_type, 0) < target_number * self.keys_threshold: + return True + return False + @property def curve25519_to_upload(self): return self.to_upload.get('curve25519', 0) diff --git a/test/client_test.py b/test/client_test.py index 3ae7b720..bde056cb 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -515,3 +515,26 @@ def test_detect_encryption_state(): room = client._mkroom(room_id) assert not room.encrypted + + +@responses.activate +def test_one_time_keys_sync(): + client = MatrixClient(HOSTNAME, encryption=True) + sync_url = HOSTNAME + MATRIX_V2_API_PATH + "/sync" + sync_response = deepcopy(response_examples.example_sync) + payload = {'dummy': 1} + sync_response["device_one_time_keys_count"] = payload + sync_response['rooms']['join'] = {} + + class DummyDevice: + + def update_one_time_key_counts(self, payload): + self.payload = payload + + device = DummyDevice() + client.olm_device = device + + responses.add(responses.GET, sync_url, json=sync_response) + + client._sync() + assert device.payload == payload diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 0067df8c..84f56df6 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -171,3 +171,36 @@ def test_upload_one_time_keys_force_update(self): self.device.upload_one_time_keys(force_update=True) assert len(responses.calls) == 3 + + @responses.activate + @pytest.mark.parametrize('count,should_upload', [(0, True), (25, False), (4, True)]) + def test_update_one_time_key_counts(self, count, should_upload): + upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' + responses.add(responses.POST, upload_url, json={'one_time_key_counts': {}}) + self.device.one_time_keys_manager.target_counts['signed_curve25519'] = 50 + self.device.one_time_keys_manager.server_counts.clear() + + count_dict = {} + if count: + count_dict['signed_curve25519'] = count + + self.device.update_one_time_key_counts(count_dict) + + if should_upload: + if count: + req_otk = json.loads(responses.calls[0].request.body)['one_time_keys'] + assert len(responses.calls) == 1 + else: + req_otk = json.loads(responses.calls[1].request.body)['one_time_keys'] + assert len(responses.calls) == 2 + assert len(req_otk) == 50 - count + else: + assert not len(responses.calls) + + @pytest.mark.parametrize('threshold', [-1, 2]) + def test_invalid_keys_threshold(self, threshold): + with pytest.raises(ValueError): + OlmDevice(self.cli.api, + self.user_id, + self.device_id, + keys_threshold=threshold)