diff --git a/matrix_client/api.py b/matrix_client/api.py index 3cff5c8e..f58f2902 100644 --- a/matrix_client/api.py +++ b/matrix_client/api.py @@ -164,13 +164,23 @@ def logout(self): """ return self._send("POST", "/logout") - def create_room(self, alias=None, is_public=False, invitees=None): + def create_room( + self, + alias=None, + name=None, + is_public=False, + invitees=None, + federate=None + ): """Perform /createRoom. Args: alias (str): Optional. The room alias name to set for this room. + name (str): Optional. Name for new room. is_public (bool): Optional. The public/private visibility. invitees (list): Optional. The list of user IDs to invite. + federate (bool): Optional. Сan a room be federated. + Default to True. """ content = { "visibility": "public" if is_public else "private" @@ -179,6 +189,10 @@ def create_room(self, alias=None, is_public=False, invitees=None): content["room_alias_name"] = alias if invitees: content["invite"] = invitees + if name: + content["name"] = name + if federate is not None: + content["creation_content"] = {'m.federate': federate} return self._send("POST", "/createRoom", content) def join_room(self, room_id_or_alias): @@ -233,6 +247,18 @@ def send_state_event(self, room_id, event_type, content, state_key="", params["ts"] = timestamp return self._send("PUT", path, content, query_params=params) + def get_state_event(self, room_id, event_type): + """Perform GET /rooms/$room_id/state/$event_type + + Args: + room_id(str): The room ID. + event_type (str): The type of the event. + + Raises: + MatrixRequestError(code=404) if the state event is not found. + """ + return self._send("GET", "/rooms/{}/state/{}".format(quote(room_id), event_type)) + def send_message_event(self, room_id, event_type, content, txn_id=None, timestamp=None): """Perform PUT /rooms/$room_id/send/$event_type @@ -393,7 +419,7 @@ def get_room_name(self, room_id): Args: room_id(str): The room ID """ - return self._send("GET", "/rooms/" + room_id + "/state/m.room.name") + return self.get_state_event(room_id, "m.room.name") def set_room_name(self, room_id, name, timestamp=None): """Perform PUT /rooms/$room_id/state/m.room.name @@ -412,7 +438,7 @@ def get_room_topic(self, room_id): Args: room_id (str): The room ID """ - return self._send("GET", "/rooms/" + room_id + "/state/m.room.topic") + return self.get_state_event(room_id, "m.room.topic") def set_room_topic(self, room_id, topic, timestamp=None): """Perform PUT /rooms/$room_id/state/m.room.topic @@ -432,8 +458,7 @@ def get_power_levels(self, room_id): Args: room_id(str): The room ID """ - return self._send("GET", "/rooms/" + quote(room_id) + - "/state/m.room.power_levels") + return self.get_state_event(room_id, "m.room.power_levels") def set_power_levels(self, room_id, content): """Perform PUT /rooms/$room_id/state/m.room.power_levels diff --git a/matrix_client/client.py b/matrix_client/client.py index 6182e46e..fb543fc9 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -17,6 +17,11 @@ from .errors import MatrixRequestError, MatrixUnexpectedResponse from .room import Room from .user import User +try: + from .crypto.olm_device import OlmDevice + ENCRYPTION_SUPPORT = True +except ImportError: + ENCRYPTION_SUPPORT = False from threading import Thread from time import sleep from uuid import uuid4 @@ -54,6 +59,13 @@ class MatrixClient(object): the token) if supplying a token; otherwise, ignored. valid_cert_check (bool): Check the homeservers certificate on connections? + cache_level (CACHE): One of CACHE.NONE, CACHE.SOME, or + CACHE.ALL (defined in module namespace). + encryption (bool): Optional. Whether or not to enable end-to-end encryption + support. + encryption_conf (dict): Optional. Configuration parameters for encryption. + Refer to :func:`~matrix_client.crypto.olm_device.OlmDevice` for supported + options, since it will be passed to this class. Returns: `MatrixClient` @@ -95,30 +107,12 @@ def global_callback(incoming_event): def __init__(self, base_url, token=None, user_id=None, valid_cert_check=True, sync_filter_limit=20, - cache_level=CACHE.ALL): - """ Create a new Matrix Client object. - - Args: - base_url (str): The url of the HS preceding /_matrix. - e.g. (ex: https://localhost:8008 ) - token (str): Optional. If you have an access token - supply it here. - user_id (str): Optional. You must supply the user_id - (as obtained when initially logging in to obtain - the token) if supplying a token; otherwise, ignored. - valid_cert_check (bool): Check the homeservers - certificate on connections? - cache_level (CACHE): One of CACHE.NONE, CACHE.SOME, or - CACHE.ALL (defined in module namespace). - - Returns: - MatrixClient - - Raises: - MatrixRequestError, ValueError - """ + cache_level=CACHE.ALL, encryption=False, encryption_conf=None): if token is not None and user_id is None: raise ValueError("must supply user_id along with token") + if encryption and not ENCRYPTION_SUPPORT: + raise ValueError("Failed to enable encryption. Please make sure the olm " + "library is available.") self.api = MatrixHttpApi(base_url, token) self.api.validate_certificate(valid_cert_check) @@ -127,6 +121,10 @@ def __init__(self, base_url, token=None, user_id=None, self.invite_listeners = [] self.left_listeners = [] self.ephemeral_listeners = [] + self.device_id = None + self._encryption = encryption + self.encryption_conf = encryption_conf or {} + self.olm_device = None if isinstance(cache_level, CACHE): self._cache_level = cache_level else: @@ -273,6 +271,13 @@ def login(self, username, password, limit=10, sync=True, device_id=None): self.token = response["access_token"] self.hs = response["home_server"] self.api.token = self.token + self.device_id = response["device_id"] + + if self._encryption: + self.olm_device = OlmDevice( + self.api, self.user_id, self.device_id, **self.encryption_conf) + self.olm_device.upload_identity_keys() + self.olm_device.upload_one_time_keys() if sync: """ Limit Filter """ @@ -548,9 +553,19 @@ def upload(self, content, content_type): ) def _mkroom(self, room_id): - self.rooms[room_id] = Room(self, room_id) + room = Room(self, room_id) + if self._encryption: + try: + event = self.api.get_state_event(room_id, "m.room.encryption") + if event["algorithm"] == "m.megolm.v1.aes-sha2": + room.encrypted = True + except MatrixRequestError as e: + if e.code != 404: + raise + self.rooms[room_id] = room return self.rooms[room_id] + # TODO better handling of the blocking I/O caused by update_one_time_key_counts def _sync(self, timeout_ms=30000): response = self.api.sync(self.sync_token, timeout_ms, filter=self.sync_filter) self.sync_token = response["next_batch"] @@ -569,6 +584,10 @@ def _sync(self, timeout_ms=30000): if room_id in self.rooms: del self.rooms[room_id] + if self._encryption and 'device_one_time_keys_count' in response: + self.olm_device.update_one_time_key_counts( + response['device_one_time_keys_count']) + for room_id, sync_room in response['rooms']['join'].items(): if room_id not in self.rooms: self._mkroom(room_id) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 330a18f4..514965db 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -4,6 +4,7 @@ from canonicaljson import encode_canonical_json from matrix_client.checks import check_user_id +from matrix_client.crypto.one_time_keys import OneTimeKeysManager logger = logging.getLogger(__name__) @@ -17,15 +18,123 @@ class OlmDevice(object): api (MatrixHttpApi): The api object used to make requests. user_id (str): Matrix user ID. Must match the one used when logging in. device_id (str): Must match the one used when logging in. + signed_keys_proportion (float): Optional. The proportion of signed one-time keys + we should maintain on the HS compared to unsigned keys. The maximum value of + ``1`` means only signed keys will be uploaded, while the minimum value of + ``0`` means only unsigned keys. The actual amount of keys is determined at + runtime from the given proportion and the maximum number of one-time keys + we can physically hold. + keys_threshold (float): Optional. Threshold below which a one-time key + replenishment is triggered. Must be between ``0`` and ``1``. For example, + ``0.1`` means that new one-time keys will be uploaded when there is less than + 10% of the maximum number of one-time keys on the server. """ - def __init__(self, api, user_id, device_id): + _olm_algorithm = 'm.olm.v1.curve25519-aes-sha2' + _megolm_algorithm = 'm.megolm.v1.aes-sha2' + _algorithms = [_olm_algorithm, _megolm_algorithm] + + def __init__(self, + api, + user_id, + device_id, + signed_keys_proportion=1, + keys_threshold=0.1): + if not 0 <= signed_keys_proportion <= 1: + raise ValueError('signed_keys_proportion must be between 0 and 1.') + if not 0 <= keys_threshold <= 1: + raise ValueError('keys_threshold must be between 0 and 1.') self.api = api check_user_id(user_id) self.user_id = user_id self.device_id = device_id self.olm_account = olm.Account() logger.info('Initialised Olm Device.') + self.identity_keys = self.olm_account.identity_keys + # Try to maintain half the number of one-time keys libolm can hold uploaded + # on the HS. This is because some keys will be claimed by peers but not + # used instantly, and we want them to stay in libolm, until the limit is reached + # and it starts discarding keys, starting by the oldest. + target_keys_number = self.olm_account.max_one_time_keys // 2 + self.one_time_keys_manager = OneTimeKeysManager(target_keys_number, + signed_keys_proportion, + keys_threshold) + + def upload_identity_keys(self): + """Uploads this device's identity keys to HS. + + This device must be the one used when logging in. + """ + device_keys = { + 'user_id': self.user_id, + 'device_id': self.device_id, + 'algorithms': self._algorithms, + 'keys': {'{}:{}'.format(alg, self.device_id): key + for alg, key in self.identity_keys.items()} + } + self.sign_json(device_keys) + ret = self.api.upload_keys(device_keys=device_keys) + self.one_time_keys_manager.server_counts = ret['one_time_key_counts'] + logger.info('Uploaded identity keys.') + + def upload_one_time_keys(self, force_update=False): + """Uploads new one-time keys to the HS, if needed. + + Args: + force_update (bool): Fetch the number of one-time keys currently on the HS + before uploading, even if we already know one. In most cases this should + not be necessary, as we get this value from sync responses. + + Returns: + A dict containg the number of new keys that were uploaded for each key type + (signed_curve25519 or curve25519). The format is + ``: ``. If no keys of a given type have been + uploaded, the corresponding key will not be present. Consequently, an + empty dict indicates that no keys were uploaded. + """ + if force_update or not self.one_time_keys_manager.server_counts: + counts = self.api.upload_keys()['one_time_key_counts'] + self.one_time_keys_manager.server_counts = counts + + signed_keys_to_upload = self.one_time_keys_manager.signed_curve25519_to_upload + unsigned_keys_to_upload = self.one_time_keys_manager.curve25519_to_upload + + self.olm_account.generate_one_time_keys(signed_keys_to_upload + + unsigned_keys_to_upload) + + one_time_keys = {} + keys = self.olm_account.one_time_keys['curve25519'] + for i, key_id in enumerate(keys): + if i < signed_keys_to_upload: + key = self.sign_json({'key': keys[key_id]}) + key_type = 'signed_curve25519' + else: + key = keys[key_id] + key_type = 'curve25519' + one_time_keys['{}:{}'.format(key_type, key_id)] = key + + ret = self.api.upload_keys(one_time_keys=one_time_keys) + self.one_time_keys_manager.server_counts = ret['one_time_key_counts'] + self.olm_account.mark_keys_as_published() + + keys_uploaded = {} + if unsigned_keys_to_upload: + keys_uploaded['curve25519'] = unsigned_keys_to_upload + if signed_keys_to_upload: + keys_uploaded['signed_curve25519'] = signed_keys_to_upload + logger.info('Uploaded new one-time keys: %s.', keys_uploaded) + return keys_uploaded + + def update_one_time_key_counts(self, counts): + """Update data on one-time keys count and upload new ones if necessary. + + Args: + counts (dict): Counts of keys currently on the HS for each key type. + """ + self.one_time_keys_manager.server_counts = counts + if self.one_time_keys_manager.should_upload(): + logger.info('Uploading new one-time keys.') + self.upload_one_time_keys() def sign_json(self, json): """Signs a JSON object. diff --git a/matrix_client/crypto/one_time_keys.py b/matrix_client/crypto/one_time_keys.py new file mode 100644 index 00000000..131dc023 --- /dev/null +++ b/matrix_client/crypto/one_time_keys.py @@ -0,0 +1,42 @@ +class OneTimeKeysManager(object): + """Handles one-time keys accounting for an OlmDevice.""" + + def __init__(self, target_keys_number, signed_keys_proportion, keys_threshold): + self.target_counts = { + 'signed_curve25519': int(round(signed_keys_proportion * target_keys_number)), + 'curve25519': int(round((1 - signed_keys_proportion) * target_keys_number)), + } + self._server_counts = {} + self.to_upload = {} + self.keys_threshold = keys_threshold + + @property + def server_counts(self): + return self._server_counts + + @server_counts.setter + def server_counts(self, server_counts): + self._server_counts = server_counts + self.update_keys_to_upload() + + def update_keys_to_upload(self): + for key_type, target_number in self.target_counts.items(): + num_keys = self._server_counts.get(key_type, 0) + num_to_create = max(target_number - num_keys, 0) + self.to_upload[key_type] = num_to_create + + def should_upload(self): + if not self._server_counts: + return True + for key_type, target_number in self.target_counts.items(): + if self._server_counts.get(key_type, 0) < target_number * self.keys_threshold: + return True + return False + + @property + def curve25519_to_upload(self): + return self.to_upload.get('curve25519', 0) + + @property + def signed_curve25519_to_upload(self): + return self.to_upload.get('signed_curve25519', 0) diff --git a/matrix_client/room.py b/matrix_client/room.py index 2fc30356..af9e348c 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -22,14 +22,13 @@ class Room(object): - """Call room-specific functions after joining a room from the client.""" + """Call room-specific functions after joining a room from the client. - def __init__(self, client, room_id): - """Create a blank Room object. + NOTE: This should ideally be called from within the Client. + NOTE: This does not verify the room with the Home Server. + """ - NOTE: This should ideally be called from within the Client. - NOTE: This does not verify the room with the Home Server. - """ + def __init__(self, client, room_id): check_room_id(room_id) self.room_id = room_id @@ -47,6 +46,7 @@ def __init__(self, client, room_id): self.guest_access = None self._prev_batch = None self._members = [] + self.encrypted = False def set_user_profile(self, displayname=None, @@ -617,6 +617,22 @@ def set_guest_access(self, allow_guests): except MatrixRequestError: return False + def enable_encryption(self): + """Enables encryption in the room. + + NOTE: Once enabled, encryption cannot be disabled. + + Returns: + True if successful, False if not + """ + try: + self.send_state_event("m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}) + self.encrypted = True + return True + except MatrixRequestError: + return False + def _process_state_event(self, state_event): if "type" not in state_event: return # Ignore event @@ -638,6 +654,9 @@ def _process_state_event(self, state_event): self.invite_only = econtent["join_rule"] == "invite" elif etype == "m.room.guest_access": self.guest_access = econtent["guest_access"] == "can_join" + elif etype == "m.room.encryption": + if econtent.get("algorithm") == "m.megolm.v1.aes-sha2": + self.encrypted = True elif etype == "m.room.member" and clevel == clevel.ALL: # tracking room members can be large e.g. #matrix:matrix.org if econtent["membership"] == "join": diff --git a/test/api_test.py b/test/api_test.py index 01369662..a88f45d5 100644 --- a/test/api_test.py +++ b/test/api_test.py @@ -1,5 +1,6 @@ import responses import pytest +import json from matrix_client import client from matrix_client.errors import MatrixRequestError @@ -232,3 +233,93 @@ def test_send_to_device(self): req = responses.calls[0].request assert req.url == send_to_device_url assert req.method == 'PUT' + + +class TestRoomApi: + cli = client.MatrixClient("http://example.com") + user_id = "@user:matrix.org" + room_id = "#foo:matrix.org" + + @responses.activate + def test_create_room_visibility_public(self): + create_room_url = "http://example.com" \ + "/_matrix/client/r0/createRoom" + responses.add( + responses.POST, + create_room_url, + json='{"room_id": "!sefiuhWgwghwWgh:example.com"}' + ) + self.cli.api.create_room( + name="test", + alias="#test:example.com", + is_public=True + ) + req = responses.calls[0].request + assert req.url == create_room_url + assert req.method == 'POST' + j = json.loads(req.body) + assert j["room_alias_name"] == "#test:example.com" + assert j["visibility"] == "public" + assert j["name"] == "test" + + @responses.activate + def test_create_room_visibility_private(self): + create_room_url = "http://example.com" \ + "/_matrix/client/r0/createRoom" + responses.add( + responses.POST, + create_room_url, + json='{"room_id": "!sefiuhWgwghwWgh:example.com"}' + ) + self.cli.api.create_room( + name="test", + alias="#test:example.com", + is_public=False + ) + req = responses.calls[0].request + assert req.url == create_room_url + assert req.method == 'POST' + j = json.loads(req.body) + assert j["room_alias_name"] == "#test:example.com" + assert j["visibility"] == "private" + assert j["name"] == "test" + + @responses.activate + def test_create_room_federate_true(self): + create_room_url = "http://example.com" \ + "/_matrix/client/r0/createRoom" + responses.add( + responses.POST, + create_room_url, + json='{"room_id": "!sefiuhWgwghwWgh:example.com"}' + ) + self.cli.api.create_room( + name="test2", + alias="#test2:example.com", + federate=True + ) + req = responses.calls[0].request + assert req.url == create_room_url + assert req.method == 'POST' + j = json.loads(req.body) + assert j["creation_content"]["m.federate"] + + @responses.activate + def test_create_room_federate_false(self): + create_room_url = "http://example.com" \ + "/_matrix/client/r0/createRoom" + responses.add( + responses.POST, + create_room_url, + json='{"room_id": "!sefiuhWgwghwWgh:example.com"}' + ) + self.cli.api.create_room( + name="test", + alias="#test:example.com", + federate=False + ) + req = responses.calls[0].request + assert req.url == create_room_url + assert req.method == 'POST' + j = json.loads(req.body) + assert not j["creation_content"]["m.federate"] diff --git a/test/client_test.py b/test/client_test.py index 4071d5bf..4076dbcf 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -130,6 +130,17 @@ def test_state_event(): room._process_state_event(ev) assert room.guest_access + # test encryption + room.encrypted = False + ev["type"] = "m.room.encryption" + ev["content"] = {"algorithm": "m.megolm.v1.aes-sha2"} + room._process_state_event(ev) + assert room.encrypted + # encrypted flag must not be cleared on configuration change + ev["content"] = {"algorithm": None} + room._process_state_event(ev) + assert room.encrypted + def test_get_user(): client = MatrixClient("http://example.com") @@ -452,3 +463,78 @@ def test_room_guest_access(): assert room.set_guest_access(True) assert room.guest_access + + +@responses.activate +def test_enable_encryption(): + pytest.importorskip('olm') + client = MatrixClient(HOSTNAME, encryption=True) + + login_path = HOSTNAME + MATRIX_V2_API_PATH + "/login" + responses.add(responses.POST, login_path, + json=response_examples.example_success_login_response) + + upload_path = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' + responses.add(responses.POST, upload_path, body='{"one_time_key_counts": {}}') + + client.login("@example:localhost", "password", sync=False) + + assert client.olm_device + + +@responses.activate +def test_enable_encryption_in_room(): + client = MatrixClient(HOSTNAME) + room_id = "!UcYsUzyxTGDxLBEvLz:matrix.org" + room = client._mkroom(room_id) + assert not room.encrypted + encryption_state_path = HOSTNAME + MATRIX_V2_API_PATH + \ + "/rooms/" + quote(room_id) + "/state/m.room.encryption" + + responses.add(responses.PUT, encryption_state_path, + json=response_examples.example_event_response) + + assert room.enable_encryption() + assert room.encrypted + + +@responses.activate +def test_detect_encryption_state(): + client = MatrixClient(HOSTNAME, encryption=True) + room_id = "!UcYsUzyxTGDxLBEvLz:matrix.org" + + encryption_state_path = HOSTNAME + MATRIX_V2_API_PATH + \ + "/rooms/" + quote(room_id) + "/state/m.room.encryption" + responses.add(responses.GET, encryption_state_path, + json={"algorithm": "m.megolm.v1.aes-sha2"}) + responses.add(responses.GET, encryption_state_path, + json={}, status=404) + + room = client._mkroom(room_id) + assert room.encrypted + + room = client._mkroom(room_id) + assert not room.encrypted + + +@responses.activate +def test_one_time_keys_sync(): + client = MatrixClient(HOSTNAME, encryption=True) + sync_url = HOSTNAME + MATRIX_V2_API_PATH + "/sync" + sync_response = deepcopy(response_examples.example_sync) + payload = {'dummy': 1} + sync_response["device_one_time_keys_count"] = payload + sync_response['rooms']['join'] = {} + + class DummyDevice: + + def update_one_time_key_counts(self, payload): + self.payload = payload + + device = DummyDevice() + client.olm_device = device + + responses.add(responses.GET, sync_url, json=sync_response) + + client._sync() + assert device.payload == payload diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index a32eea08..84f56df6 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -1,10 +1,15 @@ import pytest pytest.importorskip("olm") # noqa +import json from copy import deepcopy +import responses + +from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient from matrix_client.crypto.olm_device import OlmDevice +from test.response_examples import example_key_upload_response HOSTNAME = 'http://example.com' @@ -75,3 +80,127 @@ def test_sign_verify(self): signed_payload = self.device.sign_json(example_payload) assert self.device.verify_json(signed_payload, self.signing_key, self.user_id, self.device_id) + + @responses.activate + def test_upload_identity_keys(self): + upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' + self.device.one_time_keys_manager.server_counts = {} + resp = deepcopy(example_key_upload_response) + + responses.add(responses.POST, upload_url, json=resp) + + assert self.device.upload_identity_keys() is None + assert self.device.one_time_keys_manager.server_counts == \ + resp['one_time_key_counts'] + + req_device_keys = json.loads(responses.calls[0].request.body)['device_keys'] + assert req_device_keys['user_id'] == self.user_id + assert req_device_keys['device_id'] == self.device_id + assert req_device_keys['algorithms'] == self.device._algorithms + assert 'keys' in req_device_keys + assert 'signatures' in req_device_keys + assert self.device.verify_json(req_device_keys, self.signing_key, self.user_id, + self.device_id) + + @pytest.mark.parametrize('proportion', [-1, 2]) + def test_upload_identity_keys_invalid(self, proportion): + with pytest.raises(ValueError): + OlmDevice(self.cli.api, + self.user_id, + self.device_id, + signed_keys_proportion=proportion) + + @responses.activate + @pytest.mark.parametrize('proportion', [0, 1, 0.5, 0.33]) + def test_upload_one_time_keys(self, proportion): + upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' + resp = deepcopy(example_key_upload_response) + counts = resp['one_time_key_counts'] + counts['curve25519'] = counts['signed_curve25519'] = 10 + responses.add(responses.POST, upload_url, json=resp) + + device = OlmDevice( + self.cli.api, self.user_id, self.device_id, signed_keys_proportion=proportion) + assert not device.one_time_keys_manager.server_counts + + max_keys = device.olm_account.max_one_time_keys // 2 + signed_keys_to_upload = \ + max(round(max_keys * proportion) - counts['signed_curve25519'], 0) + unsigned_keys_to_upload = \ + max(round(max_keys * (1 - proportion)) - counts['curve25519'], 0) + expected_return = {} + if signed_keys_to_upload: + expected_return['signed_curve25519'] = signed_keys_to_upload + if unsigned_keys_to_upload: + expected_return['curve25519'] = unsigned_keys_to_upload + + assert device.upload_one_time_keys() == expected_return + assert len(responses.calls) == 2 + assert device.one_time_keys_manager.server_counts == resp['one_time_key_counts'] + + req_otk = json.loads(responses.calls[1].request.body)['one_time_keys'] + assert len(req_otk) == unsigned_keys_to_upload + signed_keys_to_upload + assert len([key for key in req_otk if not key.startswith('signed')]) == \ + unsigned_keys_to_upload + assert len([key for key in req_otk if key.startswith('signed')]) == \ + signed_keys_to_upload + for k in req_otk: + if k == 'signed_curve25519': + device.verify_json(req_otk[k], device.signing_key, device.user_id, + device.device_id) + + @responses.activate + def test_upload_one_time_keys_enough(self): + upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' + self.device.one_time_keys_manager.server_counts = {} + limit = self.device.olm_account.max_one_time_keys // 2 + resp = {'one_time_key_counts': {'signed_curve25519': limit}} + responses.add(responses.POST, upload_url, json=resp) + + assert not self.device.upload_one_time_keys() + + @responses.activate + def test_upload_one_time_keys_force_update(self): + upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' + self.device.one_time_keys_manager.server_counts = {'curve25519': 10} + resp = deepcopy(example_key_upload_response) + responses.add(responses.POST, upload_url, json=resp) + + self.device.upload_one_time_keys() + assert len(responses.calls) == 1 + + self.device.upload_one_time_keys(force_update=True) + assert len(responses.calls) == 3 + + @responses.activate + @pytest.mark.parametrize('count,should_upload', [(0, True), (25, False), (4, True)]) + def test_update_one_time_key_counts(self, count, should_upload): + upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' + responses.add(responses.POST, upload_url, json={'one_time_key_counts': {}}) + self.device.one_time_keys_manager.target_counts['signed_curve25519'] = 50 + self.device.one_time_keys_manager.server_counts.clear() + + count_dict = {} + if count: + count_dict['signed_curve25519'] = count + + self.device.update_one_time_key_counts(count_dict) + + if should_upload: + if count: + req_otk = json.loads(responses.calls[0].request.body)['one_time_keys'] + assert len(responses.calls) == 1 + else: + req_otk = json.loads(responses.calls[1].request.body)['one_time_keys'] + assert len(responses.calls) == 2 + assert len(req_otk) == 50 - count + else: + assert not len(responses.calls) + + @pytest.mark.parametrize('threshold', [-1, 2]) + def test_invalid_keys_threshold(self, threshold): + with pytest.raises(ValueError): + OlmDevice(self.cli.api, + self.user_id, + self.device_id, + keys_threshold=threshold) diff --git a/test/response_examples.py b/test/response_examples.py index 2ae565fc..7e4787bb 100644 --- a/test/response_examples.py +++ b/test/response_examples.py @@ -160,3 +160,17 @@ example_event_response = { "event_id": "YUwRidLecu" } + +example_key_upload_response = { + "one_time_key_counts": { + "curve25519": 10, + "signed_curve25519": 20 + } +} + +example_success_login_response = { + "user_id": "@example:localhost", + "access_token": "abc123", + "home_server": "matrix.org", + "device_id": "GHTYAJCE" +}