From b0fa2a65b0f7a74a733ca32fdb68b2c2a0fad585 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 24 May 2018 20:47:07 +0200 Subject: [PATCH 01/15] track the device list of users and download keys Use a thread to allow downloading the keys in a non-blocking way. Implement sort of a queue of users who need updating, which should be efficient (request is fired as soon as possible with as much data as possible) and avoid races (we always have only one download in progress). --- matrix_client/client.py | 9 + matrix_client/crypto/device_list.py | 269 ++++++++++++++++++++++++++++ matrix_client/crypto/olm_device.py | 4 + 3 files changed, 282 insertions(+) create mode 100644 matrix_client/crypto/device_list.py diff --git a/matrix_client/client.py b/matrix_client/client.py index 703b825c..f9a2d876 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -581,6 +581,15 @@ def _mkroom(self, room_id): # TODO better handling of the blocking I/O caused by update_one_time_key_counts def _sync(self, timeout_ms=30000): response = self.api.sync(self.sync_token, timeout_ms, filter=self.sync_filter) + + if self._encryption and 'device_lists' in response: + if response['device_lists'].get('changed'): + self.olm_device.device_list.update_user_device_keys( + response['device_lists']['changed'], self.sync_token) + if response['device_lists'].get('left'): + self.olm_device.device_list.stop_tracking_users( + response['device_lists']['left']) + self.sync_token = response["next_batch"] for presence_update in response['presence']['events']: diff --git a/matrix_client/crypto/device_list.py b/matrix_client/crypto/device_list.py new file mode 100644 index 00000000..c9a22966 --- /dev/null +++ b/matrix_client/crypto/device_list.py @@ -0,0 +1,269 @@ +import logging +from collections import defaultdict +from threading import Thread, Condition, Event + +from matrix_client.errors import MatrixHttpLibError, MatrixRequestError + +logger = logging.getLogger(__name__) + + +class DeviceList: + """Allows to maintain a list of devices up-to-date for an OlmDevice. + + Offers blocking and non-blocking methods to fetch device keys when appropriate. + NOTE: Spawns a thread that will last until program termination. + + Args: + olm_device (OlmDevice): Will be used to get additional info, such as device id. + api (MatrixHttpApi): The api object used to make requests. + device_keys (defaultdict(dict)): A map from user to device to keys. + """ + + def __init__(self, olm_device, api, device_keys): + self.olm_device = olm_device + self.api = api + self.device_keys = device_keys + # Stores the ids of users who need updating + self.outdated_user_ids = _OutdatedUsersSet() + # Stores the ids of users we are currently tracking. We can assume the device + # keys of these users are up-to-date as long as no downloading is in progress. + # We should track every user we share an encrypted room with. + self.tracked_user_ids = set() + # Allows to wake up the thread when there are new users to update, and to + # synchronise shared data. + self.thread_condition = Condition() + self.update_thread = _UpdateDeviceList( + self.thread_condition, self.outdated_user_ids, self._download_device_keys, + self.tracked_user_ids + ) + self.update_thread.start() + + def get_room_device_keys(self, room, blocking=True): + """Gets the keys of all devices present in the room. + + Makes sure not to download keys of users we are already tracking. + The users we were not yet tracking will get tracked automatically. + + Args: + room (Room): The room to use. + blocking (bool): Optional. Whether to wait for the keys to have been + downloaded before returning. + """ + logger.info('Fetching all missing keys in room %s.', room.room_id) + user_ids = {u.user_id for u in room.get_joined_members()} - self.tracked_user_ids + if not user_ids: + logger.info('Already had all the keys in room %s.', room.room_id) + if blocking: + # Wait on an eventual download to finish + self.update_thread.event.wait() + return + with self.thread_condition: + self.outdated_user_ids.update(user_ids) + if blocking: + # Will ensure the user_ids we just added are processed + event = Event() + self.outdated_user_ids.events.add(event) + self.thread_condition.notify() + if blocking: + event.wait() + + def add_users(self, user_ids): + """Add users to be tracked, and download their device keys. + + NOTE: this is non-blocking and will return before the keys are downloaded. + + Args: + user_ids (iterable): Any iterable containing user ids. + """ + user_ids = user_ids.difference(self.tracked_user_ids) + if user_ids: + self._add_outdated_users(user_ids) + + def stop_tracking_users(self, user_ids): + """Stop tracking users. + + NOTE: Keys will not be deleted. + + Args: + user_ids (iterable): Any iterable containing user ids. + """ + with self.thread_condition: + self.tracked_user_ids.difference_update(user_ids) + self.outdated_user_ids.difference_update(user_ids) + logger.info('Stopped tracking users: %s.', user_ids) + + def update_user_device_keys(self, user_ids, since_token=None): + """Triggers an update for users we already track. + + Args: + user_ids (iterable): Any iterable containing user ids. + since_token (str): Optional. Since token of a sync request, if triggering + the update as a result of that sync request. + """ + user_ids = self.tracked_user_ids.intersection(user_ids) + if not user_ids: + return + logger.info('Updating the device lists of users: %s, using token %s', + user_ids, since_token) + self._add_outdated_users(user_ids, since_token=since_token) + + def _add_outdated_users(self, user_ids, since_token=None): + """Stop tracking users. Keys will not be deleted. + + Args: + user_ids (iterable): Any iterable containing user ids. + since_token (str): Optional. Since token of a sync request. + """ + with self.thread_condition: + self.outdated_user_ids.update(user_ids) + if since_token: + self.outdated_user_ids.sync_token = since_token + self.thread_condition.notify() + + def _download_device_keys(self, user_devices, since_token=None): + """Download and store device keys, if they pass security checks. + + Args: + user_devices (dict): Format is ``user_id: [device_ids]``. + since_token (str): Optional. Since token of a sync request. + """ + changed = defaultdict(dict) + resp = self.api.query_keys(user_devices, token=since_token) + if resp.get('failures'): + logger.warning('Failed to download keys from the following unreachable ' + 'homeservers %s.', resp['failures']) + device_keys = resp['device_keys'] + for user_id in user_devices: + # The response might not contain every user_ids we requested + for device_id, payload in device_keys.get(user_id, {}).items(): + if device_id == self.olm_device.device_id: + continue + if payload['user_id'] != user_id or payload['device_id'] != device_id: + logger.warning('Mismatch in keys payload of device %s (%s) of user ' + '%s (%s).', payload['device_id'], device_id, + payload['user_id'], user_id) + continue + try: + signing_key = payload['keys']['ed25519:{}'.format(device_id)] + curve_key = payload['keys']['curve25519:{}'.format(device_id)] + except KeyError as e: + logger.warning('Invalid identity keys payload from device %s of' + 'user %s: %s.', device_id, user_id, e) + continue + verified = self.olm_device.verify_json( + payload, signing_key, user_id, device_id) + if not verified: + logger.warning('Signature verification failed for device %s of ' + 'user %s.', device_id, user_id) + continue + keys = self.device_keys[user_id].setdefault(device_id, {}) + if keys: + if keys['ed25519'] != signing_key: + logger.warning('Ed25519 key has changed for device %s of ' + 'user %s.', device_id, user_id) + continue + if keys['curve25519'] == curve_key: + continue + else: + keys['ed25519'] = signing_key + keys['curve25519'] = curve_key + changed[user_id][device_id] = keys + + logger.info('Successfully downloaded keys for devices: %s.', + {user_id: list(changed[user_id]) for user_id in changed}) + return changed + + +class _OutdatedUsersSet(set): + """Allows to know if elements in a set have been processed. + + This is done by adding elements along with an Event object. Then, functions + processing the set should set the events when they are done. + """ + + def __init__(self, iterable=()): + self.events = set() + self._sync_token = None + super(_OutdatedUsersSet, self).__init__(iterable) + + def mark_as_processed(self): + for event in self.events: + event.set() + + def copy(self): + new_set = _OutdatedUsersSet(self) + new_set.events = self.events.copy() + return new_set + + def clear(self): + self.events.clear() + super(_OutdatedUsersSet, self).clear() + + def update(self, iterable): + super(_OutdatedUsersSet, self).update(iterable) + if isinstance(iterable, _OutdatedUsersSet): + self.events.update(iterable.events) + + @property + def sync_token(self): + return self._sync_token + + @sync_token.setter + def sync_token(self, token): + if not self._sync_token or token > self._sync_token: + self._sync_token = token + + +class _UpdateDeviceList(Thread): + + def __init__(self, cond, user_ids, download_method, tracked_user_ids): + # We wait on this condition when there is nothing to do. Outside code should use + # it to notify us when they add data to be processed in outdated_user_ids so that + # we can wake up and process it. + self.cond = cond + self.outdated_user_ids = user_ids + self.download = download_method + self.tracked_user_ids = tracked_user_ids + # Cleared when we start a download, and set when we have finished it. This can be + # used by outside code in order to know if we are in the middle of a download, and + # allows to wait for it to complete by waiting on this event. + self.event = Event() + # Used internally to terminate gracefully on program exit. + self._should_terminate = Event() + super(_UpdateDeviceList, self).__init__() + + def run(self): + while True and not self._should_terminate.is_set(): + with self.cond: + while not self.outdated_user_ids: + # Avoid any deadlocks + self.outdated_user_ids.mark_as_processed() + self.event.set() + logger.debug('Update thread is going to sleep...') + self.cond.wait() + logger.debug('Update thread woke up!') + if self._should_terminate.is_set(): + return + to_download = self.outdated_user_ids.copy() + self.outdated_user_ids.clear() + self.event.clear() + self.tracked_user_ids.update(to_download) + payload = {user_id: [] for user_id in to_download} + logger.info('Downloading device keys for users: %s.', to_download) + try: + self.download(payload, self.outdated_user_ids.sync_token) + self.event.set() + to_download.mark_as_processed() + except (MatrixHttpLibError, MatrixRequestError) as e: + logger.warning('Network error when fetching device keys (will retry): %s', + e) + with self.cond: + self.outdated_user_ids.update(to_download) + + def join(self, timeout=None): + # If we are joined, this means that the main program is terminating. + # We should terminate too. + self._should_terminate.set() + with self.cond: + self.cond.notify() + super(_UpdateDeviceList, self).join(timeout=timeout) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 514965db..d2e114b1 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -1,10 +1,12 @@ import logging +from collections import defaultdict import olm from canonicaljson import encode_canonical_json from matrix_client.checks import check_user_id from matrix_client.crypto.one_time_keys import OneTimeKeysManager +from matrix_client.crypto.device_list import DeviceList logger = logging.getLogger(__name__) @@ -59,6 +61,8 @@ def __init__(self, self.one_time_keys_manager = OneTimeKeysManager(target_keys_number, signed_keys_proportion, keys_threshold) + self.device_keys = defaultdict(dict) + self.device_list = DeviceList(self, api, self.device_keys) def upload_identity_keys(self): """Uploads this device's identity keys to HS. From c8496d39e7459e2afd5d73baa4e51ccad7f173cf Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sun, 24 Jun 2018 04:17:39 +0200 Subject: [PATCH 02/15] add device tracking tests --- test/crypto/device_list_test.py | 272 ++++++++++++++++++++++++++++++++ test/crypto/olm_device_test.py | 1 + test/response_examples.py | 31 ++++ 3 files changed, 304 insertions(+) create mode 100644 test/crypto/device_list_test.py diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py new file mode 100644 index 00000000..32f75ce6 --- /dev/null +++ b/test/crypto/device_list_test.py @@ -0,0 +1,272 @@ +import pytest +pytest.importorskip("olm") # noqa + +import json +from copy import deepcopy +from threading import Event, Condition + +import responses + +from matrix_client.api import MATRIX_V2_API_PATH +from matrix_client.client import MatrixClient +from matrix_client.room import User +from matrix_client.errors import MatrixRequestError +from matrix_client.crypto.olm_device import OlmDevice +from matrix_client.crypto.device_list import (_OutdatedUsersSet as OutdatedUsersSet, + _UpdateDeviceList as UpdateDeviceList) +from test.response_examples import example_key_query_response + +HOSTNAME = 'http://example.com' + + +class TestDeviceList: + cli = MatrixClient(HOSTNAME) + user_id = '@test:example.com' + alice = '@alice:example.com' + room_id = '!test:example.com' + device_id = 'AUIETSRN' + device = OlmDevice(cli.api, user_id, device_id) + device_list = device.device_list + signing_key = device.olm_account.identity_keys['ed25519'] + query_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/query' + + @responses.activate + def test_download_device_keys(self): + # The method we want to test + download_device_keys = self.device_list._download_device_keys + bob = '@bob:example.com' + eve = '@eve:example.com' + user_devices = {self.alice: [], bob: [], self.user_id: []} + + # This response is correct for Alice's keys, but lacks Bob's + # There are no failures + resp = example_key_query_response + responses.add(responses.POST, self.query_url, json=resp) + + # Still correct, but Alice's identity key has changed + resp = deepcopy(example_key_query_response) + new_id_key = 'ijxGZqwB/UvMtKABdaCdrI0OtQI6NhHBYiknoCkdWng' + payload = resp['device_keys'][self.alice]['JLAFKJWSCS'] + payload['keys']['curve25519:JLAFKJWSCS'] = new_id_key + payload['signatures'][self.alice]['ed25519:JLAFKJWSCS'] = \ + ('D9oLtYefMIr4StiHTIzn3+bhtPCfrZNDU9jsUbMu3MicfZLl4d8WlYn3TPmbwDi8XMGcT' + 'nNnqfdi/tYUPvKfCA') + responses.add(responses.POST, self.query_url, json=resp) + + # Still correct, but Alice's signing key has changed + alice_device = OlmDevice(self.cli.api, self.alice, 'JLAFKJWSCS') + resp = deepcopy(example_key_query_response) + resp['device_keys'][self.alice]['JLAFKJWSCS']['keys']['ed25519:JLAFKJWSCS'] = \ + alice_device.identity_keys['ed25519'] + resp['device_keys'][self.alice]['JLAFKJWSCS'] = \ + alice_device.sign_json(resp['device_keys'][self.alice]['JLAFKJWSCS']) + responses.add(responses.POST, self.query_url, json=resp) + + # Response containing an unknown user + resp = deepcopy(example_key_query_response) + user_device = resp['device_keys'].pop(self.alice) + resp['device_keys'][eve] = user_device + responses.add(responses.POST, self.query_url, json=resp) + + # Response with an invalid signature + resp = deepcopy(example_key_query_response) + resp['device_keys'][self.alice]['JLAFKJWSCS']['test'] = 1 + responses.add(responses.POST, self.query_url, json=resp) + + # Response with a requested user and valid signature, but with a mismatch + resp = deepcopy(example_key_query_response) + user_device = resp['device_keys'].pop(self.alice) + resp['device_keys'][bob] = user_device + responses.add(responses.POST, self.query_url, json=resp) + + # Response with an invalid keys field + resp = deepcopy(example_key_query_response) + keys_field = resp['device_keys'][self.alice]['JLAFKJWSCS']['keys'] + key = keys_field.pop("ed25519:JLAFKJWSCS") + keys_field["ed25519:wrong"] = key + # Cover a missing branch by adding failures + resp["failures"]["other.com"] = {} + # And one more by adding ourself + resp['device_keys'][self.user_id] = {self.device_id: 'dummy'} + responses.add(responses.POST, self.query_url, json=resp) + + self.device.device_keys.clear() + assert download_device_keys(user_devices) + req = json.loads(responses.calls[0].request.body) + assert req['device_keys'] == {self.alice: [], bob: [], self.user_id: []} + expected_device_keys = { + self.alice: { + 'JLAFKJWSCS': { + 'curve25519': '3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI', + 'ed25519': 'VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA' + } + } + } + assert self.device.device_keys == expected_device_keys + + # Different curve25519, key should get updated + assert download_device_keys(user_devices) + expected_device_keys[self.alice]['JLAFKJWSCS']['curve25519'] = new_id_key + assert self.device.device_keys == expected_device_keys + + # Different ed25519, key should not get updated + assert not download_device_keys(user_devices) + assert self.device.device_keys == expected_device_keys + + self.device.device_keys.clear() + # All the remaining responses are wrong and we should not add the key + for _ in range(4): + assert not download_device_keys(user_devices) + assert self.device.device_keys == {} + + assert len(responses.calls) == 7 + + @responses.activate + def test_update_thread(self): + # Normal run + event = Event() + outdated_users = OutdatedUsersSet({self.user_id}) + outdated_users.events.add(event) + + def dummy_download(user_devices, since_token=None): + assert user_devices == {self.user_id: []} + return + thread = UpdateDeviceList(Condition(), outdated_users, dummy_download, set()) + + thread.start() + event.wait() + assert not thread.outdated_user_ids + assert thread.event.is_set() + assert thread.tracked_user_ids == {self.user_id} + thread.join() + assert not thread.is_alive() + + # Error run + outdated_users = OutdatedUsersSet({self.user_id}) + + def error_on_first_download(user_devices, since_token=None): + error_on_first_download.c += 1 + if error_on_first_download.c == 1: + raise MatrixRequestError + return + error_on_first_download.c = 0 + thread = UpdateDeviceList( + Condition(), outdated_users, error_on_first_download, set()) + thread.start() + thread.event.wait() + assert error_on_first_download.c == 2 + assert not thread.outdated_user_ids + thread.join() + + # Cover a missing branch + thread = UpdateDeviceList( + Condition(), outdated_users, error_on_first_download, set()) + thread._should_terminate.set() + thread.start() + thread.join() + assert not thread.is_alive() + + @responses.activate + def test_get_room_device_keys(self): + self.device_list.tracked_user_ids.clear() + room = self.cli._mkroom(self.room_id) + room._members[self.alice] = User(self.cli.api, self.alice) + + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + # Blocking + self.device_list.get_room_device_keys(room) + assert self.device_list.tracked_user_ids == {self.alice} + assert self.device_list.device_keys[self.alice]['JLAFKJWSCS'] + + # Same, but we already track the user + self.device_list.get_room_device_keys(room) + + # Non-blocking + self.device_list.tracked_user_ids.clear() + # We have to block for testing purposes, though + self.device_list.update_thread.event.clear() + self.device_list.get_room_device_keys(room, blocking=False) + self.device_list.update_thread.event.wait() + + # Same, but we already track the user + self.device_list.get_room_device_keys(room, blocking=False) + + @responses.activate + def test_add_users(self): + self.device_list.tracked_user_ids.clear() + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + self.device_list.update_thread.event.clear() + self.device_list.add_users({self.alice}) + self.device_list.update_thread.event.wait() + assert self.device_list.tracked_user_ids == {self.alice} + assert len(responses.calls) == 1 + + # Same, but we are already tracking Alice + self.device_list.add_users({self.alice}) + assert len(responses.calls) == 1 + + def test_stop_tracking_users(self): + self.device_list.tracked_user_ids.clear() + self.device_list.tracked_user_ids.add(self.alice) + self.device_list.outdated_user_ids.clear() + self.device_list.outdated_user_ids.add(self.alice) + + self.device_list.stop_tracking_users({self.alice}) + + assert not self.device_list.tracked_user_ids + assert not self.device_list.outdated_user_ids + + @responses.activate + def test_update_user_device_keys(self): + self.device_list.tracked_user_ids.clear() + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + self.device_list.update_user_device_keys({self.alice}) + assert len(responses.calls) == 0 + + self.device_list.tracked_user_ids.add(self.alice) + + self.device_list.update_thread.event.clear() + self.device_list.update_user_device_keys({self.alice}, since_token='dummy') + self.device_list.update_thread.event.wait() + assert len(responses.calls) == 1 + + +def test_outdated_users_set(): + s = OutdatedUsersSet() + assert not s + + s = OutdatedUsersSet({1}) + event = Event() + s.events.add(event) + assert s == {1} + + # Make a manual copy of s + t = OutdatedUsersSet() + t.add(1) + t.events.add(event) + assert t == s and t.events == s.events + + u = s.copy() + event2 = Event() + u.add(2) + u.events.add(event2) + # Check that modifying u didn't change s + assert t == s and t.events == s.events + + s.update(u) + assert s == {1, 2} and s.events == {event, event2} + + s.mark_as_processed() + assert event.is_set() + + new = 's72594_4483_1935' + s.sync_token = new + old = 's72594_4483_1934' + s.sync_token = old + assert s.sync_token == new + + s.clear() + assert not s and not s.events diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 84f56df6..cab7c56d 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -17,6 +17,7 @@ class TestOlmDevice: cli = MatrixClient(HOSTNAME) user_id = '@user:matrix.org' + room_id = '!test:example.com' device_id = 'QBUAZIFURK' device = OlmDevice(cli.api, user_id, device_id) signing_key = device.olm_account.identity_keys['ed25519'] diff --git a/test/response_examples.py b/test/response_examples.py index 2d45aa86..d827a05b 100644 --- a/test/response_examples.py +++ b/test/response_examples.py @@ -184,3 +184,34 @@ "og:image:width": 48, "og:title": "Matrix Blog Post" } + +example_key_query_response = { + "failures": {}, + "device_keys": { + "@alice:example.com": { + "JLAFKJWSCS": { + "user_id": "@alice:example.com", + "device_id": "JLAFKJWSCS", + "algorithms": [ + "m.olm.curve25519-aes-sha256", + "m.megolm.v1.aes-sha" + ], + "keys": { + "curve25519:JLAFKJWSCS": ("3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIAr" + "zgyI"), + "ed25519:JLAFKJWSCS": "VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA" + }, + "signatures": { + "@alice:example.com": { + "ed25519:JLAFKJWSCS": + ("wux6Dhjtk7GYPMW54hnx0doVH0NvuUAFBleL5OW99jhbjIutufglAgrYAcu8" + "ueacgNyeSumvtzVIPZXgbB2BCg") + } + }, + "unsigned": { + "device_display_name": "Alice'smobilephone" + } + } + } + } +} From 69d300eec2c6d6f504060227e55987802c319652 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 12 Jun 2018 22:56:34 +0200 Subject: [PATCH 03/15] track devices of users in encrypted rooms This has the effect of tracking the device lists of users proactively. --- matrix_client/client.py | 4 ++++ matrix_client/crypto/device_list.py | 27 +++++++++++++++++++++++++-- matrix_client/room.py | 7 +++++-- test/crypto/device_list_test.py | 16 +++++++++++++--- 4 files changed, 47 insertions(+), 7 deletions(-) diff --git a/matrix_client/client.py b/matrix_client/client.py index f9a2d876..0c2a8af6 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -636,6 +636,10 @@ def _sync(self, timeout_ms=30000): ): listener['callback'](event) + if self._encryption and room.encrypted: + # Track the new users in the room + self.olm_device.device_list.track_pending_users() + for event in sync_room['ephemeral']['events']: event['room_id'] = room_id room._put_ephemeral_event(event) diff --git a/matrix_client/crypto/device_list.py b/matrix_client/crypto/device_list.py index c9a22966..c0bf406a 100644 --- a/matrix_client/crypto/device_list.py +++ b/matrix_client/crypto/device_list.py @@ -1,6 +1,6 @@ import logging from collections import defaultdict -from threading import Thread, Condition, Event +from threading import Thread, Condition, Event, Lock from matrix_client.errors import MatrixHttpLibError, MatrixRequestError @@ -25,6 +25,9 @@ def __init__(self, olm_device, api, device_keys): self.device_keys = device_keys # Stores the ids of users who need updating self.outdated_user_ids = _OutdatedUsersSet() + # Stores the ids of users to fetch the device keys of eventually + self.pending_outdated_user_ids = set() + self.pending_users_lock = Lock() # Stores the ids of users we are currently tracking. We can assume the device # keys of these users are up-to-date as long as no downloading is in progress. # We should track every user we share an encrypted room with. @@ -67,7 +70,27 @@ def get_room_device_keys(self, room, blocking=True): if blocking: event.wait() - def add_users(self, user_ids): + def track_user_no_download(self, user_id): + """Add user to be tracked, but do not track it instantly. + + This should be used to avoid making calls to :func:`track_users` with only one + user repeatedly. Instead, using this should allow to passively queue the users, + and tracking can be triggered when there are sufficiently, using + :func:`track_pending_users`. + + Args: + user_id (str): A user id. + """ + with self.pending_users_lock: + self.pending_outdated_user_ids.add(user_id) + + def track_pending_users(self): + """Triggers the tracking of the user added with :func:`track_user_no_download`.""" + with self.pending_users_lock: + self.track_users(self.pending_outdated_user_ids) + self.pending_outdated_user_ids.clear() + + def track_users(self, user_ids): """Add users to be tracked, and download their device keys. NOTE: this is non-blocking and will return before the keys are downloaded. diff --git a/matrix_client/room.py b/matrix_client/room.py index 488f335b..deec52d6 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -660,11 +660,14 @@ def _process_state_event(self, state_event): self.encrypted = True elif etype == "m.room.member" and clevel == clevel.ALL: # tracking room members can be large e.g. #matrix:matrix.org + user_id = state_event["state_key"] if econtent["membership"] == "join": - user_id = state_event["state_key"] self._add_member(user_id, econtent.get("displayname")) + if self.client._encryption and self.encrypted: + # Track the device list of this user + self.client.olm_device.device_list.track_user_no_download(user_id) elif econtent["membership"] in ("leave", "kick", "invite"): - self._members.pop(state_event["state_key"], None) + self._members.pop(user_id, None) for listener in self.state_listeners: if ( diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py index 32f75ce6..69df1ed3 100644 --- a/test/crypto/device_list_test.py +++ b/test/crypto/device_list_test.py @@ -193,18 +193,18 @@ def test_get_room_device_keys(self): self.device_list.get_room_device_keys(room, blocking=False) @responses.activate - def test_add_users(self): + def test_track_users(self): self.device_list.tracked_user_ids.clear() responses.add(responses.POST, self.query_url, json=example_key_query_response) self.device_list.update_thread.event.clear() - self.device_list.add_users({self.alice}) + self.device_list.track_users({self.alice}) self.device_list.update_thread.event.wait() assert self.device_list.tracked_user_ids == {self.alice} assert len(responses.calls) == 1 # Same, but we are already tracking Alice - self.device_list.add_users({self.alice}) + self.device_list.track_users({self.alice}) assert len(responses.calls) == 1 def test_stop_tracking_users(self): @@ -218,6 +218,16 @@ def test_stop_tracking_users(self): assert not self.device_list.tracked_user_ids assert not self.device_list.outdated_user_ids + def test_pending_users(self): + # Say Alice is already tracked to avoid triggering dowload process + self.device_list.tracked_user_ids.add(self.alice) + + self.device_list.track_user_no_download(self.alice) + assert self.alice in self.device_list.pending_outdated_user_ids + + self.device_list.track_pending_users() + assert self.alice not in self.device_list.pending_outdated_user_ids + @responses.activate def test_update_user_device_keys(self): self.device_list.tracked_user_ids.clear() From 9aadf1b6f755cba1b5fd3fdef771b8952e9fbed1 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 29 Jun 2018 15:24:16 +0200 Subject: [PATCH 04/15] build doc for crypto.device_list --- docs/source/matrix_client.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index af1ebc41..1587a867 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -56,3 +56,8 @@ matrix_client.crypto :members: :undoc-members: :show-inheritance: + +.. automodule:: matrix_client.crypto.device_list + :members: + :undoc-members: + :show-inheritance: From 165d4f44173b6fe117f16be6fbe8c16ec3009b18 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 23 May 2018 16:43:58 +0200 Subject: [PATCH 05/15] add olm encryption, decryption and session --- matrix_client/crypto/olm_device.py | 231 +++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index d2e114b1..c8701f8c 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -1,3 +1,4 @@ +import json import logging from collections import defaultdict @@ -63,6 +64,7 @@ def __init__(self, keys_threshold) self.device_keys = defaultdict(dict) self.device_list = DeviceList(self, api, self.device_keys) + self.olm_sessions = defaultdict(list) def upload_identity_keys(self): """Uploads this device's identity keys to HS. @@ -140,6 +142,235 @@ def update_one_time_key_counts(self, counts): logger.info('Uploading new one-time keys.') self.upload_one_time_keys() + def olm_start_sessions(self, user_devices): + """Start olm sessions with the given devices. + + NOTE: those device keys must already be known. + + Args: + user_devices (dict): A map from user_id to an iterable of device_ids. + The format is ``: []``. + """ + logger.info('Trying to establish Olm sessions with devices: %s.', + dict(user_devices)) + payload = defaultdict(dict) + for user_id in user_devices: + for device_id in user_devices[user_id]: + payload[user_id][device_id] = 'signed_curve25519' + + resp = self.api.claim_keys(payload) + if resp.get('failures'): + logger.warning('Failed to claim one-time keys from the following unreachable ' + 'homeservers: %s.', resp['failures']) + keys = resp['one_time_keys'] + if logger.level >= logging.WARNING: + missing = {} + for user_id, device_ids in user_devices.items(): + if user_id not in keys: + missing[user_id] = device_ids + else: + missing_devices = set(device_ids) - set(keys[user_id]) + if missing_devices: + missing[user_id] = missing_devices + logger.warning('Failed to claim the keys of %s.', missing) + + for user_id in user_devices: + for device_id, one_time_key in keys.get(user_id, {}).items(): + try: + device_keys = self.device_keys[user_id][device_id] + except KeyError: + logger.warning('Key for device %s of user %s not found, could not ' + 'start Olm session.', device_id, user_id) + continue + key_object = next(iter(one_time_key.values())) + verified = self.verify_json(key_object, + device_keys['ed25519'], + user_id, + device_id) + if verified: + session = olm.OutboundSession(self.olm_account, + device_keys['curve25519'], + key_object['key']) + sessions = self.olm_sessions[device_keys['curve25519']] + sessions.append(session) + logger.info('Established Olm session %s with device %s of user ' + '%s.', device_id, session.id, user_id) + else: + logger.warning('Signature verification for one-time key of device %s ' + 'of user %s failed, could not start olm session.', + device_id, user_id) + + def olm_build_encrypted_event(self, event_type, content, user_id, device_id): + """Encrypt an event using Olm. + + NOTE: a session with this device must already be established. + + Args: + event_type (str): The event type, will be encrypted. + content (dict): The event content, will be encrypted. + user_id (str): The intended recipient of the event. + device_id (str): The device to encrypt to. + + Returns: + The Olm encrypted event, as JSON. + """ + try: + keys = self.device_keys[user_id][device_id] + except KeyError: + raise RuntimeError('Device is unknown, could not encrypt.') + + signing_key = keys['ed25519'] + identity_key = keys['curve25519'] + + payload = { + 'type': event_type, + 'content': content, + 'sender': self.user_id, + 'sender_device': self.device_id, + 'keys': { + 'ed25519': self.identity_keys['ed25519'] + }, + 'recipient': user_id, + 'recipient_keys': { + 'ed25519': signing_key + } + } + + sessions = self.olm_sessions[identity_key] + if sessions: + session = sorted(sessions, key=lambda s: s.id)[0] + else: + raise RuntimeError('No session for this device, could not encrypt.') + + encrypted_message = session.encrypt(json.dumps(payload)) + ciphertext_payload = { + identity_key: { + 'type': encrypted_message.message_type, + 'body': encrypted_message.ciphertext + } + } + + event = { + 'algorithm': self._olm_algorithm, + 'sender_key': self.identity_keys['curve25519'], + 'ciphertext': ciphertext_payload + } + return event + + def olm_decrypt_event(self, content, user_id): + """Decrypt an Olm encrypted event, and check its properties. + + Args: + event (dict): The content property of a m.room.encrypted event. + user_id (str): The sender of the event. + + Retuns: + The decrypted event held by the initial event. + + Raises: + RuntimeError: Error in the decryption process. Nothing can be done. The text + of the exception indicates what went wrong, and should be logged or + displayed to the user. + KeyError: The event is missing a required field. + """ + if content['algorithm'] != self._olm_algorithm: + raise RuntimeError('Event was not encrypted with {}.' + .format(self._olm_algorithm)) + + ciphertext = content['ciphertext'] + try: + payload = ciphertext[self.identity_keys['curve25519']] + except KeyError: + raise RuntimeError('This message was not encrypted for us.') + + msg_type = payload['type'] + if msg_type == 0: + encrypted_message = olm.OlmPreKeyMessage(payload['body']) + else: + encrypted_message = olm.OlmMessage(payload['body']) + + decrypted_event = self._olm_decrypt(encrypted_message, content['sender_key']) + + if decrypted_event['sender'] != user_id: + raise RuntimeError( + 'Found user {} instead of sender {} in Olm plaintext {}.' + .format(decrypted_event['sender'], user_id, decrypted_event) + ) + if decrypted_event['recipient'] != self.user_id: + raise RuntimeError( + 'Found user {} instead of us ({}) in Olm plaintext {}.' + .format(decrypted_event['recipient'], self.user_id, decrypted_event) + ) + our_key = decrypted_event['recipient_keys']['ed25519'] + if our_key != self.identity_keys['ed25519']: + raise RuntimeError( + 'Found key {} instead of ours own ed25519 key {} in Olm plaintext {}.' + .format(our_key, self.identity_keys['ed25519'], decrypted_event) + ) + + return decrypted_event + + def _olm_decrypt(self, olm_message, sender_key): + """Decrypt an Olm encrypted event. + + NOTE: This does no implement any security check. + + Try to decrypt using existing sessions. If it fails, start an new one when + possible. + + Args: + olm_message (OlmMessage): Olm encrypted payload. + sender_key (str): The sender's curve25519 identity key. + + Returns: + The decrypted event held by the initial payload, as JSON. + """ + + sessions = self.olm_sessions[sender_key] + + # Try to decrypt message body using one of the known sessions for that device + for session in sessions: + try: + event = session.decrypt(olm_message) + logger.info('Success decrypting Olm event using existing session %s.', + session.id) + break + except olm.session.OlmSessionError as e: + if olm_message.message_type == 0: + if session.matches(olm_message, sender_key): + # We had a matching session for a pre-key message, but it didn't + # work. This means something is wrong, so we fail now. + raise RuntimeError('Error decrypting pre-key message with ' + 'existing Olm session {}, reason: {}.' + .format(session.id, e)) + # Simply keep trying otherwise + else: + if olm_message.message_type > 0: + # Not a pre-key message, we should have had a matching session + if sessions: + raise RuntimeError('Error decrypting with existing sessions.') + raise RuntimeError('No existing sessions.') + + # We have a pre-key message without any matching session, in this case + # we should try to create one. + try: + session = olm.session.InboundSession( + self.olm_account, olm_message, sender_key) + except olm.session.OlmSessionError as e: + raise RuntimeError('Error decrypting pre-key message when trying to ' + 'establish a new session: {}.'.format(e)) + + logger.info('Created new Olm session %s.', session.id) + try: + event = session.decrypt(olm_message) + except olm.session.OlmSessionError as e: + raise RuntimeError('Error decrypting pre-key message with new session: ' + '{}.'.format(e)) + self.olm_account.remove_one_time_keys(session) + sessions.append(session) + + return json.loads(event) + def sign_json(self, json): """Signs a JSON object. From c689ec17d80d998f250fa757813333add66fa680 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Mon, 25 Jun 2018 19:03:06 +0200 Subject: [PATCH 06/15] add olm tests --- test/crypto/olm_device_test.py | 195 ++++++++++++++++++++++++++++++++- test/response_examples.py | 21 ++++ 2 files changed, 214 insertions(+), 2 deletions(-) diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index cab7c56d..69f692c4 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -1,15 +1,18 @@ import pytest -pytest.importorskip("olm") # noqa +olm = pytest.importorskip("olm") # noqa import json +import logging from copy import deepcopy import responses +from matrix_client.crypto import olm_device from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient from matrix_client.crypto.olm_device import OlmDevice -from test.response_examples import example_key_upload_response +from test.response_examples import (example_key_upload_response, + example_claim_keys_response) HOSTNAME = 'http://example.com' @@ -21,6 +24,13 @@ class TestOlmDevice: device_id = 'QBUAZIFURK' device = OlmDevice(cli.api, user_id, device_id) signing_key = device.olm_account.identity_keys['ed25519'] + alice = '@alice:example.com' + alice_device_id = 'JLAFKJWSCS' + alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' + alice_identity_keys = { + 'curve25519': alice_curve_key, + 'ed25519': '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' + } def test_sign_json(self): example_payload = { @@ -205,3 +215,184 @@ def test_invalid_keys_threshold(self, threshold): self.user_id, self.device_id, keys_threshold=threshold) + + @responses.activate + def test_olm_start_sessions(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + self.device.olm_sessions.clear() + self.device.device_keys.clear() + + user_devices = {self.alice: {self.alice_device_id}} + + # We don't have alice's keys + self.device.olm_start_sessions(user_devices) + assert not self.device.olm_sessions[self.alice_curve_key] + + # Cover logging part + olm_device.logger.setLevel(logging.WARNING) + # Now should be good + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.olm_start_sessions(user_devices) + assert self.device.olm_sessions[self.alice_curve_key] + + # With failures and wrong signature + self.device.olm_sessions.clear() + payload = deepcopy(example_claim_keys_response) + payload['failures'] = {'dummy': 1} + key = payload['one_time_keys'][self.alice][self.alice_device_id] + key['signed_curve25519:AAAAAQ']['test'] = 1 + responses.replace(responses.POST, claim_url, json=payload) + + self.device.olm_start_sessions(user_devices) + assert not self.device.olm_sessions[self.alice_curve_key] + + # Missing requested user and devices + user_devices[self.alice].add('test') + user_devices['test'] = 'test' + + self.device.olm_start_sessions(user_devices) + + @responses.activate + def test_olm_build_encrypted_event(self): + self.device.device_keys.clear() + self.device.olm_sessions.clear() + event_content = {'dummy': 'example'} + + # We don't have Alice's keys + with pytest.raises(RuntimeError): + self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + # We don't have a session with Alice + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + with pytest.raises(RuntimeError): + self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + user_devices = {self.alice: {self.alice_device_id}} + self.device.olm_start_sessions(user_devices) + assert self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + def test_olm_decrypt(self): + self.device.olm_sessions.clear() + # Since this method doesn't care about high-level event formatting, we will + # generate things at low level + our_account = self.device.olm_account + # Alice needs to start a session with us + alice = olm.Account() + sender_key = alice.identity_keys['curve25519'] + our_account.generate_one_time_keys(1) + otk = next(iter(our_account.one_time_keys['curve25519'].values())) + self.device.olm_account.mark_keys_as_published() + session = olm.OutboundSession(alice, our_account.identity_keys['curve25519'], otk) + + plaintext = {"test": "test"} + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # New pre-key message, but the session exists this time + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # Try to decrypt the same message twice + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Answer Alice in order to have a type 1 message + message = self.device.olm_sessions[sender_key][0].encrypt(json.dumps(plaintext)) + session.decrypt(message) + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # Try to decrypt the same message type 1 twice + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt a message from a session that reused a one-time key + otk_reused_session = olm.OutboundSession( + alice, our_account.identity_keys['curve25519'], otk) + message = otk_reused_session.encrypt(json.dumps(plaintext)) + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt an invalid type 0 message + our_account.generate_one_time_keys(1) + otk = next(iter(our_account.one_time_keys['curve25519'].values())) + wrong_session = olm.OutboundSession(alice, sender_key, otk) + message = wrong_session.encrypt(json.dumps(plaintext)) + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt a type 1 message for which we have no sessions + message = session.encrypt(json.dumps(plaintext)) + self.device.olm_sessions.clear() + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + def test_olm_decrypt_event(self): + self.device.device_keys.clear() + self.device.olm_sessions.clear() + alice_device = OlmDevice(self.device.api, self.alice, self.alice_device_id) + alice_device.device_keys[self.user_id][self.device_id] = self.device.identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = \ + alice_device.identity_keys + + # Artificially start an Olm session from Alice + self.device.olm_account.generate_one_time_keys(1) + otk = next(iter(self.device.olm_account.one_time_keys['curve25519'].values())) + self.device.olm_account.mark_keys_as_published() + sender_key = self.device.identity_keys['curve25519'] + session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) + alice_device.olm_sessions[sender_key] = [session] + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + + # Now we can test + self.device.olm_decrypt_event(encrypted_event, self.alice) + + # Type 1 Olm payload + alice_device.olm_decrypt_event( + self.device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.alice, self.alice_device_id + ), + self.user_id) + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.olm_decrypt_event(encrypted_event, self.alice) + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, 'wrong') + + wrong_event = deepcopy(encrypted_event) + wrong_event['algorithm'] = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(wrong_event, self.alice) + + wrong_event = deepcopy(encrypted_event) + wrong_event['ciphertext'] = {} + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(wrong_event, self.alice) + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.user_id = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, self.alice) + self.device.user_id = self.user_id + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + backup = self.device.identity_keys['ed25519'] + self.device.identity_keys['ed25519'] = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, self.alice) + self.device.identity_keys['ed25519'] = backup diff --git a/test/response_examples.py b/test/response_examples.py index d827a05b..d6a2ca49 100644 --- a/test/response_examples.py +++ b/test/response_examples.py @@ -215,3 +215,24 @@ } } } + +example_claim_keys_response = { + "failures": {}, + "one_time_keys": { + "@alice:example.com": { + "JLAFKJWSCS": { + 'signed_curve25519:AAAAAQ': { + 'key': '9UOzQjF2j2Xf8mBIiMgruuCkuWtD0ea9kvx63mO92Ws', + 'signatures': { + '@alice:example.com': { + 'ed25519:JLAFKJWSCS': ( + '6O+VYxN7mVcr/j66YdHASRrpW4ydC/0FcYmEWVAGIFzU4+yjzxxinhQD' + 'l7InhhdGuXeQlk4/w/CyU76TY6wdBA' + ) + } + } + } + } + } + } +} From eb43d6c6b1c1eef41464667435a9e032dcecc7b1 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 30 May 2018 22:43:14 +0200 Subject: [PATCH 07/15] add olm_ensure_sessions --- matrix_client/crypto/olm_device.py | 19 +++++++++++++++++++ test/crypto/olm_device_test.py | 20 ++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index c8701f8c..6f186030 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -371,6 +371,25 @@ def _olm_decrypt(self, olm_message, sender_key): return json.loads(event) + def olm_ensure_sessions(self, user_devices): + """Start Olm sessions with the given devices if one doesn't exist already. + + Args: + user_devices (dict): A map from user ids to a list of device ids. + """ + user_devices_no_session = defaultdict(list) + for user_id in user_devices: + for device_id in user_devices[user_id]: + curve_key = self.device_keys[user_id][device_id]['curve25519'] + # Check if we have a list of sessions for this device, which can be + # empty. Implicitely, an empty list will indicate that we already tried + # to establish a session with a device, but this attempt was + # unsuccessful. We do not retry to establish a session. + if curve_key not in self.olm_sessions: + user_devices_no_session[user_id].append(device_id) + if user_devices_no_session: + self.olm_start_sessions(user_devices_no_session) + def sign_json(self, json): """Signs a JSON object. diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 69f692c4..c67cff93 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -396,3 +396,23 @@ def test_olm_decrypt_event(self): with pytest.raises(RuntimeError): self.device.olm_decrypt_event(encrypted_event, self.alice) self.device.identity_keys['ed25519'] = backup + + @responses.activate + def test_olm_ensure_sessions(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + self.device.olm_sessions.clear() + alice_device_id = 'JLAFKJWSCS' + alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' + self.device.device_keys[self.alice][alice_device_id] = { + 'curve25519': alice_curve_key, + 'ed25519': '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' + } + user_devices = {self.alice: [alice_device_id]} + + self.device.olm_ensure_sessions(user_devices) + assert self.device.olm_sessions[alice_curve_key] + assert len(responses.calls) == 1 + + self.device.olm_ensure_sessions(user_devices) + assert len(responses.calls) == 1 From d3896b4a54a5c5238d137fd7130030efe9e5c665 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Mon, 28 May 2018 12:32:59 +0200 Subject: [PATCH 08/15] add megolm outbound session support --- .../crypto/megolm_outbound_session.py | 56 +++++++++ matrix_client/crypto/olm_device.py | 118 ++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 matrix_client/crypto/megolm_outbound_session.py diff --git a/matrix_client/crypto/megolm_outbound_session.py b/matrix_client/crypto/megolm_outbound_session.py new file mode 100644 index 00000000..5793c35f --- /dev/null +++ b/matrix_client/crypto/megolm_outbound_session.py @@ -0,0 +1,56 @@ +from datetime import datetime, timedelta + +from olm import OutboundGroupSession + + +class MegolmOutboundSession(OutboundGroupSession): + + """Outbound group session aware of the users it is shared with. + + Also remembers the time it was created and the number of messages it has encrypted, + in order to know if it needs to be rotated. + + Args: + max_age (datetime.timedelta): Optional. The maximum time the session should + exist. + max_messages (int): Optional. The maximum number of messages that should be sent. + A new message in considered sent each time there is a call to ``encrypt``. + """ + + def __init__(self, max_age=timedelta(days=7), max_messages=100): + self.devices = set() + self.max_age = max_age + self.max_messages = max_messages + self.creation_time = datetime.now() + self.message_count = 0 + super(MegolmOutboundSession, self).__init__() + + def __new__(cls, **kwargs): + return super(MegolmOutboundSession, cls).__new__(cls) + + def add_device(self, device_id): + """Adds a device the session is shared with.""" + self.devices.add(device_id) + + def add_devices(self, device_ids): + """Adds devices the session is shared with. + + Args: + device_ids (iterable): An iterable of device ids, preferably a set. + """ + self.devices.update(device_ids) + + def should_rotate(self): + """Wether the session should be rotated. + + Returns: + True if it should, False if not. + """ + if self.message_count >= self.max_messages or \ + datetime.now() - self.creation_time >= self.max_age: + return True + return False + + def encrypt(self, plaintext): + self.message_count += 1 + return super(MegolmOutboundSession, self).encrypt(plaintext) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 6f186030..0cb05b00 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -8,6 +8,7 @@ from matrix_client.checks import check_user_id from matrix_client.crypto.one_time_keys import OneTimeKeysManager from matrix_client.crypto.device_list import DeviceList +from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession logger = logging.getLogger(__name__) @@ -65,6 +66,7 @@ def __init__(self, self.device_keys = defaultdict(dict) self.device_list = DeviceList(self, api, self.device_keys) self.olm_sessions = defaultdict(list) + self.megolm_outbound_sessions = {} def upload_identity_keys(self): """Uploads this device's identity keys to HS. @@ -390,6 +392,122 @@ def olm_ensure_sessions(self, user_devices): if user_devices_no_session: self.olm_start_sessions(user_devices_no_session) + def megolm_start_session(self, room): + """Start a megolm session in a room, and share it with its members. + + Args: + room (Room): The room to use. + + Returns: + The newly created session. + """ + session = MegolmOutboundSession() + self.megolm_outbound_sessions[room.room_id] = session + logger.info('Starting a new Meglom outbound session %s in %s.', + session.id, room.room_id) + + users = room.get_joined_members() + user_devices = {user.user_id: list(self.device_keys[user.user_id]) + for user in users} + self.device_list.get_room_device_keys(room) + self.megolm_share_session(room.room_id, user_devices, session) + return session + + def megolm_share_session(self, room_id, user_devices, session): + """Share an already existing outbound megolm session with the specified devices. + + Args: + room_id (str): The room corresponding to the session. + user_devices (dict): A map from user ids to a list of device ids. + session (MegolmOutboundSession): The session object. + """ + + logger.info('Attempting to share Megolm session %s in %s with %s.', + session.id, room_id, user_devices) + self.olm_ensure_sessions(user_devices) + + event = { + 'algorithm': self._megolm_algorithm, + 'room_id': room_id, + 'session_id': session.id, + 'session_key': session.session_key + } + + messages = defaultdict(dict) + new_devices = set() + for user_id in user_devices: + for device_id in user_devices[user_id]: + try: + messages[user_id][device_id] = self.olm_build_encrypted_event( + 'm.room_key', event, user_id, device_id + ) + except RuntimeError as e: + logger.warning('Could not share megolm session %s with device %s of ' + 'user %s: %s', session.id, + device_id, user_id, e) + # We will not retry to share session with failed devices + new_devices.add(device_id) + self.api.send_to_device('m.room.encrypted', messages) + session.add_devices(new_devices) + + def megolm_share_session_with_new_devices(self, room, session): + """Share a megolm session with new devices in a room. + + Args: + room (Room): The room corresponding to the session. + session (MegolmOutboundSession): The session to share. + """ + user_devices = {} + users = room.get_joined_members() + for user in users: + user_id = user.user_id + missing_devices = list(set(self.device_keys[user_id].keys()) - + self.megolm_outbound_sessions[room.room_id].devices) + if missing_devices: + user_devices[user_id] = missing_devices + if user_devices: + logger.info('Sharing existing Megolm outbound session %s with new devices: ' + '%s', session.id, user_devices) + self.megolm_share_session(room.room_id, user_devices, session) + + def megolm_build_encrypted_event(self, room, event): + """Build an encrypted Megolm payload from a plaintext event. + + If no session exists in the room, a new one will be initiated. Also takes care + of rotating the session periodically. + + Args: + room (Room): The room the event will be sent in. + event (dict): Matrix event. + + Returns: + The encrypted event, as a dict. + """ + room_id = room.room_id + + session = self.megolm_outbound_sessions.get(room_id) + if not session or session.should_rotate(): + session = self.megolm_start_session(room) + else: + self.megolm_share_session_with_new_devices(room, session) + + payload = { + 'type': event['type'], + 'content': event['content'], + 'room_id': room_id + } + + encrypted_payload = session.encrypt(json.dumps(payload)) + + encrypted_event = { + 'algorithm': self._megolm_algorithm, + 'sender_key': self.identity_keys['curve25519'], + 'ciphertext': encrypted_payload, + 'session_id': session.id, + 'device_id': self.device_id + } + return encrypted_event + def sign_json(self, json): """Signs a JSON object. From f18f9b26808a8233f0d726b13753cca977e868da Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Mon, 18 Jun 2018 17:27:37 +0200 Subject: [PATCH 09/15] delete outbound sessions on leave --- matrix_client/crypto/olm_device.py | 14 ++++++++++++++ matrix_client/room.py | 6 ++++++ 2 files changed, 20 insertions(+) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 0cb05b00..4e09ea08 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -508,6 +508,20 @@ def megolm_build_encrypted_event(self, room, event): } return encrypted_event + def megolm_remove_outbound_session(self, room_id): + """Remove an existing Megolm outbound session in a room. + + If there is no such session, nothing will happen. + + Args: + room_id (str): The room to use. + """ + try: + self.megolm_outbound_sessions.pop(room_id) + logger.info('Removed Meglom outbound session in %s.', room_id) + except KeyError: + pass + def sign_json(self, json): """Signs a JSON object. diff --git a/matrix_client/room.py b/matrix_client/room.py index deec52d6..15c427e5 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -668,6 +668,12 @@ def _process_state_event(self, state_event): self.client.olm_device.device_list.track_user_no_download(user_id) elif econtent["membership"] in ("leave", "kick", "invite"): self._members.pop(user_id, None) + if econtent["membership"] != "invite": + if self.client._encryption and self.encrypted: + # Invalidate any outbound session we have in the room when + # someone leaves + self.client.olm_device.megolm_remove_outbound_session( + self.room_id) for listener in self.state_listeners: if ( From de73b9d69b70d6a413c20619297f20955efd7304 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 5 Jun 2018 16:49:03 +0200 Subject: [PATCH 10/15] send encrypted group messages --- matrix_client/crypto/olm_device.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 4e09ea08..ace09c7d 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -522,6 +522,21 @@ def megolm_remove_outbound_session(self, room_id): except KeyError: pass + def send_encrypted_message(self, room, content): + """Send a m.room.encrypted event in a room. + + Args: + room (Room): The room to use. + content (dict): The content of the event, will be encrypted. + + Raises: + MatrixRequestError if there was an error sending the event. + """ + event = {'content': content, 'room_id': room.room_id, 'type': 'm.room.message'} + encrypted_event = self.megolm_build_encrypted_event(room, event) + return self.api.send_message_event( + room.room_id, 'm.room.encrypted', encrypted_event) + def sign_json(self, json): """Signs a JSON object. From c07ee438f168791d194f1098c49613d6cd4fdb7e Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 5 Jun 2018 16:49:30 +0200 Subject: [PATCH 11/15] automatically send encrypted messages in encrypted rooms --- matrix_client/api.py | 40 ++++++++++++++++++++------------- matrix_client/errors.py | 7 ++++++ matrix_client/room.py | 50 ++++++++++++++++++++++++++++++++++------- 3 files changed, 74 insertions(+), 23 deletions(-) diff --git a/matrix_client/api.py b/matrix_client/api.py index 8bd44ecc..102d3e5c 100644 --- a/matrix_client/api.py +++ b/matrix_client/api.py @@ -343,6 +343,18 @@ def send_content(self, room_id, item_url, item_name, msg_type, return self.send_message_event(room_id, "m.room.message", content_pack, timestamp=timestamp) + def get_location_body(self, geo_uri, name, thumb_url=None, thumb_info=None): + content_pack = { + "geo_uri": geo_uri, + "msgtype": "m.location", + "body": name, + } + if thumb_url: + content_pack["thumbnail_url"] = thumb_url + if thumb_info: + content_pack["thumbnail_info"] = thumb_info + return content_pack + # http://matrix.org/docs/spec/client_server/r0.2.0.html#m-location def send_location(self, room_id, geo_uri, name, thumb_url=None, thumb_info=None, timestamp=None): @@ -356,15 +368,8 @@ def send_location(self, room_id, geo_uri, name, thumb_url=None, thumb_info=None, thumb_info (dict): Metadata about the thumbnail, type ImageInfo. timestamp (int): Set origin_server_ts (For application services only) """ - content_pack = { - "geo_uri": geo_uri, - "msgtype": "m.location", - "body": name, - } - if thumb_url: - content_pack["thumbnail_url"] = thumb_url - if thumb_info: - content_pack["thumbnail_info"] = thumb_info + content_pack = self.get_location_body( + geo_uri, name, thumb_url, thumb_info) return self.send_message_event(room_id, "m.room.message", content_pack, timestamp=timestamp) @@ -405,12 +410,11 @@ def send_notice(self, room_id, text_content, timestamp=None): text_content (str): The m.notice body to send. timestamp (int): Set origin_server_ts (For application services only) """ - body = { - "msgtype": "m.notice", - "body": text_content - } - return self.send_message_event(room_id, "m.room.message", body, - timestamp=timestamp) + return self.send_message_event( + room_id, "m.room.message", + self.get_notice_body(text_content), + timestamp=timestamp + ) def get_room_messages(self, room_id, token, direction, limit=10, to=None): """Perform GET /rooms/{roomId}/messages. @@ -675,6 +679,12 @@ def get_emote_body(self, text): "body": text } + def get_notice_body(self, text): + return { + "msgtype": "m.notice", + "body": text + } + def get_filter(self, user_id, filter_id): return self._send("GET", "/user/{userId}/filter/{filterId}" .format(userId=user_id, filterId=filter_id)) diff --git a/matrix_client/errors.py b/matrix_client/errors.py index e9dc8fe3..98bd5eb0 100644 --- a/matrix_client/errors.py +++ b/matrix_client/errors.py @@ -46,3 +46,10 @@ def __init__(self, original_exception, method, endpoint): original_exception) ) self.original_exception = original_exception + + +class MatrixNoEncryptionError(MatrixError): + """Encryption was not available.""" + + def __init__(self, content=""): + super(MatrixNoEncryptionError, self).__init__(content) diff --git a/matrix_client/room.py b/matrix_client/room.py index 15c427e5..8a2f0d94 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -18,7 +18,7 @@ from .checks import check_room_id from .user import User -from .errors import MatrixRequestError +from .errors import MatrixRequestError, MatrixNoEncryptionError class Room(object): @@ -102,7 +102,10 @@ def display_name(self): def send_text(self, text): """Send a plain text message to the room.""" - return self.client.api.send_message(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_text_body(text)) + else: + return self.client.api.send_message(self.room_id, text) def get_html_content(self, html, body=None, msgtype="m.text"): return { @@ -119,8 +122,12 @@ def send_html(self, html, body=None, msgtype="m.text"): html (str): The html formatted message to be sent. body (str): The unformatted body of the message to be sent. """ - return self.client.api.send_message_event( - self.room_id, "m.room.message", self.get_html_content(html, body, msgtype)) + content = self.get_html_content(html, body, msgtype) + if self.encrypted and self.client._encryption: + return self.send_encrypted(content) + else: + return self.client.api.send_message_event( + self.room_id, "m.room.message", content) def set_account_data(self, type, account_data): return self.client.api.set_room_account_data( @@ -142,7 +149,10 @@ def add_tag(self, tag, order=None, content=None): def send_emote(self, text): """Send an emote (/me style) message to the room.""" - return self.client.api.send_emote(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_emote_body(text)) + else: + return self.client.api.send_emote(self.room_id, text) def send_file(self, url, name, **fileinfo): """Send a pre-uploaded file to the room. @@ -163,7 +173,10 @@ def send_file(self, url, name, **fileinfo): def send_notice(self, text): """Send a notice (from bot) message to the room.""" - return self.client.api.send_notice(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_notice_body(text)) + else: + return self.client.api.send_notice(self.room_id, text) # See http://matrix.org/docs/spec/r0.0.1/client_server.html#m-image for the # imageinfo args. @@ -195,8 +208,13 @@ def send_location(self, geo_uri, name, thumb_url=None, **thumb_info): thumb_url (str): URL to the thumbnail of the location. thumb_info (): Metadata about the thumbnail, type ImageInfo. """ - return self.client.api.send_location(self.room_id, geo_uri, name, - thumb_url, thumb_info) + if self.encrypted and self.client._encryption: + content = self.client.api.get_location_body( + geo_uri, name, thumb_url, thumb_info) + return self.send_encrypted(content) + else: + return self.client.api.send_location(self.room_id, geo_uri, name, + thumb_url, thumb_info) def send_video(self, url, name, **videoinfo): """Send a pre-uploaded video to the room. @@ -226,6 +244,22 @@ def send_audio(self, url, name, **audioinfo): return self.client.api.send_content(self.room_id, url, name, "m.audio", extra_information=audioinfo) + def send_encrypted(self, content): + """Send an arbitrary encrypted message to the room. + + Args: + content (dict): The content of a m.room.message event. + + Raises: + ``MatrixNoEncryptionError`` if encryption is not enabled in client, or if + the room is unencrypted. + """ + if not self.client._encryption: + raise MatrixNoEncryptionError('Encryption is not enabled in client.') + if not self.encrypted: + raise MatrixNoEncryptionError('Encryption is not enabled in the room.') + return self.client.olm_device.send_encrypted_message(self, content) + def redact_message(self, event_id, reason=None): """Redacts the message with specified event_id for the given reason. From 24b5a407eae0c9fe67d2637c5d9a5b9e083c12b9 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 27 Jun 2018 16:04:14 +0200 Subject: [PATCH 12/15] add megolm outbound tests --- test/crypto/olm_device_test.py | 133 +++++++++++++++++++++++++++++++-- 1 file changed, 125 insertions(+), 8 deletions(-) diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index c67cff93..7e4967aa 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -4,13 +4,19 @@ import json import logging from copy import deepcopy +try: + from urllib import quote +except ImportError: + from urllib.parse import quote import responses from matrix_client.crypto import olm_device from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient +from matrix_client.user import User from matrix_client.crypto.olm_device import OlmDevice +from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession from test.response_examples import (example_key_upload_response, example_claim_keys_response) @@ -31,6 +37,12 @@ class TestOlmDevice: 'curve25519': alice_curve_key, 'ed25519': '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' } + alice_olm_session = olm.OutboundSession( + device.olm_account, alice_curve_key, alice_curve_key) + room = cli._mkroom(room_id) + room._members[alice] = User(cli.api, alice) + # allow to_device api call to work well with responses + device.api._make_txn_id = lambda: 1 def test_sign_json(self): example_payload = { @@ -402,17 +414,122 @@ def test_olm_ensure_sessions(self): claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' responses.add(responses.POST, claim_url, json=example_claim_keys_response) self.device.olm_sessions.clear() - alice_device_id = 'JLAFKJWSCS' - alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' - self.device.device_keys[self.alice][alice_device_id] = { - 'curve25519': alice_curve_key, - 'ed25519': '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' - } - user_devices = {self.alice: [alice_device_id]} + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + user_devices = {self.alice: [self.alice_device_id]} self.device.olm_ensure_sessions(user_devices) - assert self.device.olm_sessions[alice_curve_key] + assert self.device.olm_sessions[self.alice_curve_key] assert len(responses.calls) == 1 self.device.olm_ensure_sessions(user_devices) assert len(responses.calls) == 1 + + @responses.activate + def test_megolm_share_session(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.olm_sessions.clear() + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.device_keys['dummy']['dummy'] = {'curve25519': 'a', 'ed25519': 'a'} + user_devices = {self.alice: [self.alice_device_id], 'dummy': ['dummy']} + session = MegolmOutboundSession() + + # Sharing with Alice should succeed, but dummy will fail + self.device.megolm_share_session(self.room_id, user_devices, session) + assert session.devices == {self.alice_device_id, 'dummy'} + + req = json.loads(responses.calls[1].request.body)['messages'] + assert self.alice in req + assert 'dummy' not in req + + @responses.activate + def test_megolm_start_session(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.device_list.tracked_user_ids.add(self.alice) + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + + self.device.megolm_start_session(self.room) + session = self.device.megolm_outbound_sessions[self.room_id] + assert self.alice_device_id in session.devices + + @responses.activate + def test_megolm_share_session_with_new_devices(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + session = MegolmOutboundSession() + self.device.megolm_outbound_sessions[self.room_id] = session + + self.device.megolm_share_session_with_new_devices(self.room, session) + assert self.alice_device_id in session.devices + assert len(responses.calls) == 1 + + self.device.megolm_share_session_with_new_devices(self.room, session) + assert len(responses.calls) == 1 + + @responses.activate + def test_megolm_build_encrypted_event(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.megolm_outbound_sessions.clear() + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.device_list.tracked_user_ids.add(self.alice) + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + event = {'type': 'm.room.message', 'content': {'body': 'test'}} + + self.room.rotation_period_msgs = 1 + self.device.megolm_build_encrypted_event(self.room, event) + + self.device.megolm_build_encrypted_event(self.room, event) + + session = self.device.megolm_outbound_sessions[self.room_id] + session.encrypt('test') + self.device.megolm_build_encrypted_event(self.room, event) + assert self.device.megolm_outbound_sessions[self.room_id].id != session.id + + def test_megolm_remove_outbound_session(self): + session = MegolmOutboundSession() + self.device.megolm_outbound_sessions[self.room_id] = session + self.device.megolm_remove_outbound_session(self.room_id) + self.device.megolm_remove_outbound_session(self.room_id) + + @responses.activate + def test_send_encrypted_message(self): + message_url = HOSTNAME + MATRIX_V2_API_PATH + \ + '/rooms/{}/send/m.room.encrypted/1'.format(quote(self.room.room_id)) + responses.add(responses.PUT, message_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + session = MegolmOutboundSession() + session.add_device(self.alice_device_id) + self.device.megolm_outbound_sessions[self.room_id] = session + + self.device.send_encrypted_message(self.room, {'test': 'test'}) + + +def test_megolm_outbound_session(): + session = MegolmOutboundSession(max_messages=1) + + assert not session.devices + + session.add_device('test') + assert 'test' in session.devices + + session.add_devices({'test2', 'test3'}) + assert 'test2' in session.devices and 'test3' in session.devices + + assert not session.should_rotate() + + session.encrypt('message') + assert session.should_rotate() From 4ce7cdf67f3803de78566d9ce73658e9831da071 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 29 Jun 2018 10:59:24 +0200 Subject: [PATCH 13/15] handle m.room.encryption properties --- .../crypto/megolm_outbound_session.py | 12 ++++++++---- matrix_client/crypto/olm_device.py | 3 ++- matrix_client/room.py | 6 ++++++ test/client_test.py | 19 ++++++++++++++++++- test/crypto/olm_device_test.py | 14 +++++++++++++- 5 files changed, 47 insertions(+), 7 deletions(-) diff --git a/matrix_client/crypto/megolm_outbound_session.py b/matrix_client/crypto/megolm_outbound_session.py index 5793c35f..cff87f6a 100644 --- a/matrix_client/crypto/megolm_outbound_session.py +++ b/matrix_client/crypto/megolm_outbound_session.py @@ -12,15 +12,19 @@ class MegolmOutboundSession(OutboundGroupSession): Args: max_age (datetime.timedelta): Optional. The maximum time the session should - exist. + exist. Default to one week if not present. max_messages (int): Optional. The maximum number of messages that should be sent. A new message in considered sent each time there is a call to ``encrypt``. + Default to 100 if not present. """ - def __init__(self, max_age=timedelta(days=7), max_messages=100): + def __init__(self, max_age=None, max_messages=None): self.devices = set() - self.max_age = max_age - self.max_messages = max_messages + if max_age: + self.max_age = timedelta(milliseconds=max_age) + else: + self.max_age = timedelta(days=7) + self.max_messages = max_messages or 100 self.creation_time = datetime.now() self.message_count = 0 super(MegolmOutboundSession, self).__init__() diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index ace09c7d..8b2a344d 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -401,7 +401,8 @@ def megolm_start_session(self, room): Returns: The newly created session. """ - session = MegolmOutboundSession() + session = MegolmOutboundSession(max_age=room.rotation_period_ms, + max_messages=room.rotation_period_msgs) self.megolm_outbound_sessions[room.room_id] = session logger.info('Starting a new Meglom outbound session %s in %s.', session.id, room.room_id) diff --git a/matrix_client/room.py b/matrix_client/room.py index 8a2f0d94..38e9a9fe 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -50,6 +50,8 @@ def __init__(self, client, room_id): # user_id: displayname } self.encrypted = False + self.rotation_period_msgs = None + self.rotation_period_ms = None def set_user_profile(self, displayname=None, @@ -692,6 +694,10 @@ def _process_state_event(self, state_event): elif etype == "m.room.encryption": if econtent.get("algorithm") == "m.megolm.v1.aes-sha2": self.encrypted = True + if not self.rotation_period_ms: + self.rotation_period_ms = econtent.get("rotation_period_ms") + if not self.rotation_period_msgs: + self.rotation_period_msgs = econtent.get("rotation_period_msgs") elif etype == "m.room.member" and clevel == clevel.ALL: # tracking room members can be large e.g. #matrix:matrix.org user_id = state_event["state_key"] diff --git a/test/client_test.py b/test/client_test.py index c5884924..5532e1ae 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -155,13 +155,30 @@ def test_state_event(): # test encryption room.encrypted = False ev["type"] = "m.room.encryption" - ev["content"] = {"algorithm": "m.megolm.v1.aes-sha2"} + ev["content"] = { + "algorithm": "m.megolm.v1.aes-sha2", + "rotation_period_msgs": 50, + "rotation_period_ms": 100000, + } room._process_state_event(ev) assert room.encrypted + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 # encrypted flag must not be cleared on configuration change ev["content"] = {"algorithm": None} room._process_state_event(ev) assert room.encrypted + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 + # nor should the session parameters be changed + ev["content"] = { + "algorithm": "m.megolm.v1.aes-sha2", + "rotation_period_msgs": 5, + "rotation_period_ms": 10000, + } + room._process_state_event(ev) + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 def test_get_user(): diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 7e4967aa..c7aaa51b 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -4,6 +4,7 @@ import json import logging from copy import deepcopy +from datetime import timedelta, datetime try: from urllib import quote except ImportError: @@ -519,7 +520,13 @@ def test_send_encrypted_message(self): def test_megolm_outbound_session(): - session = MegolmOutboundSession(max_messages=1) + session = MegolmOutboundSession() + assert session.max_messages == 100 + assert session.max_age == timedelta(days=7) + + session = MegolmOutboundSession(max_messages=1, max_age=100000) + assert session.max_messages == 1 + assert session.max_age == timedelta(milliseconds=100000) assert not session.devices @@ -533,3 +540,8 @@ def test_megolm_outbound_session(): session.encrypt('message') assert session.should_rotate() + + session.max_messages = 2 + assert not session.should_rotate() + session.creation_time = datetime.now() - timedelta(milliseconds=100000) + assert session.should_rotate() From 71e8c524f2da32160f294893eff5730a35b53204 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 29 Jun 2018 15:29:54 +0200 Subject: [PATCH 14/15] build doc for crypto.megolm_outbound_session --- docs/source/matrix_client.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index 1587a867..353eb5a5 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -61,3 +61,8 @@ matrix_client.crypto :members: :undoc-members: :show-inheritance: + +.. automodule:: matrix_client.crypto.megolm_outbound_session + :members: + :undoc-members: + :show-inheritance: From 2515d578f96847fea8785e0341b59dfe8b9f4666 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 20 Jul 2018 12:58:35 +0200 Subject: [PATCH 15/15] pass rotation parameters when enabling encryption --- matrix_client/room.py | 25 +++++++++++++++++++++---- test/client_test.py | 9 +++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/matrix_client/room.py b/matrix_client/room.py index 38e9a9fe..120f65d2 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -654,18 +654,35 @@ def set_guest_access(self, allow_guests): except MatrixRequestError: return False - def enable_encryption(self): + def enable_encryption(self, rotation_period_ms=None, rotation_period_msgs=None): """Enables encryption in the room. NOTE: Once enabled, encryption cannot be disabled. + Args: + rotation_period_ms (int): Optional. Lifetime of Megolm sessions in the room. + rotation_period_msgs (int): Optional. Number of messages that can be encrypted + by a Megolm session before it is rotated. + Returns: - True if successful, False if not + ``True`` if successful, ``False`` if not. + If the room was already encrypted, True is returned, but the + ``rotation_period_ms`` and ``rotation_period_ms`` parameters have no effect. """ + if self.encrypted: + return True + + content = {"algorithm": "m.megolm.v1.aes-sha2"} + if rotation_period_ms: + content["rotation_period_ms"] = rotation_period_ms + if rotation_period_msgs: + content["rotation_period_msgs"] = rotation_period_msgs + try: - self.send_state_event("m.room.encryption", - {"algorithm": "m.megolm.v1.aes-sha2"}) + self.send_state_event("m.room.encryption", content) self.encrypted = True + self.rotation_period_ms = rotation_period_ms + self.rotation_period_msgs = rotation_period_msgs return True except MatrixRequestError: return False diff --git a/test/client_test.py b/test/client_test.py index 5532e1ae..5fa84762 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -536,6 +536,15 @@ def test_enable_encryption_in_room(): assert room.enable_encryption() assert room.encrypted + room = client._mkroom(room_id) + assert room.enable_encryption(rotation_period_msgs=1, rotation_period_ms=1) + assert room.rotation_period_msgs == 1 + assert room.rotation_period_ms == 1 + + assert room.enable_encryption(rotation_period_msgs=2) + # The room was already encrypted, we should not have changed its attribute + assert room.rotation_period_msgs == 1 + @responses.activate def test_detect_encryption_state():