From b0fa2a65b0f7a74a733ca32fdb68b2c2a0fad585 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 24 May 2018 20:47:07 +0200 Subject: [PATCH 01/66] track the device list of users and download keys Use a thread to allow downloading the keys in a non-blocking way. Implement sort of a queue of users who need updating, which should be efficient (request is fired as soon as possible with as much data as possible) and avoid races (we always have only one download in progress). --- matrix_client/client.py | 9 + matrix_client/crypto/device_list.py | 269 ++++++++++++++++++++++++++++ matrix_client/crypto/olm_device.py | 4 + 3 files changed, 282 insertions(+) create mode 100644 matrix_client/crypto/device_list.py diff --git a/matrix_client/client.py b/matrix_client/client.py index 703b825c..f9a2d876 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -581,6 +581,15 @@ def _mkroom(self, room_id): # TODO better handling of the blocking I/O caused by update_one_time_key_counts def _sync(self, timeout_ms=30000): response = self.api.sync(self.sync_token, timeout_ms, filter=self.sync_filter) + + if self._encryption and 'device_lists' in response: + if response['device_lists'].get('changed'): + self.olm_device.device_list.update_user_device_keys( + response['device_lists']['changed'], self.sync_token) + if response['device_lists'].get('left'): + self.olm_device.device_list.stop_tracking_users( + response['device_lists']['left']) + self.sync_token = response["next_batch"] for presence_update in response['presence']['events']: diff --git a/matrix_client/crypto/device_list.py b/matrix_client/crypto/device_list.py new file mode 100644 index 00000000..c9a22966 --- /dev/null +++ b/matrix_client/crypto/device_list.py @@ -0,0 +1,269 @@ +import logging +from collections import defaultdict +from threading import Thread, Condition, Event + +from matrix_client.errors import MatrixHttpLibError, MatrixRequestError + +logger = logging.getLogger(__name__) + + +class DeviceList: + """Allows to maintain a list of devices up-to-date for an OlmDevice. + + Offers blocking and non-blocking methods to fetch device keys when appropriate. + NOTE: Spawns a thread that will last until program termination. + + Args: + olm_device (OlmDevice): Will be used to get additional info, such as device id. + api (MatrixHttpApi): The api object used to make requests. + device_keys (defaultdict(dict)): A map from user to device to keys. + """ + + def __init__(self, olm_device, api, device_keys): + self.olm_device = olm_device + self.api = api + self.device_keys = device_keys + # Stores the ids of users who need updating + self.outdated_user_ids = _OutdatedUsersSet() + # Stores the ids of users we are currently tracking. We can assume the device + # keys of these users are up-to-date as long as no downloading is in progress. + # We should track every user we share an encrypted room with. + self.tracked_user_ids = set() + # Allows to wake up the thread when there are new users to update, and to + # synchronise shared data. + self.thread_condition = Condition() + self.update_thread = _UpdateDeviceList( + self.thread_condition, self.outdated_user_ids, self._download_device_keys, + self.tracked_user_ids + ) + self.update_thread.start() + + def get_room_device_keys(self, room, blocking=True): + """Gets the keys of all devices present in the room. + + Makes sure not to download keys of users we are already tracking. + The users we were not yet tracking will get tracked automatically. + + Args: + room (Room): The room to use. + blocking (bool): Optional. Whether to wait for the keys to have been + downloaded before returning. + """ + logger.info('Fetching all missing keys in room %s.', room.room_id) + user_ids = {u.user_id for u in room.get_joined_members()} - self.tracked_user_ids + if not user_ids: + logger.info('Already had all the keys in room %s.', room.room_id) + if blocking: + # Wait on an eventual download to finish + self.update_thread.event.wait() + return + with self.thread_condition: + self.outdated_user_ids.update(user_ids) + if blocking: + # Will ensure the user_ids we just added are processed + event = Event() + self.outdated_user_ids.events.add(event) + self.thread_condition.notify() + if blocking: + event.wait() + + def add_users(self, user_ids): + """Add users to be tracked, and download their device keys. + + NOTE: this is non-blocking and will return before the keys are downloaded. + + Args: + user_ids (iterable): Any iterable containing user ids. + """ + user_ids = user_ids.difference(self.tracked_user_ids) + if user_ids: + self._add_outdated_users(user_ids) + + def stop_tracking_users(self, user_ids): + """Stop tracking users. + + NOTE: Keys will not be deleted. + + Args: + user_ids (iterable): Any iterable containing user ids. + """ + with self.thread_condition: + self.tracked_user_ids.difference_update(user_ids) + self.outdated_user_ids.difference_update(user_ids) + logger.info('Stopped tracking users: %s.', user_ids) + + def update_user_device_keys(self, user_ids, since_token=None): + """Triggers an update for users we already track. + + Args: + user_ids (iterable): Any iterable containing user ids. + since_token (str): Optional. Since token of a sync request, if triggering + the update as a result of that sync request. + """ + user_ids = self.tracked_user_ids.intersection(user_ids) + if not user_ids: + return + logger.info('Updating the device lists of users: %s, using token %s', + user_ids, since_token) + self._add_outdated_users(user_ids, since_token=since_token) + + def _add_outdated_users(self, user_ids, since_token=None): + """Stop tracking users. Keys will not be deleted. + + Args: + user_ids (iterable): Any iterable containing user ids. + since_token (str): Optional. Since token of a sync request. + """ + with self.thread_condition: + self.outdated_user_ids.update(user_ids) + if since_token: + self.outdated_user_ids.sync_token = since_token + self.thread_condition.notify() + + def _download_device_keys(self, user_devices, since_token=None): + """Download and store device keys, if they pass security checks. + + Args: + user_devices (dict): Format is ``user_id: [device_ids]``. + since_token (str): Optional. Since token of a sync request. + """ + changed = defaultdict(dict) + resp = self.api.query_keys(user_devices, token=since_token) + if resp.get('failures'): + logger.warning('Failed to download keys from the following unreachable ' + 'homeservers %s.', resp['failures']) + device_keys = resp['device_keys'] + for user_id in user_devices: + # The response might not contain every user_ids we requested + for device_id, payload in device_keys.get(user_id, {}).items(): + if device_id == self.olm_device.device_id: + continue + if payload['user_id'] != user_id or payload['device_id'] != device_id: + logger.warning('Mismatch in keys payload of device %s (%s) of user ' + '%s (%s).', payload['device_id'], device_id, + payload['user_id'], user_id) + continue + try: + signing_key = payload['keys']['ed25519:{}'.format(device_id)] + curve_key = payload['keys']['curve25519:{}'.format(device_id)] + except KeyError as e: + logger.warning('Invalid identity keys payload from device %s of' + 'user %s: %s.', device_id, user_id, e) + continue + verified = self.olm_device.verify_json( + payload, signing_key, user_id, device_id) + if not verified: + logger.warning('Signature verification failed for device %s of ' + 'user %s.', device_id, user_id) + continue + keys = self.device_keys[user_id].setdefault(device_id, {}) + if keys: + if keys['ed25519'] != signing_key: + logger.warning('Ed25519 key has changed for device %s of ' + 'user %s.', device_id, user_id) + continue + if keys['curve25519'] == curve_key: + continue + else: + keys['ed25519'] = signing_key + keys['curve25519'] = curve_key + changed[user_id][device_id] = keys + + logger.info('Successfully downloaded keys for devices: %s.', + {user_id: list(changed[user_id]) for user_id in changed}) + return changed + + +class _OutdatedUsersSet(set): + """Allows to know if elements in a set have been processed. + + This is done by adding elements along with an Event object. Then, functions + processing the set should set the events when they are done. + """ + + def __init__(self, iterable=()): + self.events = set() + self._sync_token = None + super(_OutdatedUsersSet, self).__init__(iterable) + + def mark_as_processed(self): + for event in self.events: + event.set() + + def copy(self): + new_set = _OutdatedUsersSet(self) + new_set.events = self.events.copy() + return new_set + + def clear(self): + self.events.clear() + super(_OutdatedUsersSet, self).clear() + + def update(self, iterable): + super(_OutdatedUsersSet, self).update(iterable) + if isinstance(iterable, _OutdatedUsersSet): + self.events.update(iterable.events) + + @property + def sync_token(self): + return self._sync_token + + @sync_token.setter + def sync_token(self, token): + if not self._sync_token or token > self._sync_token: + self._sync_token = token + + +class _UpdateDeviceList(Thread): + + def __init__(self, cond, user_ids, download_method, tracked_user_ids): + # We wait on this condition when there is nothing to do. Outside code should use + # it to notify us when they add data to be processed in outdated_user_ids so that + # we can wake up and process it. + self.cond = cond + self.outdated_user_ids = user_ids + self.download = download_method + self.tracked_user_ids = tracked_user_ids + # Cleared when we start a download, and set when we have finished it. This can be + # used by outside code in order to know if we are in the middle of a download, and + # allows to wait for it to complete by waiting on this event. + self.event = Event() + # Used internally to terminate gracefully on program exit. + self._should_terminate = Event() + super(_UpdateDeviceList, self).__init__() + + def run(self): + while True and not self._should_terminate.is_set(): + with self.cond: + while not self.outdated_user_ids: + # Avoid any deadlocks + self.outdated_user_ids.mark_as_processed() + self.event.set() + logger.debug('Update thread is going to sleep...') + self.cond.wait() + logger.debug('Update thread woke up!') + if self._should_terminate.is_set(): + return + to_download = self.outdated_user_ids.copy() + self.outdated_user_ids.clear() + self.event.clear() + self.tracked_user_ids.update(to_download) + payload = {user_id: [] for user_id in to_download} + logger.info('Downloading device keys for users: %s.', to_download) + try: + self.download(payload, self.outdated_user_ids.sync_token) + self.event.set() + to_download.mark_as_processed() + except (MatrixHttpLibError, MatrixRequestError) as e: + logger.warning('Network error when fetching device keys (will retry): %s', + e) + with self.cond: + self.outdated_user_ids.update(to_download) + + def join(self, timeout=None): + # If we are joined, this means that the main program is terminating. + # We should terminate too. + self._should_terminate.set() + with self.cond: + self.cond.notify() + super(_UpdateDeviceList, self).join(timeout=timeout) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 514965db..d2e114b1 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -1,10 +1,12 @@ import logging +from collections import defaultdict import olm from canonicaljson import encode_canonical_json from matrix_client.checks import check_user_id from matrix_client.crypto.one_time_keys import OneTimeKeysManager +from matrix_client.crypto.device_list import DeviceList logger = logging.getLogger(__name__) @@ -59,6 +61,8 @@ def __init__(self, self.one_time_keys_manager = OneTimeKeysManager(target_keys_number, signed_keys_proportion, keys_threshold) + self.device_keys = defaultdict(dict) + self.device_list = DeviceList(self, api, self.device_keys) def upload_identity_keys(self): """Uploads this device's identity keys to HS. From c8496d39e7459e2afd5d73baa4e51ccad7f173cf Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sun, 24 Jun 2018 04:17:39 +0200 Subject: [PATCH 02/66] add device tracking tests --- test/crypto/device_list_test.py | 272 ++++++++++++++++++++++++++++++++ test/crypto/olm_device_test.py | 1 + test/response_examples.py | 31 ++++ 3 files changed, 304 insertions(+) create mode 100644 test/crypto/device_list_test.py diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py new file mode 100644 index 00000000..32f75ce6 --- /dev/null +++ b/test/crypto/device_list_test.py @@ -0,0 +1,272 @@ +import pytest +pytest.importorskip("olm") # noqa + +import json +from copy import deepcopy +from threading import Event, Condition + +import responses + +from matrix_client.api import MATRIX_V2_API_PATH +from matrix_client.client import MatrixClient +from matrix_client.room import User +from matrix_client.errors import MatrixRequestError +from matrix_client.crypto.olm_device import OlmDevice +from matrix_client.crypto.device_list import (_OutdatedUsersSet as OutdatedUsersSet, + _UpdateDeviceList as UpdateDeviceList) +from test.response_examples import example_key_query_response + +HOSTNAME = 'http://example.com' + + +class TestDeviceList: + cli = MatrixClient(HOSTNAME) + user_id = '@test:example.com' + alice = '@alice:example.com' + room_id = '!test:example.com' + device_id = 'AUIETSRN' + device = OlmDevice(cli.api, user_id, device_id) + device_list = device.device_list + signing_key = device.olm_account.identity_keys['ed25519'] + query_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/query' + + @responses.activate + def test_download_device_keys(self): + # The method we want to test + download_device_keys = self.device_list._download_device_keys + bob = '@bob:example.com' + eve = '@eve:example.com' + user_devices = {self.alice: [], bob: [], self.user_id: []} + + # This response is correct for Alice's keys, but lacks Bob's + # There are no failures + resp = example_key_query_response + responses.add(responses.POST, self.query_url, json=resp) + + # Still correct, but Alice's identity key has changed + resp = deepcopy(example_key_query_response) + new_id_key = 'ijxGZqwB/UvMtKABdaCdrI0OtQI6NhHBYiknoCkdWng' + payload = resp['device_keys'][self.alice]['JLAFKJWSCS'] + payload['keys']['curve25519:JLAFKJWSCS'] = new_id_key + payload['signatures'][self.alice]['ed25519:JLAFKJWSCS'] = \ + ('D9oLtYefMIr4StiHTIzn3+bhtPCfrZNDU9jsUbMu3MicfZLl4d8WlYn3TPmbwDi8XMGcT' + 'nNnqfdi/tYUPvKfCA') + responses.add(responses.POST, self.query_url, json=resp) + + # Still correct, but Alice's signing key has changed + alice_device = OlmDevice(self.cli.api, self.alice, 'JLAFKJWSCS') + resp = deepcopy(example_key_query_response) + resp['device_keys'][self.alice]['JLAFKJWSCS']['keys']['ed25519:JLAFKJWSCS'] = \ + alice_device.identity_keys['ed25519'] + resp['device_keys'][self.alice]['JLAFKJWSCS'] = \ + alice_device.sign_json(resp['device_keys'][self.alice]['JLAFKJWSCS']) + responses.add(responses.POST, self.query_url, json=resp) + + # Response containing an unknown user + resp = deepcopy(example_key_query_response) + user_device = resp['device_keys'].pop(self.alice) + resp['device_keys'][eve] = user_device + responses.add(responses.POST, self.query_url, json=resp) + + # Response with an invalid signature + resp = deepcopy(example_key_query_response) + resp['device_keys'][self.alice]['JLAFKJWSCS']['test'] = 1 + responses.add(responses.POST, self.query_url, json=resp) + + # Response with a requested user and valid signature, but with a mismatch + resp = deepcopy(example_key_query_response) + user_device = resp['device_keys'].pop(self.alice) + resp['device_keys'][bob] = user_device + responses.add(responses.POST, self.query_url, json=resp) + + # Response with an invalid keys field + resp = deepcopy(example_key_query_response) + keys_field = resp['device_keys'][self.alice]['JLAFKJWSCS']['keys'] + key = keys_field.pop("ed25519:JLAFKJWSCS") + keys_field["ed25519:wrong"] = key + # Cover a missing branch by adding failures + resp["failures"]["other.com"] = {} + # And one more by adding ourself + resp['device_keys'][self.user_id] = {self.device_id: 'dummy'} + responses.add(responses.POST, self.query_url, json=resp) + + self.device.device_keys.clear() + assert download_device_keys(user_devices) + req = json.loads(responses.calls[0].request.body) + assert req['device_keys'] == {self.alice: [], bob: [], self.user_id: []} + expected_device_keys = { + self.alice: { + 'JLAFKJWSCS': { + 'curve25519': '3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI', + 'ed25519': 'VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA' + } + } + } + assert self.device.device_keys == expected_device_keys + + # Different curve25519, key should get updated + assert download_device_keys(user_devices) + expected_device_keys[self.alice]['JLAFKJWSCS']['curve25519'] = new_id_key + assert self.device.device_keys == expected_device_keys + + # Different ed25519, key should not get updated + assert not download_device_keys(user_devices) + assert self.device.device_keys == expected_device_keys + + self.device.device_keys.clear() + # All the remaining responses are wrong and we should not add the key + for _ in range(4): + assert not download_device_keys(user_devices) + assert self.device.device_keys == {} + + assert len(responses.calls) == 7 + + @responses.activate + def test_update_thread(self): + # Normal run + event = Event() + outdated_users = OutdatedUsersSet({self.user_id}) + outdated_users.events.add(event) + + def dummy_download(user_devices, since_token=None): + assert user_devices == {self.user_id: []} + return + thread = UpdateDeviceList(Condition(), outdated_users, dummy_download, set()) + + thread.start() + event.wait() + assert not thread.outdated_user_ids + assert thread.event.is_set() + assert thread.tracked_user_ids == {self.user_id} + thread.join() + assert not thread.is_alive() + + # Error run + outdated_users = OutdatedUsersSet({self.user_id}) + + def error_on_first_download(user_devices, since_token=None): + error_on_first_download.c += 1 + if error_on_first_download.c == 1: + raise MatrixRequestError + return + error_on_first_download.c = 0 + thread = UpdateDeviceList( + Condition(), outdated_users, error_on_first_download, set()) + thread.start() + thread.event.wait() + assert error_on_first_download.c == 2 + assert not thread.outdated_user_ids + thread.join() + + # Cover a missing branch + thread = UpdateDeviceList( + Condition(), outdated_users, error_on_first_download, set()) + thread._should_terminate.set() + thread.start() + thread.join() + assert not thread.is_alive() + + @responses.activate + def test_get_room_device_keys(self): + self.device_list.tracked_user_ids.clear() + room = self.cli._mkroom(self.room_id) + room._members[self.alice] = User(self.cli.api, self.alice) + + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + # Blocking + self.device_list.get_room_device_keys(room) + assert self.device_list.tracked_user_ids == {self.alice} + assert self.device_list.device_keys[self.alice]['JLAFKJWSCS'] + + # Same, but we already track the user + self.device_list.get_room_device_keys(room) + + # Non-blocking + self.device_list.tracked_user_ids.clear() + # We have to block for testing purposes, though + self.device_list.update_thread.event.clear() + self.device_list.get_room_device_keys(room, blocking=False) + self.device_list.update_thread.event.wait() + + # Same, but we already track the user + self.device_list.get_room_device_keys(room, blocking=False) + + @responses.activate + def test_add_users(self): + self.device_list.tracked_user_ids.clear() + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + self.device_list.update_thread.event.clear() + self.device_list.add_users({self.alice}) + self.device_list.update_thread.event.wait() + assert self.device_list.tracked_user_ids == {self.alice} + assert len(responses.calls) == 1 + + # Same, but we are already tracking Alice + self.device_list.add_users({self.alice}) + assert len(responses.calls) == 1 + + def test_stop_tracking_users(self): + self.device_list.tracked_user_ids.clear() + self.device_list.tracked_user_ids.add(self.alice) + self.device_list.outdated_user_ids.clear() + self.device_list.outdated_user_ids.add(self.alice) + + self.device_list.stop_tracking_users({self.alice}) + + assert not self.device_list.tracked_user_ids + assert not self.device_list.outdated_user_ids + + @responses.activate + def test_update_user_device_keys(self): + self.device_list.tracked_user_ids.clear() + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + self.device_list.update_user_device_keys({self.alice}) + assert len(responses.calls) == 0 + + self.device_list.tracked_user_ids.add(self.alice) + + self.device_list.update_thread.event.clear() + self.device_list.update_user_device_keys({self.alice}, since_token='dummy') + self.device_list.update_thread.event.wait() + assert len(responses.calls) == 1 + + +def test_outdated_users_set(): + s = OutdatedUsersSet() + assert not s + + s = OutdatedUsersSet({1}) + event = Event() + s.events.add(event) + assert s == {1} + + # Make a manual copy of s + t = OutdatedUsersSet() + t.add(1) + t.events.add(event) + assert t == s and t.events == s.events + + u = s.copy() + event2 = Event() + u.add(2) + u.events.add(event2) + # Check that modifying u didn't change s + assert t == s and t.events == s.events + + s.update(u) + assert s == {1, 2} and s.events == {event, event2} + + s.mark_as_processed() + assert event.is_set() + + new = 's72594_4483_1935' + s.sync_token = new + old = 's72594_4483_1934' + s.sync_token = old + assert s.sync_token == new + + s.clear() + assert not s and not s.events diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 84f56df6..cab7c56d 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -17,6 +17,7 @@ class TestOlmDevice: cli = MatrixClient(HOSTNAME) user_id = '@user:matrix.org' + room_id = '!test:example.com' device_id = 'QBUAZIFURK' device = OlmDevice(cli.api, user_id, device_id) signing_key = device.olm_account.identity_keys['ed25519'] diff --git a/test/response_examples.py b/test/response_examples.py index 2d45aa86..d827a05b 100644 --- a/test/response_examples.py +++ b/test/response_examples.py @@ -184,3 +184,34 @@ "og:image:width": 48, "og:title": "Matrix Blog Post" } + +example_key_query_response = { + "failures": {}, + "device_keys": { + "@alice:example.com": { + "JLAFKJWSCS": { + "user_id": "@alice:example.com", + "device_id": "JLAFKJWSCS", + "algorithms": [ + "m.olm.curve25519-aes-sha256", + "m.megolm.v1.aes-sha" + ], + "keys": { + "curve25519:JLAFKJWSCS": ("3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIAr" + "zgyI"), + "ed25519:JLAFKJWSCS": "VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA" + }, + "signatures": { + "@alice:example.com": { + "ed25519:JLAFKJWSCS": + ("wux6Dhjtk7GYPMW54hnx0doVH0NvuUAFBleL5OW99jhbjIutufglAgrYAcu8" + "ueacgNyeSumvtzVIPZXgbB2BCg") + } + }, + "unsigned": { + "device_display_name": "Alice'smobilephone" + } + } + } + } +} From 69d300eec2c6d6f504060227e55987802c319652 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 12 Jun 2018 22:56:34 +0200 Subject: [PATCH 03/66] track devices of users in encrypted rooms This has the effect of tracking the device lists of users proactively. --- matrix_client/client.py | 4 ++++ matrix_client/crypto/device_list.py | 27 +++++++++++++++++++++++++-- matrix_client/room.py | 7 +++++-- test/crypto/device_list_test.py | 16 +++++++++++++--- 4 files changed, 47 insertions(+), 7 deletions(-) diff --git a/matrix_client/client.py b/matrix_client/client.py index f9a2d876..0c2a8af6 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -636,6 +636,10 @@ def _sync(self, timeout_ms=30000): ): listener['callback'](event) + if self._encryption and room.encrypted: + # Track the new users in the room + self.olm_device.device_list.track_pending_users() + for event in sync_room['ephemeral']['events']: event['room_id'] = room_id room._put_ephemeral_event(event) diff --git a/matrix_client/crypto/device_list.py b/matrix_client/crypto/device_list.py index c9a22966..c0bf406a 100644 --- a/matrix_client/crypto/device_list.py +++ b/matrix_client/crypto/device_list.py @@ -1,6 +1,6 @@ import logging from collections import defaultdict -from threading import Thread, Condition, Event +from threading import Thread, Condition, Event, Lock from matrix_client.errors import MatrixHttpLibError, MatrixRequestError @@ -25,6 +25,9 @@ def __init__(self, olm_device, api, device_keys): self.device_keys = device_keys # Stores the ids of users who need updating self.outdated_user_ids = _OutdatedUsersSet() + # Stores the ids of users to fetch the device keys of eventually + self.pending_outdated_user_ids = set() + self.pending_users_lock = Lock() # Stores the ids of users we are currently tracking. We can assume the device # keys of these users are up-to-date as long as no downloading is in progress. # We should track every user we share an encrypted room with. @@ -67,7 +70,27 @@ def get_room_device_keys(self, room, blocking=True): if blocking: event.wait() - def add_users(self, user_ids): + def track_user_no_download(self, user_id): + """Add user to be tracked, but do not track it instantly. + + This should be used to avoid making calls to :func:`track_users` with only one + user repeatedly. Instead, using this should allow to passively queue the users, + and tracking can be triggered when there are sufficiently, using + :func:`track_pending_users`. + + Args: + user_id (str): A user id. + """ + with self.pending_users_lock: + self.pending_outdated_user_ids.add(user_id) + + def track_pending_users(self): + """Triggers the tracking of the user added with :func:`track_user_no_download`.""" + with self.pending_users_lock: + self.track_users(self.pending_outdated_user_ids) + self.pending_outdated_user_ids.clear() + + def track_users(self, user_ids): """Add users to be tracked, and download their device keys. NOTE: this is non-blocking and will return before the keys are downloaded. diff --git a/matrix_client/room.py b/matrix_client/room.py index 488f335b..deec52d6 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -660,11 +660,14 @@ def _process_state_event(self, state_event): self.encrypted = True elif etype == "m.room.member" and clevel == clevel.ALL: # tracking room members can be large e.g. #matrix:matrix.org + user_id = state_event["state_key"] if econtent["membership"] == "join": - user_id = state_event["state_key"] self._add_member(user_id, econtent.get("displayname")) + if self.client._encryption and self.encrypted: + # Track the device list of this user + self.client.olm_device.device_list.track_user_no_download(user_id) elif econtent["membership"] in ("leave", "kick", "invite"): - self._members.pop(state_event["state_key"], None) + self._members.pop(user_id, None) for listener in self.state_listeners: if ( diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py index 32f75ce6..69df1ed3 100644 --- a/test/crypto/device_list_test.py +++ b/test/crypto/device_list_test.py @@ -193,18 +193,18 @@ def test_get_room_device_keys(self): self.device_list.get_room_device_keys(room, blocking=False) @responses.activate - def test_add_users(self): + def test_track_users(self): self.device_list.tracked_user_ids.clear() responses.add(responses.POST, self.query_url, json=example_key_query_response) self.device_list.update_thread.event.clear() - self.device_list.add_users({self.alice}) + self.device_list.track_users({self.alice}) self.device_list.update_thread.event.wait() assert self.device_list.tracked_user_ids == {self.alice} assert len(responses.calls) == 1 # Same, but we are already tracking Alice - self.device_list.add_users({self.alice}) + self.device_list.track_users({self.alice}) assert len(responses.calls) == 1 def test_stop_tracking_users(self): @@ -218,6 +218,16 @@ def test_stop_tracking_users(self): assert not self.device_list.tracked_user_ids assert not self.device_list.outdated_user_ids + def test_pending_users(self): + # Say Alice is already tracked to avoid triggering dowload process + self.device_list.tracked_user_ids.add(self.alice) + + self.device_list.track_user_no_download(self.alice) + assert self.alice in self.device_list.pending_outdated_user_ids + + self.device_list.track_pending_users() + assert self.alice not in self.device_list.pending_outdated_user_ids + @responses.activate def test_update_user_device_keys(self): self.device_list.tracked_user_ids.clear() From 9aadf1b6f755cba1b5fd3fdef771b8952e9fbed1 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 29 Jun 2018 15:24:16 +0200 Subject: [PATCH 04/66] build doc for crypto.device_list --- docs/source/matrix_client.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index af1ebc41..1587a867 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -56,3 +56,8 @@ matrix_client.crypto :members: :undoc-members: :show-inheritance: + +.. automodule:: matrix_client.crypto.device_list + :members: + :undoc-members: + :show-inheritance: From 165d4f44173b6fe117f16be6fbe8c16ec3009b18 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 23 May 2018 16:43:58 +0200 Subject: [PATCH 05/66] add olm encryption, decryption and session --- matrix_client/crypto/olm_device.py | 231 +++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index d2e114b1..c8701f8c 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -1,3 +1,4 @@ +import json import logging from collections import defaultdict @@ -63,6 +64,7 @@ def __init__(self, keys_threshold) self.device_keys = defaultdict(dict) self.device_list = DeviceList(self, api, self.device_keys) + self.olm_sessions = defaultdict(list) def upload_identity_keys(self): """Uploads this device's identity keys to HS. @@ -140,6 +142,235 @@ def update_one_time_key_counts(self, counts): logger.info('Uploading new one-time keys.') self.upload_one_time_keys() + def olm_start_sessions(self, user_devices): + """Start olm sessions with the given devices. + + NOTE: those device keys must already be known. + + Args: + user_devices (dict): A map from user_id to an iterable of device_ids. + The format is ``: []``. + """ + logger.info('Trying to establish Olm sessions with devices: %s.', + dict(user_devices)) + payload = defaultdict(dict) + for user_id in user_devices: + for device_id in user_devices[user_id]: + payload[user_id][device_id] = 'signed_curve25519' + + resp = self.api.claim_keys(payload) + if resp.get('failures'): + logger.warning('Failed to claim one-time keys from the following unreachable ' + 'homeservers: %s.', resp['failures']) + keys = resp['one_time_keys'] + if logger.level >= logging.WARNING: + missing = {} + for user_id, device_ids in user_devices.items(): + if user_id not in keys: + missing[user_id] = device_ids + else: + missing_devices = set(device_ids) - set(keys[user_id]) + if missing_devices: + missing[user_id] = missing_devices + logger.warning('Failed to claim the keys of %s.', missing) + + for user_id in user_devices: + for device_id, one_time_key in keys.get(user_id, {}).items(): + try: + device_keys = self.device_keys[user_id][device_id] + except KeyError: + logger.warning('Key for device %s of user %s not found, could not ' + 'start Olm session.', device_id, user_id) + continue + key_object = next(iter(one_time_key.values())) + verified = self.verify_json(key_object, + device_keys['ed25519'], + user_id, + device_id) + if verified: + session = olm.OutboundSession(self.olm_account, + device_keys['curve25519'], + key_object['key']) + sessions = self.olm_sessions[device_keys['curve25519']] + sessions.append(session) + logger.info('Established Olm session %s with device %s of user ' + '%s.', device_id, session.id, user_id) + else: + logger.warning('Signature verification for one-time key of device %s ' + 'of user %s failed, could not start olm session.', + device_id, user_id) + + def olm_build_encrypted_event(self, event_type, content, user_id, device_id): + """Encrypt an event using Olm. + + NOTE: a session with this device must already be established. + + Args: + event_type (str): The event type, will be encrypted. + content (dict): The event content, will be encrypted. + user_id (str): The intended recipient of the event. + device_id (str): The device to encrypt to. + + Returns: + The Olm encrypted event, as JSON. + """ + try: + keys = self.device_keys[user_id][device_id] + except KeyError: + raise RuntimeError('Device is unknown, could not encrypt.') + + signing_key = keys['ed25519'] + identity_key = keys['curve25519'] + + payload = { + 'type': event_type, + 'content': content, + 'sender': self.user_id, + 'sender_device': self.device_id, + 'keys': { + 'ed25519': self.identity_keys['ed25519'] + }, + 'recipient': user_id, + 'recipient_keys': { + 'ed25519': signing_key + } + } + + sessions = self.olm_sessions[identity_key] + if sessions: + session = sorted(sessions, key=lambda s: s.id)[0] + else: + raise RuntimeError('No session for this device, could not encrypt.') + + encrypted_message = session.encrypt(json.dumps(payload)) + ciphertext_payload = { + identity_key: { + 'type': encrypted_message.message_type, + 'body': encrypted_message.ciphertext + } + } + + event = { + 'algorithm': self._olm_algorithm, + 'sender_key': self.identity_keys['curve25519'], + 'ciphertext': ciphertext_payload + } + return event + + def olm_decrypt_event(self, content, user_id): + """Decrypt an Olm encrypted event, and check its properties. + + Args: + event (dict): The content property of a m.room.encrypted event. + user_id (str): The sender of the event. + + Retuns: + The decrypted event held by the initial event. + + Raises: + RuntimeError: Error in the decryption process. Nothing can be done. The text + of the exception indicates what went wrong, and should be logged or + displayed to the user. + KeyError: The event is missing a required field. + """ + if content['algorithm'] != self._olm_algorithm: + raise RuntimeError('Event was not encrypted with {}.' + .format(self._olm_algorithm)) + + ciphertext = content['ciphertext'] + try: + payload = ciphertext[self.identity_keys['curve25519']] + except KeyError: + raise RuntimeError('This message was not encrypted for us.') + + msg_type = payload['type'] + if msg_type == 0: + encrypted_message = olm.OlmPreKeyMessage(payload['body']) + else: + encrypted_message = olm.OlmMessage(payload['body']) + + decrypted_event = self._olm_decrypt(encrypted_message, content['sender_key']) + + if decrypted_event['sender'] != user_id: + raise RuntimeError( + 'Found user {} instead of sender {} in Olm plaintext {}.' + .format(decrypted_event['sender'], user_id, decrypted_event) + ) + if decrypted_event['recipient'] != self.user_id: + raise RuntimeError( + 'Found user {} instead of us ({}) in Olm plaintext {}.' + .format(decrypted_event['recipient'], self.user_id, decrypted_event) + ) + our_key = decrypted_event['recipient_keys']['ed25519'] + if our_key != self.identity_keys['ed25519']: + raise RuntimeError( + 'Found key {} instead of ours own ed25519 key {} in Olm plaintext {}.' + .format(our_key, self.identity_keys['ed25519'], decrypted_event) + ) + + return decrypted_event + + def _olm_decrypt(self, olm_message, sender_key): + """Decrypt an Olm encrypted event. + + NOTE: This does no implement any security check. + + Try to decrypt using existing sessions. If it fails, start an new one when + possible. + + Args: + olm_message (OlmMessage): Olm encrypted payload. + sender_key (str): The sender's curve25519 identity key. + + Returns: + The decrypted event held by the initial payload, as JSON. + """ + + sessions = self.olm_sessions[sender_key] + + # Try to decrypt message body using one of the known sessions for that device + for session in sessions: + try: + event = session.decrypt(olm_message) + logger.info('Success decrypting Olm event using existing session %s.', + session.id) + break + except olm.session.OlmSessionError as e: + if olm_message.message_type == 0: + if session.matches(olm_message, sender_key): + # We had a matching session for a pre-key message, but it didn't + # work. This means something is wrong, so we fail now. + raise RuntimeError('Error decrypting pre-key message with ' + 'existing Olm session {}, reason: {}.' + .format(session.id, e)) + # Simply keep trying otherwise + else: + if olm_message.message_type > 0: + # Not a pre-key message, we should have had a matching session + if sessions: + raise RuntimeError('Error decrypting with existing sessions.') + raise RuntimeError('No existing sessions.') + + # We have a pre-key message without any matching session, in this case + # we should try to create one. + try: + session = olm.session.InboundSession( + self.olm_account, olm_message, sender_key) + except olm.session.OlmSessionError as e: + raise RuntimeError('Error decrypting pre-key message when trying to ' + 'establish a new session: {}.'.format(e)) + + logger.info('Created new Olm session %s.', session.id) + try: + event = session.decrypt(olm_message) + except olm.session.OlmSessionError as e: + raise RuntimeError('Error decrypting pre-key message with new session: ' + '{}.'.format(e)) + self.olm_account.remove_one_time_keys(session) + sessions.append(session) + + return json.loads(event) + def sign_json(self, json): """Signs a JSON object. From c689ec17d80d998f250fa757813333add66fa680 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Mon, 25 Jun 2018 19:03:06 +0200 Subject: [PATCH 06/66] add olm tests --- test/crypto/olm_device_test.py | 195 ++++++++++++++++++++++++++++++++- test/response_examples.py | 21 ++++ 2 files changed, 214 insertions(+), 2 deletions(-) diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index cab7c56d..69f692c4 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -1,15 +1,18 @@ import pytest -pytest.importorskip("olm") # noqa +olm = pytest.importorskip("olm") # noqa import json +import logging from copy import deepcopy import responses +from matrix_client.crypto import olm_device from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient from matrix_client.crypto.olm_device import OlmDevice -from test.response_examples import example_key_upload_response +from test.response_examples import (example_key_upload_response, + example_claim_keys_response) HOSTNAME = 'http://example.com' @@ -21,6 +24,13 @@ class TestOlmDevice: device_id = 'QBUAZIFURK' device = OlmDevice(cli.api, user_id, device_id) signing_key = device.olm_account.identity_keys['ed25519'] + alice = '@alice:example.com' + alice_device_id = 'JLAFKJWSCS' + alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' + alice_identity_keys = { + 'curve25519': alice_curve_key, + 'ed25519': '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' + } def test_sign_json(self): example_payload = { @@ -205,3 +215,184 @@ def test_invalid_keys_threshold(self, threshold): self.user_id, self.device_id, keys_threshold=threshold) + + @responses.activate + def test_olm_start_sessions(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + self.device.olm_sessions.clear() + self.device.device_keys.clear() + + user_devices = {self.alice: {self.alice_device_id}} + + # We don't have alice's keys + self.device.olm_start_sessions(user_devices) + assert not self.device.olm_sessions[self.alice_curve_key] + + # Cover logging part + olm_device.logger.setLevel(logging.WARNING) + # Now should be good + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.olm_start_sessions(user_devices) + assert self.device.olm_sessions[self.alice_curve_key] + + # With failures and wrong signature + self.device.olm_sessions.clear() + payload = deepcopy(example_claim_keys_response) + payload['failures'] = {'dummy': 1} + key = payload['one_time_keys'][self.alice][self.alice_device_id] + key['signed_curve25519:AAAAAQ']['test'] = 1 + responses.replace(responses.POST, claim_url, json=payload) + + self.device.olm_start_sessions(user_devices) + assert not self.device.olm_sessions[self.alice_curve_key] + + # Missing requested user and devices + user_devices[self.alice].add('test') + user_devices['test'] = 'test' + + self.device.olm_start_sessions(user_devices) + + @responses.activate + def test_olm_build_encrypted_event(self): + self.device.device_keys.clear() + self.device.olm_sessions.clear() + event_content = {'dummy': 'example'} + + # We don't have Alice's keys + with pytest.raises(RuntimeError): + self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + # We don't have a session with Alice + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + with pytest.raises(RuntimeError): + self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + user_devices = {self.alice: {self.alice_device_id}} + self.device.olm_start_sessions(user_devices) + assert self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + def test_olm_decrypt(self): + self.device.olm_sessions.clear() + # Since this method doesn't care about high-level event formatting, we will + # generate things at low level + our_account = self.device.olm_account + # Alice needs to start a session with us + alice = olm.Account() + sender_key = alice.identity_keys['curve25519'] + our_account.generate_one_time_keys(1) + otk = next(iter(our_account.one_time_keys['curve25519'].values())) + self.device.olm_account.mark_keys_as_published() + session = olm.OutboundSession(alice, our_account.identity_keys['curve25519'], otk) + + plaintext = {"test": "test"} + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # New pre-key message, but the session exists this time + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # Try to decrypt the same message twice + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Answer Alice in order to have a type 1 message + message = self.device.olm_sessions[sender_key][0].encrypt(json.dumps(plaintext)) + session.decrypt(message) + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # Try to decrypt the same message type 1 twice + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt a message from a session that reused a one-time key + otk_reused_session = olm.OutboundSession( + alice, our_account.identity_keys['curve25519'], otk) + message = otk_reused_session.encrypt(json.dumps(plaintext)) + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt an invalid type 0 message + our_account.generate_one_time_keys(1) + otk = next(iter(our_account.one_time_keys['curve25519'].values())) + wrong_session = olm.OutboundSession(alice, sender_key, otk) + message = wrong_session.encrypt(json.dumps(plaintext)) + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt a type 1 message for which we have no sessions + message = session.encrypt(json.dumps(plaintext)) + self.device.olm_sessions.clear() + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + def test_olm_decrypt_event(self): + self.device.device_keys.clear() + self.device.olm_sessions.clear() + alice_device = OlmDevice(self.device.api, self.alice, self.alice_device_id) + alice_device.device_keys[self.user_id][self.device_id] = self.device.identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = \ + alice_device.identity_keys + + # Artificially start an Olm session from Alice + self.device.olm_account.generate_one_time_keys(1) + otk = next(iter(self.device.olm_account.one_time_keys['curve25519'].values())) + self.device.olm_account.mark_keys_as_published() + sender_key = self.device.identity_keys['curve25519'] + session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) + alice_device.olm_sessions[sender_key] = [session] + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + + # Now we can test + self.device.olm_decrypt_event(encrypted_event, self.alice) + + # Type 1 Olm payload + alice_device.olm_decrypt_event( + self.device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.alice, self.alice_device_id + ), + self.user_id) + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.olm_decrypt_event(encrypted_event, self.alice) + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, 'wrong') + + wrong_event = deepcopy(encrypted_event) + wrong_event['algorithm'] = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(wrong_event, self.alice) + + wrong_event = deepcopy(encrypted_event) + wrong_event['ciphertext'] = {} + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(wrong_event, self.alice) + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.user_id = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, self.alice) + self.device.user_id = self.user_id + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + backup = self.device.identity_keys['ed25519'] + self.device.identity_keys['ed25519'] = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, self.alice) + self.device.identity_keys['ed25519'] = backup diff --git a/test/response_examples.py b/test/response_examples.py index d827a05b..d6a2ca49 100644 --- a/test/response_examples.py +++ b/test/response_examples.py @@ -215,3 +215,24 @@ } } } + +example_claim_keys_response = { + "failures": {}, + "one_time_keys": { + "@alice:example.com": { + "JLAFKJWSCS": { + 'signed_curve25519:AAAAAQ': { + 'key': '9UOzQjF2j2Xf8mBIiMgruuCkuWtD0ea9kvx63mO92Ws', + 'signatures': { + '@alice:example.com': { + 'ed25519:JLAFKJWSCS': ( + '6O+VYxN7mVcr/j66YdHASRrpW4ydC/0FcYmEWVAGIFzU4+yjzxxinhQD' + 'l7InhhdGuXeQlk4/w/CyU76TY6wdBA' + ) + } + } + } + } + } + } +} From eb43d6c6b1c1eef41464667435a9e032dcecc7b1 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 30 May 2018 22:43:14 +0200 Subject: [PATCH 07/66] add olm_ensure_sessions --- matrix_client/crypto/olm_device.py | 19 +++++++++++++++++++ test/crypto/olm_device_test.py | 20 ++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index c8701f8c..6f186030 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -371,6 +371,25 @@ def _olm_decrypt(self, olm_message, sender_key): return json.loads(event) + def olm_ensure_sessions(self, user_devices): + """Start Olm sessions with the given devices if one doesn't exist already. + + Args: + user_devices (dict): A map from user ids to a list of device ids. + """ + user_devices_no_session = defaultdict(list) + for user_id in user_devices: + for device_id in user_devices[user_id]: + curve_key = self.device_keys[user_id][device_id]['curve25519'] + # Check if we have a list of sessions for this device, which can be + # empty. Implicitely, an empty list will indicate that we already tried + # to establish a session with a device, but this attempt was + # unsuccessful. We do not retry to establish a session. + if curve_key not in self.olm_sessions: + user_devices_no_session[user_id].append(device_id) + if user_devices_no_session: + self.olm_start_sessions(user_devices_no_session) + def sign_json(self, json): """Signs a JSON object. diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 69f692c4..c67cff93 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -396,3 +396,23 @@ def test_olm_decrypt_event(self): with pytest.raises(RuntimeError): self.device.olm_decrypt_event(encrypted_event, self.alice) self.device.identity_keys['ed25519'] = backup + + @responses.activate + def test_olm_ensure_sessions(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + self.device.olm_sessions.clear() + alice_device_id = 'JLAFKJWSCS' + alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' + self.device.device_keys[self.alice][alice_device_id] = { + 'curve25519': alice_curve_key, + 'ed25519': '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' + } + user_devices = {self.alice: [alice_device_id]} + + self.device.olm_ensure_sessions(user_devices) + assert self.device.olm_sessions[alice_curve_key] + assert len(responses.calls) == 1 + + self.device.olm_ensure_sessions(user_devices) + assert len(responses.calls) == 1 From d3896b4a54a5c5238d137fd7130030efe9e5c665 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Mon, 28 May 2018 12:32:59 +0200 Subject: [PATCH 08/66] add megolm outbound session support --- .../crypto/megolm_outbound_session.py | 56 +++++++++ matrix_client/crypto/olm_device.py | 118 ++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 matrix_client/crypto/megolm_outbound_session.py diff --git a/matrix_client/crypto/megolm_outbound_session.py b/matrix_client/crypto/megolm_outbound_session.py new file mode 100644 index 00000000..5793c35f --- /dev/null +++ b/matrix_client/crypto/megolm_outbound_session.py @@ -0,0 +1,56 @@ +from datetime import datetime, timedelta + +from olm import OutboundGroupSession + + +class MegolmOutboundSession(OutboundGroupSession): + + """Outbound group session aware of the users it is shared with. + + Also remembers the time it was created and the number of messages it has encrypted, + in order to know if it needs to be rotated. + + Args: + max_age (datetime.timedelta): Optional. The maximum time the session should + exist. + max_messages (int): Optional. The maximum number of messages that should be sent. + A new message in considered sent each time there is a call to ``encrypt``. + """ + + def __init__(self, max_age=timedelta(days=7), max_messages=100): + self.devices = set() + self.max_age = max_age + self.max_messages = max_messages + self.creation_time = datetime.now() + self.message_count = 0 + super(MegolmOutboundSession, self).__init__() + + def __new__(cls, **kwargs): + return super(MegolmOutboundSession, cls).__new__(cls) + + def add_device(self, device_id): + """Adds a device the session is shared with.""" + self.devices.add(device_id) + + def add_devices(self, device_ids): + """Adds devices the session is shared with. + + Args: + device_ids (iterable): An iterable of device ids, preferably a set. + """ + self.devices.update(device_ids) + + def should_rotate(self): + """Wether the session should be rotated. + + Returns: + True if it should, False if not. + """ + if self.message_count >= self.max_messages or \ + datetime.now() - self.creation_time >= self.max_age: + return True + return False + + def encrypt(self, plaintext): + self.message_count += 1 + return super(MegolmOutboundSession, self).encrypt(plaintext) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 6f186030..0cb05b00 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -8,6 +8,7 @@ from matrix_client.checks import check_user_id from matrix_client.crypto.one_time_keys import OneTimeKeysManager from matrix_client.crypto.device_list import DeviceList +from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession logger = logging.getLogger(__name__) @@ -65,6 +66,7 @@ def __init__(self, self.device_keys = defaultdict(dict) self.device_list = DeviceList(self, api, self.device_keys) self.olm_sessions = defaultdict(list) + self.megolm_outbound_sessions = {} def upload_identity_keys(self): """Uploads this device's identity keys to HS. @@ -390,6 +392,122 @@ def olm_ensure_sessions(self, user_devices): if user_devices_no_session: self.olm_start_sessions(user_devices_no_session) + def megolm_start_session(self, room): + """Start a megolm session in a room, and share it with its members. + + Args: + room (Room): The room to use. + + Returns: + The newly created session. + """ + session = MegolmOutboundSession() + self.megolm_outbound_sessions[room.room_id] = session + logger.info('Starting a new Meglom outbound session %s in %s.', + session.id, room.room_id) + + users = room.get_joined_members() + user_devices = {user.user_id: list(self.device_keys[user.user_id]) + for user in users} + self.device_list.get_room_device_keys(room) + self.megolm_share_session(room.room_id, user_devices, session) + return session + + def megolm_share_session(self, room_id, user_devices, session): + """Share an already existing outbound megolm session with the specified devices. + + Args: + room_id (str): The room corresponding to the session. + user_devices (dict): A map from user ids to a list of device ids. + session (MegolmOutboundSession): The session object. + """ + + logger.info('Attempting to share Megolm session %s in %s with %s.', + session.id, room_id, user_devices) + self.olm_ensure_sessions(user_devices) + + event = { + 'algorithm': self._megolm_algorithm, + 'room_id': room_id, + 'session_id': session.id, + 'session_key': session.session_key + } + + messages = defaultdict(dict) + new_devices = set() + for user_id in user_devices: + for device_id in user_devices[user_id]: + try: + messages[user_id][device_id] = self.olm_build_encrypted_event( + 'm.room_key', event, user_id, device_id + ) + except RuntimeError as e: + logger.warning('Could not share megolm session %s with device %s of ' + 'user %s: %s', session.id, + device_id, user_id, e) + # We will not retry to share session with failed devices + new_devices.add(device_id) + self.api.send_to_device('m.room.encrypted', messages) + session.add_devices(new_devices) + + def megolm_share_session_with_new_devices(self, room, session): + """Share a megolm session with new devices in a room. + + Args: + room (Room): The room corresponding to the session. + session (MegolmOutboundSession): The session to share. + """ + user_devices = {} + users = room.get_joined_members() + for user in users: + user_id = user.user_id + missing_devices = list(set(self.device_keys[user_id].keys()) - + self.megolm_outbound_sessions[room.room_id].devices) + if missing_devices: + user_devices[user_id] = missing_devices + if user_devices: + logger.info('Sharing existing Megolm outbound session %s with new devices: ' + '%s', session.id, user_devices) + self.megolm_share_session(room.room_id, user_devices, session) + + def megolm_build_encrypted_event(self, room, event): + """Build an encrypted Megolm payload from a plaintext event. + + If no session exists in the room, a new one will be initiated. Also takes care + of rotating the session periodically. + + Args: + room (Room): The room the event will be sent in. + event (dict): Matrix event. + + Returns: + The encrypted event, as a dict. + """ + room_id = room.room_id + + session = self.megolm_outbound_sessions.get(room_id) + if not session or session.should_rotate(): + session = self.megolm_start_session(room) + else: + self.megolm_share_session_with_new_devices(room, session) + + payload = { + 'type': event['type'], + 'content': event['content'], + 'room_id': room_id + } + + encrypted_payload = session.encrypt(json.dumps(payload)) + + encrypted_event = { + 'algorithm': self._megolm_algorithm, + 'sender_key': self.identity_keys['curve25519'], + 'ciphertext': encrypted_payload, + 'session_id': session.id, + 'device_id': self.device_id + } + return encrypted_event + def sign_json(self, json): """Signs a JSON object. From f18f9b26808a8233f0d726b13753cca977e868da Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Mon, 18 Jun 2018 17:27:37 +0200 Subject: [PATCH 09/66] delete outbound sessions on leave --- matrix_client/crypto/olm_device.py | 14 ++++++++++++++ matrix_client/room.py | 6 ++++++ 2 files changed, 20 insertions(+) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 0cb05b00..4e09ea08 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -508,6 +508,20 @@ def megolm_build_encrypted_event(self, room, event): } return encrypted_event + def megolm_remove_outbound_session(self, room_id): + """Remove an existing Megolm outbound session in a room. + + If there is no such session, nothing will happen. + + Args: + room_id (str): The room to use. + """ + try: + self.megolm_outbound_sessions.pop(room_id) + logger.info('Removed Meglom outbound session in %s.', room_id) + except KeyError: + pass + def sign_json(self, json): """Signs a JSON object. diff --git a/matrix_client/room.py b/matrix_client/room.py index deec52d6..15c427e5 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -668,6 +668,12 @@ def _process_state_event(self, state_event): self.client.olm_device.device_list.track_user_no_download(user_id) elif econtent["membership"] in ("leave", "kick", "invite"): self._members.pop(user_id, None) + if econtent["membership"] != "invite": + if self.client._encryption and self.encrypted: + # Invalidate any outbound session we have in the room when + # someone leaves + self.client.olm_device.megolm_remove_outbound_session( + self.room_id) for listener in self.state_listeners: if ( From de73b9d69b70d6a413c20619297f20955efd7304 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 5 Jun 2018 16:49:03 +0200 Subject: [PATCH 10/66] send encrypted group messages --- matrix_client/crypto/olm_device.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 4e09ea08..ace09c7d 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -522,6 +522,21 @@ def megolm_remove_outbound_session(self, room_id): except KeyError: pass + def send_encrypted_message(self, room, content): + """Send a m.room.encrypted event in a room. + + Args: + room (Room): The room to use. + content (dict): The content of the event, will be encrypted. + + Raises: + MatrixRequestError if there was an error sending the event. + """ + event = {'content': content, 'room_id': room.room_id, 'type': 'm.room.message'} + encrypted_event = self.megolm_build_encrypted_event(room, event) + return self.api.send_message_event( + room.room_id, 'm.room.encrypted', encrypted_event) + def sign_json(self, json): """Signs a JSON object. From c07ee438f168791d194f1098c49613d6cd4fdb7e Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 5 Jun 2018 16:49:30 +0200 Subject: [PATCH 11/66] automatically send encrypted messages in encrypted rooms --- matrix_client/api.py | 40 ++++++++++++++++++++------------- matrix_client/errors.py | 7 ++++++ matrix_client/room.py | 50 ++++++++++++++++++++++++++++++++++------- 3 files changed, 74 insertions(+), 23 deletions(-) diff --git a/matrix_client/api.py b/matrix_client/api.py index 8bd44ecc..102d3e5c 100644 --- a/matrix_client/api.py +++ b/matrix_client/api.py @@ -343,6 +343,18 @@ def send_content(self, room_id, item_url, item_name, msg_type, return self.send_message_event(room_id, "m.room.message", content_pack, timestamp=timestamp) + def get_location_body(self, geo_uri, name, thumb_url=None, thumb_info=None): + content_pack = { + "geo_uri": geo_uri, + "msgtype": "m.location", + "body": name, + } + if thumb_url: + content_pack["thumbnail_url"] = thumb_url + if thumb_info: + content_pack["thumbnail_info"] = thumb_info + return content_pack + # http://matrix.org/docs/spec/client_server/r0.2.0.html#m-location def send_location(self, room_id, geo_uri, name, thumb_url=None, thumb_info=None, timestamp=None): @@ -356,15 +368,8 @@ def send_location(self, room_id, geo_uri, name, thumb_url=None, thumb_info=None, thumb_info (dict): Metadata about the thumbnail, type ImageInfo. timestamp (int): Set origin_server_ts (For application services only) """ - content_pack = { - "geo_uri": geo_uri, - "msgtype": "m.location", - "body": name, - } - if thumb_url: - content_pack["thumbnail_url"] = thumb_url - if thumb_info: - content_pack["thumbnail_info"] = thumb_info + content_pack = self.get_location_body( + geo_uri, name, thumb_url, thumb_info) return self.send_message_event(room_id, "m.room.message", content_pack, timestamp=timestamp) @@ -405,12 +410,11 @@ def send_notice(self, room_id, text_content, timestamp=None): text_content (str): The m.notice body to send. timestamp (int): Set origin_server_ts (For application services only) """ - body = { - "msgtype": "m.notice", - "body": text_content - } - return self.send_message_event(room_id, "m.room.message", body, - timestamp=timestamp) + return self.send_message_event( + room_id, "m.room.message", + self.get_notice_body(text_content), + timestamp=timestamp + ) def get_room_messages(self, room_id, token, direction, limit=10, to=None): """Perform GET /rooms/{roomId}/messages. @@ -675,6 +679,12 @@ def get_emote_body(self, text): "body": text } + def get_notice_body(self, text): + return { + "msgtype": "m.notice", + "body": text + } + def get_filter(self, user_id, filter_id): return self._send("GET", "/user/{userId}/filter/{filterId}" .format(userId=user_id, filterId=filter_id)) diff --git a/matrix_client/errors.py b/matrix_client/errors.py index e9dc8fe3..98bd5eb0 100644 --- a/matrix_client/errors.py +++ b/matrix_client/errors.py @@ -46,3 +46,10 @@ def __init__(self, original_exception, method, endpoint): original_exception) ) self.original_exception = original_exception + + +class MatrixNoEncryptionError(MatrixError): + """Encryption was not available.""" + + def __init__(self, content=""): + super(MatrixNoEncryptionError, self).__init__(content) diff --git a/matrix_client/room.py b/matrix_client/room.py index 15c427e5..8a2f0d94 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -18,7 +18,7 @@ from .checks import check_room_id from .user import User -from .errors import MatrixRequestError +from .errors import MatrixRequestError, MatrixNoEncryptionError class Room(object): @@ -102,7 +102,10 @@ def display_name(self): def send_text(self, text): """Send a plain text message to the room.""" - return self.client.api.send_message(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_text_body(text)) + else: + return self.client.api.send_message(self.room_id, text) def get_html_content(self, html, body=None, msgtype="m.text"): return { @@ -119,8 +122,12 @@ def send_html(self, html, body=None, msgtype="m.text"): html (str): The html formatted message to be sent. body (str): The unformatted body of the message to be sent. """ - return self.client.api.send_message_event( - self.room_id, "m.room.message", self.get_html_content(html, body, msgtype)) + content = self.get_html_content(html, body, msgtype) + if self.encrypted and self.client._encryption: + return self.send_encrypted(content) + else: + return self.client.api.send_message_event( + self.room_id, "m.room.message", content) def set_account_data(self, type, account_data): return self.client.api.set_room_account_data( @@ -142,7 +149,10 @@ def add_tag(self, tag, order=None, content=None): def send_emote(self, text): """Send an emote (/me style) message to the room.""" - return self.client.api.send_emote(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_emote_body(text)) + else: + return self.client.api.send_emote(self.room_id, text) def send_file(self, url, name, **fileinfo): """Send a pre-uploaded file to the room. @@ -163,7 +173,10 @@ def send_file(self, url, name, **fileinfo): def send_notice(self, text): """Send a notice (from bot) message to the room.""" - return self.client.api.send_notice(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_notice_body(text)) + else: + return self.client.api.send_notice(self.room_id, text) # See http://matrix.org/docs/spec/r0.0.1/client_server.html#m-image for the # imageinfo args. @@ -195,8 +208,13 @@ def send_location(self, geo_uri, name, thumb_url=None, **thumb_info): thumb_url (str): URL to the thumbnail of the location. thumb_info (): Metadata about the thumbnail, type ImageInfo. """ - return self.client.api.send_location(self.room_id, geo_uri, name, - thumb_url, thumb_info) + if self.encrypted and self.client._encryption: + content = self.client.api.get_location_body( + geo_uri, name, thumb_url, thumb_info) + return self.send_encrypted(content) + else: + return self.client.api.send_location(self.room_id, geo_uri, name, + thumb_url, thumb_info) def send_video(self, url, name, **videoinfo): """Send a pre-uploaded video to the room. @@ -226,6 +244,22 @@ def send_audio(self, url, name, **audioinfo): return self.client.api.send_content(self.room_id, url, name, "m.audio", extra_information=audioinfo) + def send_encrypted(self, content): + """Send an arbitrary encrypted message to the room. + + Args: + content (dict): The content of a m.room.message event. + + Raises: + ``MatrixNoEncryptionError`` if encryption is not enabled in client, or if + the room is unencrypted. + """ + if not self.client._encryption: + raise MatrixNoEncryptionError('Encryption is not enabled in client.') + if not self.encrypted: + raise MatrixNoEncryptionError('Encryption is not enabled in the room.') + return self.client.olm_device.send_encrypted_message(self, content) + def redact_message(self, event_id, reason=None): """Redacts the message with specified event_id for the given reason. From 24b5a407eae0c9fe67d2637c5d9a5b9e083c12b9 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 27 Jun 2018 16:04:14 +0200 Subject: [PATCH 12/66] add megolm outbound tests --- test/crypto/olm_device_test.py | 133 +++++++++++++++++++++++++++++++-- 1 file changed, 125 insertions(+), 8 deletions(-) diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index c67cff93..7e4967aa 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -4,13 +4,19 @@ import json import logging from copy import deepcopy +try: + from urllib import quote +except ImportError: + from urllib.parse import quote import responses from matrix_client.crypto import olm_device from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient +from matrix_client.user import User from matrix_client.crypto.olm_device import OlmDevice +from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession from test.response_examples import (example_key_upload_response, example_claim_keys_response) @@ -31,6 +37,12 @@ class TestOlmDevice: 'curve25519': alice_curve_key, 'ed25519': '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' } + alice_olm_session = olm.OutboundSession( + device.olm_account, alice_curve_key, alice_curve_key) + room = cli._mkroom(room_id) + room._members[alice] = User(cli.api, alice) + # allow to_device api call to work well with responses + device.api._make_txn_id = lambda: 1 def test_sign_json(self): example_payload = { @@ -402,17 +414,122 @@ def test_olm_ensure_sessions(self): claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' responses.add(responses.POST, claim_url, json=example_claim_keys_response) self.device.olm_sessions.clear() - alice_device_id = 'JLAFKJWSCS' - alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' - self.device.device_keys[self.alice][alice_device_id] = { - 'curve25519': alice_curve_key, - 'ed25519': '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' - } - user_devices = {self.alice: [alice_device_id]} + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + user_devices = {self.alice: [self.alice_device_id]} self.device.olm_ensure_sessions(user_devices) - assert self.device.olm_sessions[alice_curve_key] + assert self.device.olm_sessions[self.alice_curve_key] assert len(responses.calls) == 1 self.device.olm_ensure_sessions(user_devices) assert len(responses.calls) == 1 + + @responses.activate + def test_megolm_share_session(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.olm_sessions.clear() + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.device_keys['dummy']['dummy'] = {'curve25519': 'a', 'ed25519': 'a'} + user_devices = {self.alice: [self.alice_device_id], 'dummy': ['dummy']} + session = MegolmOutboundSession() + + # Sharing with Alice should succeed, but dummy will fail + self.device.megolm_share_session(self.room_id, user_devices, session) + assert session.devices == {self.alice_device_id, 'dummy'} + + req = json.loads(responses.calls[1].request.body)['messages'] + assert self.alice in req + assert 'dummy' not in req + + @responses.activate + def test_megolm_start_session(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.device_list.tracked_user_ids.add(self.alice) + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + + self.device.megolm_start_session(self.room) + session = self.device.megolm_outbound_sessions[self.room_id] + assert self.alice_device_id in session.devices + + @responses.activate + def test_megolm_share_session_with_new_devices(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + session = MegolmOutboundSession() + self.device.megolm_outbound_sessions[self.room_id] = session + + self.device.megolm_share_session_with_new_devices(self.room, session) + assert self.alice_device_id in session.devices + assert len(responses.calls) == 1 + + self.device.megolm_share_session_with_new_devices(self.room, session) + assert len(responses.calls) == 1 + + @responses.activate + def test_megolm_build_encrypted_event(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.megolm_outbound_sessions.clear() + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.device_list.tracked_user_ids.add(self.alice) + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + event = {'type': 'm.room.message', 'content': {'body': 'test'}} + + self.room.rotation_period_msgs = 1 + self.device.megolm_build_encrypted_event(self.room, event) + + self.device.megolm_build_encrypted_event(self.room, event) + + session = self.device.megolm_outbound_sessions[self.room_id] + session.encrypt('test') + self.device.megolm_build_encrypted_event(self.room, event) + assert self.device.megolm_outbound_sessions[self.room_id].id != session.id + + def test_megolm_remove_outbound_session(self): + session = MegolmOutboundSession() + self.device.megolm_outbound_sessions[self.room_id] = session + self.device.megolm_remove_outbound_session(self.room_id) + self.device.megolm_remove_outbound_session(self.room_id) + + @responses.activate + def test_send_encrypted_message(self): + message_url = HOSTNAME + MATRIX_V2_API_PATH + \ + '/rooms/{}/send/m.room.encrypted/1'.format(quote(self.room.room_id)) + responses.add(responses.PUT, message_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + session = MegolmOutboundSession() + session.add_device(self.alice_device_id) + self.device.megolm_outbound_sessions[self.room_id] = session + + self.device.send_encrypted_message(self.room, {'test': 'test'}) + + +def test_megolm_outbound_session(): + session = MegolmOutboundSession(max_messages=1) + + assert not session.devices + + session.add_device('test') + assert 'test' in session.devices + + session.add_devices({'test2', 'test3'}) + assert 'test2' in session.devices and 'test3' in session.devices + + assert not session.should_rotate() + + session.encrypt('message') + assert session.should_rotate() From 4ce7cdf67f3803de78566d9ce73658e9831da071 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 29 Jun 2018 10:59:24 +0200 Subject: [PATCH 13/66] handle m.room.encryption properties --- .../crypto/megolm_outbound_session.py | 12 ++++++++---- matrix_client/crypto/olm_device.py | 3 ++- matrix_client/room.py | 6 ++++++ test/client_test.py | 19 ++++++++++++++++++- test/crypto/olm_device_test.py | 14 +++++++++++++- 5 files changed, 47 insertions(+), 7 deletions(-) diff --git a/matrix_client/crypto/megolm_outbound_session.py b/matrix_client/crypto/megolm_outbound_session.py index 5793c35f..cff87f6a 100644 --- a/matrix_client/crypto/megolm_outbound_session.py +++ b/matrix_client/crypto/megolm_outbound_session.py @@ -12,15 +12,19 @@ class MegolmOutboundSession(OutboundGroupSession): Args: max_age (datetime.timedelta): Optional. The maximum time the session should - exist. + exist. Default to one week if not present. max_messages (int): Optional. The maximum number of messages that should be sent. A new message in considered sent each time there is a call to ``encrypt``. + Default to 100 if not present. """ - def __init__(self, max_age=timedelta(days=7), max_messages=100): + def __init__(self, max_age=None, max_messages=None): self.devices = set() - self.max_age = max_age - self.max_messages = max_messages + if max_age: + self.max_age = timedelta(milliseconds=max_age) + else: + self.max_age = timedelta(days=7) + self.max_messages = max_messages or 100 self.creation_time = datetime.now() self.message_count = 0 super(MegolmOutboundSession, self).__init__() diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index ace09c7d..8b2a344d 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -401,7 +401,8 @@ def megolm_start_session(self, room): Returns: The newly created session. """ - session = MegolmOutboundSession() + session = MegolmOutboundSession(max_age=room.rotation_period_ms, + max_messages=room.rotation_period_msgs) self.megolm_outbound_sessions[room.room_id] = session logger.info('Starting a new Meglom outbound session %s in %s.', session.id, room.room_id) diff --git a/matrix_client/room.py b/matrix_client/room.py index 8a2f0d94..38e9a9fe 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -50,6 +50,8 @@ def __init__(self, client, room_id): # user_id: displayname } self.encrypted = False + self.rotation_period_msgs = None + self.rotation_period_ms = None def set_user_profile(self, displayname=None, @@ -692,6 +694,10 @@ def _process_state_event(self, state_event): elif etype == "m.room.encryption": if econtent.get("algorithm") == "m.megolm.v1.aes-sha2": self.encrypted = True + if not self.rotation_period_ms: + self.rotation_period_ms = econtent.get("rotation_period_ms") + if not self.rotation_period_msgs: + self.rotation_period_msgs = econtent.get("rotation_period_msgs") elif etype == "m.room.member" and clevel == clevel.ALL: # tracking room members can be large e.g. #matrix:matrix.org user_id = state_event["state_key"] diff --git a/test/client_test.py b/test/client_test.py index c5884924..5532e1ae 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -155,13 +155,30 @@ def test_state_event(): # test encryption room.encrypted = False ev["type"] = "m.room.encryption" - ev["content"] = {"algorithm": "m.megolm.v1.aes-sha2"} + ev["content"] = { + "algorithm": "m.megolm.v1.aes-sha2", + "rotation_period_msgs": 50, + "rotation_period_ms": 100000, + } room._process_state_event(ev) assert room.encrypted + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 # encrypted flag must not be cleared on configuration change ev["content"] = {"algorithm": None} room._process_state_event(ev) assert room.encrypted + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 + # nor should the session parameters be changed + ev["content"] = { + "algorithm": "m.megolm.v1.aes-sha2", + "rotation_period_msgs": 5, + "rotation_period_ms": 10000, + } + room._process_state_event(ev) + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 def test_get_user(): diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 7e4967aa..c7aaa51b 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -4,6 +4,7 @@ import json import logging from copy import deepcopy +from datetime import timedelta, datetime try: from urllib import quote except ImportError: @@ -519,7 +520,13 @@ def test_send_encrypted_message(self): def test_megolm_outbound_session(): - session = MegolmOutboundSession(max_messages=1) + session = MegolmOutboundSession() + assert session.max_messages == 100 + assert session.max_age == timedelta(days=7) + + session = MegolmOutboundSession(max_messages=1, max_age=100000) + assert session.max_messages == 1 + assert session.max_age == timedelta(milliseconds=100000) assert not session.devices @@ -533,3 +540,8 @@ def test_megolm_outbound_session(): session.encrypt('message') assert session.should_rotate() + + session.max_messages = 2 + assert not session.should_rotate() + session.creation_time = datetime.now() - timedelta(milliseconds=100000) + assert session.should_rotate() From 71e8c524f2da32160f294893eff5730a35b53204 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 29 Jun 2018 15:29:54 +0200 Subject: [PATCH 14/66] build doc for crypto.megolm_outbound_session --- docs/source/matrix_client.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index 1587a867..353eb5a5 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -61,3 +61,8 @@ matrix_client.crypto :members: :undoc-members: :show-inheritance: + +.. automodule:: matrix_client.crypto.megolm_outbound_session + :members: + :undoc-members: + :show-inheritance: From 2515d578f96847fea8785e0341b59dfe8b9f4666 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 20 Jul 2018 12:58:35 +0200 Subject: [PATCH 15/66] pass rotation parameters when enabling encryption --- matrix_client/room.py | 25 +++++++++++++++++++++---- test/client_test.py | 9 +++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/matrix_client/room.py b/matrix_client/room.py index 38e9a9fe..120f65d2 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -654,18 +654,35 @@ def set_guest_access(self, allow_guests): except MatrixRequestError: return False - def enable_encryption(self): + def enable_encryption(self, rotation_period_ms=None, rotation_period_msgs=None): """Enables encryption in the room. NOTE: Once enabled, encryption cannot be disabled. + Args: + rotation_period_ms (int): Optional. Lifetime of Megolm sessions in the room. + rotation_period_msgs (int): Optional. Number of messages that can be encrypted + by a Megolm session before it is rotated. + Returns: - True if successful, False if not + ``True`` if successful, ``False`` if not. + If the room was already encrypted, True is returned, but the + ``rotation_period_ms`` and ``rotation_period_ms`` parameters have no effect. """ + if self.encrypted: + return True + + content = {"algorithm": "m.megolm.v1.aes-sha2"} + if rotation_period_ms: + content["rotation_period_ms"] = rotation_period_ms + if rotation_period_msgs: + content["rotation_period_msgs"] = rotation_period_msgs + try: - self.send_state_event("m.room.encryption", - {"algorithm": "m.megolm.v1.aes-sha2"}) + self.send_state_event("m.room.encryption", content) self.encrypted = True + self.rotation_period_ms = rotation_period_ms + self.rotation_period_msgs = rotation_period_msgs return True except MatrixRequestError: return False diff --git a/test/client_test.py b/test/client_test.py index 5532e1ae..5fa84762 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -536,6 +536,15 @@ def test_enable_encryption_in_room(): assert room.enable_encryption() assert room.encrypted + room = client._mkroom(room_id) + assert room.enable_encryption(rotation_period_msgs=1, rotation_period_ms=1) + assert room.rotation_period_msgs == 1 + assert room.rotation_period_ms == 1 + + assert room.enable_encryption(rotation_period_msgs=2) + # The room was already encrypted, we should not have changed its attribute + assert room.rotation_period_msgs == 1 + @responses.activate def test_detect_encryption_state(): From e649a23f0e4783e9fd2cee4ae28feed1b3fed387 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 8 Jun 2018 23:36:07 +0200 Subject: [PATCH 16/66] add megolm inbound session support --- matrix_client/client.py | 5 ++ matrix_client/crypto/olm_device.py | 136 +++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+) diff --git a/matrix_client/client.py b/matrix_client/client.py index 0c2a8af6..6cc829b6 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -606,6 +606,11 @@ def _sync(self, timeout_ms=30000): if room_id in self.rooms: del self.rooms[room_id] + if 'to_device' in response: + for event in response['to_device']['events']: + if event['type'] == 'm.room.encrypted' and self._encryption: + self.olm_device.olm_handle_encrypted_event(event) + if self._encryption and 'device_one_time_keys_count' in response: self.olm_device.update_one_time_key_counts( response['device_one_time_keys_count']) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 8b2a344d..1b89504b 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -67,6 +67,8 @@ def __init__(self, self.device_list = DeviceList(self, api, self.device_keys) self.olm_sessions = defaultdict(list) self.megolm_outbound_sessions = {} + self.megolm_inbound_sessions = defaultdict(lambda: defaultdict(dict)) + self.megolm_index_record = defaultdict(dict) def upload_identity_keys(self): """Uploads this device's identity keys to HS. @@ -412,6 +414,10 @@ def megolm_start_session(self, room): for user in users} self.device_list.get_room_device_keys(room) self.megolm_share_session(room.room_id, user_devices, session) + # Store a corresponding inbound session, so that we can decrypt our own messages + self.megolm_add_inbound_session(room.room_id, self.identity_keys['curve25519'], + session.id, + session.session_key) return session def megolm_share_session(self, room_id, user_devices, session): @@ -538,6 +544,136 @@ def send_encrypted_message(self, room, content): return self.api.send_message_event( room.room_id, 'm.room.encrypted', encrypted_event) + def olm_handle_encrypted_event(self, encrypted_event): + """Decrypt and process an Olm m.room.encrypted event. + + Once decrypted, the event is processed according to its type. + + Args: + encrypted_event (dict): m.room.encrypted event. + """ + content = encrypted_event['content'] + if 'algorithm' not in content or content['algorithm'] != self._olm_algorithm: + return + + try: + event = self.olm_decrypt_event(content, encrypted_event['sender']) + except RuntimeError as e: + logger.warning('Failed to decrypt m.room_key event sent by user %s: %s', + encrypted_event['sender'], e) + return + + if event['type'] == 'm.room_key': + self.handle_room_key_event(event, encrypted_event['content']['sender_key']) + + def handle_room_key_event(self, event, sender_key): + """Handle a m.room_key event. + + Args: + event (dict): m.room_key event. + """ + content = event['content'] + if content['algorithm'] != self._megolm_algorithm: + logger.info('Ignoring unsupported algorithm %s in m.room_key event.', + content['algorithm']) + return + user_id = event['sender'] + device_id = event['sender_device'] + + new = self.megolm_add_inbound_session(content['room_id'], sender_key, + content['session_id'], + content['session_key']) + if new: + logger.info('Created a new Megolm inbound session with device %s of ' + 'user %s.', device_id, user_id) + else: + logger.info('Inbound Megolm session with device %s of user %s ' + 'already exists or is invalid.', device_id, user_id) + + def megolm_add_inbound_session(self, room_id, sender_key, session_id, session_key): + """Create a new Megolm inbound session if necessary. + + Args: + room_id (str): The room corresponding to the session. + sender_key (str): The curve25519 key of the sender's device. + session_id (str): The id of the session. + session_key (str): The key of the session. + + Returns: + ``True`` if a new session was created, ``False`` if it already existed or if + the parameters were invalid. + """ + sessions = self.megolm_inbound_sessions[room_id][sender_key] + if session_id in sessions: + return False + try: + session = olm.InboundGroupSession(session_key) + except olm.OlmGroupSessionError: + return False + if session.id != session_id: + logger.warning('Session ID mismatch in m.room_key event. Expected %s from ' + 'event property, got %s.', session_id, session.id) + return False + sessions[session_id] = session + return True + + def megolm_decrypt_event(self, event): + """Decrypt a Megolm m.room.encrypted event. + + The event is decrypted in-place, meaning its content and type properties are + overwritten by those of the decrypted event. + + Args: + event (dict): The event to decrypt. + """ + content = event['content'] + device_id = content['device_id'] + user_id = event['sender'] + if 'algorithm' not in content: + # Assume that this is a redacted event + return + if content['algorithm'] != self._megolm_algorithm: + raise RuntimeError('Incorrect algorithm "{}" value in event sent by device ' + '{} of user {}.'.format(content['algorithm'], device_id, + user_id)) + + sender_key = content['sender_key'] + room_id = event['room_id'] + session_id = content['session_id'] + sessions = self.megolm_inbound_sessions[room_id][sender_key] + try: + session = sessions[session_id] + except KeyError: + raise RuntimeError("Unable to decrypt event sent by device {} of user {}: " + "The sender's device has not sent us the keys for this " + "message.".format(device_id, user_id)) + + try: + decrypted_event, message_index = session.decrypt(content['ciphertext']) + except olm.group_session.OlmGroupSessionError as e: + raise RuntimeError('Unable to decrypt event sent by device {} of user {} ' + 'with matching megolm session: {}.'.format(device_id, + user_id, e)) + + try: + properties = self.megolm_index_record[session.id][message_index] + except KeyError: + self.megolm_index_record[session.id][message_index] = { + 'origin_server_ts': event['origin_server_ts'], + 'event_id': event['event_id'] + } + else: + if properties['origin_server_ts'] == event['origin_server_ts'] and \ + properties['event_id'] == event['event_id']: + raise RuntimeError('Detected a replay attack from device {} of user {} ' + 'on decrypted event: {}.'.format(device_id, user_id, + decrypted_event)) + + decrypted_event = json.loads(decrypted_event) + + event['type'] = decrypted_event['type'] + event['content'] = decrypted_event['content'] + def sign_json(self, json): """Signs a JSON object. From a28b29dc59b7e8eb86021e999275257c87cf2b37 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Mon, 4 Jun 2018 14:10:14 +0200 Subject: [PATCH 17/66] automatically decrypt encrypted events --- matrix_client/room.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/matrix_client/room.py b/matrix_client/room.py index 120f65d2..75b891e3 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import re from uuid import uuid4 @@ -20,6 +21,8 @@ from .user import User from .errors import MatrixRequestError, MatrixNoEncryptionError +logger = logging.getLogger(__name__) + class Room(object): """Call room-specific functions after joining a room from the client. @@ -332,6 +335,12 @@ def add_state_listener(self, callback, event_type=None): ) def _put_event(self, event): + if self.encrypted and self.client._encryption: + if event['type'] == 'm.room.encrypted': + try: + self.client.olm_device.megolm_decrypt_event(event) + except RuntimeError as e: + logger.warning(e) self.events.append(event) if len(self.events) > self.event_history_limit: self.events.pop(0) From a377517425307bbb00c52842ece0463d35d3dbda Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 28 Jun 2018 15:51:33 +0200 Subject: [PATCH 18/66] add megolm inbound tests --- matrix_client/crypto/olm_device.py | 4 +- test/crypto/olm_device_test.py | 155 ++++++++++++++++++++++++++++- test/response_examples.py | 17 ++++ 3 files changed, 173 insertions(+), 3 deletions(-) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 1b89504b..5ec98bba 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -663,8 +663,8 @@ def megolm_decrypt_event(self, event): 'event_id': event['event_id'] } else: - if properties['origin_server_ts'] == event['origin_server_ts'] and \ - properties['event_id'] == event['event_id']: + if properties['origin_server_ts'] != event['origin_server_ts'] or \ + properties['event_id'] != event['event_id']: raise RuntimeError('Detected a replay attack from device {} of user {} ' 'on decrypted event: {}.'.format(device_id, user_id, decrypted_event)) diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index c7aaa51b..e823ae7f 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -19,7 +19,8 @@ from matrix_client.crypto.olm_device import OlmDevice from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession from test.response_examples import (example_key_upload_response, - example_claim_keys_response) + example_claim_keys_response, + example_room_key_event) HOSTNAME = 'http://example.com' @@ -460,6 +461,23 @@ def test_megolm_start_session(self): session = self.device.megolm_outbound_sessions[self.room_id] assert self.alice_device_id in session.devices + # Check that we can decrypt our own messages + plaintext = { + 'type': 'test', + 'content': {'test': 'test'}, + } + encrypted_event = self.device.megolm_build_encrypted_event(self.room, plaintext) + event = { + 'sender': self.alice, + 'room_id': self.room_id, + 'content': encrypted_event, + 'type': 'm.room.encrypted', + 'origin_server_ts': 1, + 'event_id': 1 + } + self.device.megolm_decrypt_event(event) + assert event['content'] == plaintext['content'] + @responses.activate def test_megolm_share_session_with_new_devices(self): to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' @@ -518,6 +536,141 @@ def test_send_encrypted_message(self): self.device.send_encrypted_message(self.room, {'test': 'test'}) + def test_megolm_add_inbound_session(self): + session = MegolmOutboundSession() + self.device.megolm_inbound_sessions.clear() + + assert not self.device.megolm_add_inbound_session( + self.room_id, self.alice_curve_key, session.id, 'wrong') + assert self.device.megolm_add_inbound_session( + self.room_id, self.alice_curve_key, session.id, session.session_key) + assert session.id in \ + self.device.megolm_inbound_sessions[self.room_id][self.alice_curve_key] + assert not self.device.megolm_add_inbound_session( + self.room_id, self.alice_curve_key, session.id, session.session_key) + assert not self.device.megolm_add_inbound_session( + self.room_id, self.alice_curve_key, 'wrong', session.session_key) + + def test_handle_room_key_event(self): + self.device.megolm_inbound_sessions.clear() + + self.device.handle_room_key_event(example_room_key_event, self.alice_curve_key) + assert self.room_id in self.device.megolm_inbound_sessions + + self.device.handle_room_key_event(example_room_key_event, self.alice_curve_key) + + event = deepcopy(example_room_key_event) + event['content']['algorithm'] = 'wrong' + self.device.handle_room_key_event(event, self.alice_curve_key) + + event = deepcopy(example_room_key_event) + event['content']['session_id'] = 'wrong' + self.device.handle_room_key_event(event, self.alice_curve_key) + + def test_olm_handle_encrypted_event(self): + self.device.olm_sessions.clear() + alice_device = OlmDevice(self.device.api, self.alice, self.alice_device_id) + alice_device.device_keys[self.user_id][self.device_id] = self.device.identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = \ + alice_device.identity_keys + + # Artificially start an Olm session from Alice + self.device.olm_account.generate_one_time_keys(1) + otk = next(iter(self.device.olm_account.one_time_keys['curve25519'].values())) + self.device.olm_account.mark_keys_as_published() + sender_key = self.device.identity_keys['curve25519'] + session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) + alice_device.olm_sessions[sender_key] = [session] + + content = example_room_key_event['content'] + encrypted_event = alice_device.olm_build_encrypted_event( + 'm.room_key', content, self.user_id, self.device_id) + event = { + 'type': 'm.room.encrypted', + 'content': encrypted_event, + 'sender': self.alice + } + + self.device.olm_handle_encrypted_event(event) + + # Decrypting the same event twice will trigger an error + self.device.olm_handle_encrypted_event(event) + + encrypted_event = alice_device.olm_build_encrypted_event( + 'm.other', content, self.user_id, self.device_id) + event = { + 'type': 'm.room.encrypted', + 'content': encrypted_event, + 'sender': self.alice + } + self.device.olm_handle_encrypted_event(event) + + # Simulate redacted event + event['content'].pop('algorithm') + self.device.olm_handle_encrypted_event(event) + + def test_megolm_decrypt_event(self): + out_session = MegolmOutboundSession() + + plaintext = { + 'content': {"test": "test"}, + 'type': 'm.text', + } + ciphertext = out_session.encrypt(json.dumps(plaintext)) + + content = { + 'ciphertext': ciphertext, + 'session_id': out_session.id, + 'sender_key': self.alice_curve_key, + 'algorithm': 'm.megolm.v1.aes-sha2', + 'device_id': self.alice_device_id, + } + + event = { + 'sender': self.alice, + 'room_id': self.room_id, + 'content': content, + 'type': 'm.room.encrypted', + 'origin_server_ts': 1, + 'event_id': 1 + } + + with pytest.raises(RuntimeError): + self.device.megolm_decrypt_event(event) + + in_session = olm.InboundGroupSession(out_session.session_key) + sessions = self.device.megolm_inbound_sessions[self.room_id] + sessions[self.alice_curve_key][in_session.id] = in_session + + # Unknown message index + with pytest.raises(RuntimeError): + self.device.megolm_decrypt_event(event) + + ciphertext = out_session.encrypt(json.dumps(plaintext)) + event['content']['ciphertext'] = ciphertext + self.device.megolm_decrypt_event(event) + assert event['content'] == plaintext['content'] + + # No replay attack + event['content'] = content + self.device.megolm_decrypt_event(event) + assert event['content'] == plaintext['content'] + + # Replay attack + event['content'] = content + event['event_id'] = 2 + with pytest.raises(RuntimeError): + self.device.megolm_decrypt_event(event) + + event['content']['algorithm'] = 'wrong' + with pytest.raises(RuntimeError): + self.device.megolm_decrypt_event(event) + + event['content'].pop('algorithm') + event['type'] = 'encrypted' + self.device.megolm_decrypt_event(event) + assert event['type'] == 'encrypted' + def test_megolm_outbound_session(): session = MegolmOutboundSession() diff --git a/test/response_examples.py b/test/response_examples.py index d6a2ca49..e577c464 100644 --- a/test/response_examples.py +++ b/test/response_examples.py @@ -236,3 +236,20 @@ } } } + +example_room_key_event = { + "sender": "@alice:example.com", + "sender_device": "JLAFKJWSCS", + "content": { + "algorithm": "m.megolm.v1.aes-sha2", + "room_id": "!test:example.com", + "session_id": "AVCXMm6LZ+J/vyCcomXmE48mbD1IyKbUBUd3UOW0wHE", + "session_key": ( + "AgAAAAAJS98WXiCc90wJ23H1ucZ+XFCv8pN8C5p/XojdA6l7PWlFwAV1fQXe7afrQMRL9BxeeF8M" + "uNnpvGX0hGOWcW0e2LU3EzQ0j8+jhxrPkQHUOJ8387CjRSA9UTBDmw3y8xquy3cXvuGE5DSpFUU7" + "J7Xh+Dli8XRaRDCbmPmMtSdPMwFQlzJui2fif78gnKJl5hOPJmw9SMim1AVHd1DltMBx4vB/3Kse" + "G413GWJkw9T+G6y51bsNEKsSU23lnJz32u5XwgNY9qdFKxGA6WL1wZZS6/iGW4gfTU/Jk89aGSA8" + "Aw") + }, + "type": "m.room_key" +} From ca08e8689ee7a70fe7185c7de35ca3cf79e5ada8 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 14 Jun 2018 17:43:51 +0200 Subject: [PATCH 19/66] persist olm account --- matrix_client/crypto/crypto_store.py | 94 ++++++++++++++++++++++++++++ matrix_client/crypto/olm_device.py | 24 ++++++- setup.py | 2 +- 3 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 matrix_client/crypto/crypto_store.py diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py new file mode 100644 index 00000000..46df3f25 --- /dev/null +++ b/matrix_client/crypto/crypto_store.py @@ -0,0 +1,94 @@ +import logging +import os +import sqlite3 + +import olm +from appdirs import user_data_dir + +logger = logging.getLogger(__name__) + + +class CryptoStore(object): + """Manages persistent storage for an OlmDevice. + + Args: + device_id (str): The device id of the OlmDevice. + db_name (str): Optional. The name of the database file to use. Will be created + if necessary. + db_path (str): Optional. The path where to store the database file. Defaults to + the system default application data directory. + app_name (str): Optional. The application name, which will be used to determine + where the database is located. Ignored if db_path is supplied. + pickle_key (str): Optional. A key to encrypt the database contents. + """ + + def __init__(self, + device_id, + db_name='crypto.db', + db_path=None, + app_name='matrix-python-sdk', + pickle_key='DEFAULT_KEY'): + self.device_id = device_id + data_dir = db_path or user_data_dir(app_name, '') + try: + os.makedirs(data_dir) + except OSError: + pass + self.conn = sqlite3.connect(os.path.join(data_dir, db_name)) + self.pickle_key = pickle_key + self.create_tables_if_needed() + + def create_tables_if_needed(self): + """Ensures all the tables exist.""" + c = self.conn.cursor() + c.execute('CREATE TABLE IF NOT EXISTS accounts (device_id TEXT PRIMARY KEY,' + 'account BLOB)') + c.close() + self.conn.commit() + + def save_olm_account(self, account): + """Saves an Olm account. + + Args: + account (olm.Account): The account object to save. + """ + account_data = account.pickle(self.pickle_key) + c = self.conn.cursor() + c.execute('INSERT OR IGNORE INTO accounts (device_id, account) VALUES (?,?)', + (self.device_id, account_data)) + c.execute('UPDATE accounts SET account=? WHERE device_id=?', + (account_data, self.device_id)) + c.close() + self.conn.commit() + + def get_olm_account(self): + """Gets the Olm account. + + Returns: + olm.Account object, or None if it wasn't found for the current device_id. + """ + c = self.conn.cursor() + c.execute( + 'SELECT account FROM accounts WHERE device_id=?', (self.device_id,)) + try: + account_data = c.fetchone()[0] + # sqlite gives us unicode in Python2, we want bytes + account_data = bytes(account_data) + except TypeError: + return None + finally: + c.close() + return olm.Account.from_pickle(account_data, self.pickle_key) + + def remove_olm_account(self): + """Removes the Olm account. + + NOTE: Doing so will remove any saved information associated with the account + (keys, sessions...) + """ + c = self.conn.cursor() + c.execute('DELETE FROM accounts WHERE device_id=?', (self.device_id,)) + c.close() + + def close(self): + self.conn.close() diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 5ec98bba..c40b423d 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -9,6 +9,7 @@ from matrix_client.crypto.one_time_keys import OneTimeKeysManager from matrix_client.crypto.device_list import DeviceList from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession +from matrix_client.crypto.crypto_store import CryptoStore logger = logging.getLogger(__name__) @@ -32,6 +33,11 @@ class OlmDevice(object): replenishment is triggered. Must be between ``0`` and ``1``. For example, ``0.1`` means that new one-time keys will be uploaded when there is less than 10% of the maximum number of one-time keys on the server. + Store (class): Optional. Custom storage class. It should implement the same + methods as :class:`~matrix_client.crypto.crypto_store.CryptoStore`. + store_conf (dict): Optional. Configuration parameters for keys storage. Refer to + :func:`~matrix_client.crypto.crypto_store.CryptoStore` for supported options, + since it will be passed to this class. """ _olm_algorithm = 'm.olm.v1.curve25519-aes-sha2' @@ -43,7 +49,9 @@ def __init__(self, user_id, device_id, signed_keys_proportion=1, - keys_threshold=0.1): + keys_threshold=0.1, + Store=CryptoStore, + store_conf=None): if not 0 <= signed_keys_proportion <= 1: raise ValueError('signed_keys_proportion must be between 0 and 1.') if not 0 <= keys_threshold <= 1: @@ -52,8 +60,15 @@ def __init__(self, check_user_id(user_id) self.user_id = user_id self.device_id = device_id - self.olm_account = olm.Account() - logger.info('Initialised Olm Device.') + conf = store_conf or {} + self.db = Store(self.device_id, **conf) + self.olm_account = self.db.get_olm_account() + if self.olm_account: + logger.info('Loaded Olm account from database for device %s.', device_id) + else: + self.olm_account = olm.Account() + self.db.save_olm_account(self.olm_account) + logger.info('Created new Olm account for device %s.', device_id) self.identity_keys = self.olm_account.identity_keys # Try to maintain half the number of one-time keys libolm can hold uploaded # on the HS. This is because some keys will be claimed by peers but not @@ -126,12 +141,14 @@ def upload_one_time_keys(self, force_update=False): ret = self.api.upload_keys(one_time_keys=one_time_keys) self.one_time_keys_manager.server_counts = ret['one_time_key_counts'] self.olm_account.mark_keys_as_published() + self.db.save_olm_account(self.olm_account) keys_uploaded = {} if unsigned_keys_to_upload: keys_uploaded['curve25519'] = unsigned_keys_to_upload if signed_keys_to_upload: keys_uploaded['signed_curve25519'] = signed_keys_to_upload + logger.info('Uploaded new one-time keys: %s.', keys_uploaded) return keys_uploaded @@ -371,6 +388,7 @@ def _olm_decrypt(self, olm_message, sender_key): raise RuntimeError('Error decrypting pre-key message with new session: ' '{}.'.format(e)) self.olm_account.remove_one_time_keys(session) + self.db.save_olm_account(self.olm_account) sessions.append(session) return json.loads(event) diff --git a/setup.py b/setup.py index 6049f397..f555d4c0 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ def exec_file(names): 'test': ['pytest', 'responses'], 'doc': ['Sphinx==1.4.6', 'sphinx-rtd-theme==0.1.9', 'sphinxcontrib-napoleon==0.5.3'], 'format': ['flake8'], - 'e2e': ['python-olm==dev', 'canonicaljson'] + 'e2e': ['python-olm==dev', 'canonicaljson', 'appdirs'] }, dependency_links=[ 'git+https://github.com/poljar/python-olm.git@4752eb22f005cb9f6143857008572e6d83252841#egg=python-olm-dev' From 16ed29426afb6b0ba55c9962d8a8b3e8de2f4024 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 29 Jun 2018 15:31:58 +0200 Subject: [PATCH 20/66] build doc for crypto.crypto_store --- docs/source/matrix_client.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index 353eb5a5..e008928f 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -66,3 +66,8 @@ matrix_client.crypto :members: :undoc-members: :show-inheritance: + +.. automodule:: matrix_client.crypto.crypto_store + :members: + :undoc-members: + :show-inheritance: From 55e8e09255244d26906164607200257a7605d6f8 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 29 Jun 2018 21:03:56 +0200 Subject: [PATCH 21/66] add account persistence tests --- test/crypto/crypto_store_test.py | 49 ++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 test/crypto/crypto_store_test.py diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py new file mode 100644 index 00000000..a0565e7c --- /dev/null +++ b/test/crypto/crypto_store_test.py @@ -0,0 +1,49 @@ +import pytest +olm = pytest.importorskip("olm") # noqa + +import os +from tempfile import mkdtemp + +from matrix_client.crypto.crypto_store import CryptoStore +from matrix_client.crypto.olm_device import OlmDevice + + +class TestCryptoStore(object): + + # Initialise a store and test some init code + device_id = 'AUIETSRN' + user_id = '@user:matrix.org' + db_name = 'test.db' + db_path = mkdtemp() + store_conf = { + 'db_name': db_name, + 'db_path': db_path + } + store = CryptoStore(device_id, db_path=db_path, db_name=db_name) + db_filepath = os.path.join(db_path, db_name) + assert os.path.exists(db_filepath) + store.close() + store = CryptoStore(device_id, db_path=db_path, db_name='test.db') + + @pytest.fixture(autouse=True, scope='class') + def cleanup(self): + yield + os.remove(self.db_filepath) + + def test_olm_account_persistence(self): + account = olm.Account() + identity_keys = account.identity_keys + self.store.remove_olm_account() + + # Try to load inexisting account + saved_account = self.store.get_olm_account() + assert saved_account is None + + # Save and load + self.store.save_olm_account(account) + saved_account = self.store.get_olm_account() + assert saved_account.identity_keys == identity_keys + + # Load the account from an OlmDevice + device = OlmDevice(None, self.user_id, self.device_id, store_conf=self.store_conf) + assert device.olm_account.identity_keys == account.identity_keys From 9e0ad6ff8c218277815120ce8a8e9708699d7427 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 29 Jun 2018 21:45:23 +0200 Subject: [PATCH 22/66] prevent tests from using database --- test/client_test.py | 3 +++ test/crypto/__init__.py | 0 test/crypto/device_list_test.py | 2 +- test/crypto/dummy_olm_device.py | 21 +++++++++++++++++++++ test/crypto/olm_device_test.py | 2 +- 5 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 test/crypto/__init__.py create mode 100644 test/crypto/dummy_olm_device.py diff --git a/test/client_test.py b/test/client_test.py index 5fa84762..7cde5f74 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -1,16 +1,19 @@ import pytest import responses import json +import matrix_client.client from copy import deepcopy from matrix_client.client import MatrixClient, Room, User, CACHE from matrix_client.api import MATRIX_V2_API_PATH from . import response_examples +from .crypto.dummy_olm_device import OlmDevice try: from urllib import quote except ImportError: from urllib.parse import quote HOSTNAME = "http://example.com" +matrix_client.client.OlmDevice = OlmDevice def test_create_client(): diff --git a/test/crypto/__init__.py b/test/crypto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py index 69df1ed3..71d321c9 100644 --- a/test/crypto/device_list_test.py +++ b/test/crypto/device_list_test.py @@ -11,9 +11,9 @@ from matrix_client.client import MatrixClient from matrix_client.room import User from matrix_client.errors import MatrixRequestError -from matrix_client.crypto.olm_device import OlmDevice from matrix_client.crypto.device_list import (_OutdatedUsersSet as OutdatedUsersSet, _UpdateDeviceList as UpdateDeviceList) +from test.crypto.dummy_olm_device import OlmDevice from test.response_examples import example_key_query_response HOSTNAME = 'http://example.com' diff --git a/test/crypto/dummy_olm_device.py b/test/crypto/dummy_olm_device.py new file mode 100644 index 00000000..83c50c6c --- /dev/null +++ b/test/crypto/dummy_olm_device.py @@ -0,0 +1,21 @@ +"""Tests can import OlmDevice from here, and know it won't try to use a database.""" + +from matrix_client.crypto.crypto_store import CryptoStore +from matrix_client.crypto.olm_device import OlmDevice as BaseOlmDevice + + +class DummyStore(CryptoStore): + def __init__(*args, **kw): pass + + def nop(*args, **kw): pass + + def __getattribute__(self, name): + if name in dir(CryptoStore): + return object.__getattribute__(self, 'nop') + raise AttributeError + + +class OlmDevice(BaseOlmDevice): + + def __init__(self, *args, **kw): + super(OlmDevice, self).__init__(*args, Store=DummyStore, **kw) diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index e823ae7f..54cceea7 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -16,7 +16,7 @@ from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient from matrix_client.user import User -from matrix_client.crypto.olm_device import OlmDevice +from test.crypto.dummy_olm_device import OlmDevice from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession from test.response_examples import (example_key_upload_response, example_claim_keys_response, From 05823ea9d761b4dfd959fd17fecb31ccc4cfff51 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 14 Jun 2018 17:44:01 +0200 Subject: [PATCH 23/66] persist olm sessions --- matrix_client/crypto/crypto_store.py | 65 ++++++++++++++++++++++++++++ matrix_client/crypto/olm_device.py | 23 ++++++++-- 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index 46df3f25..9f1784e5 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -41,8 +41,13 @@ def __init__(self, def create_tables_if_needed(self): """Ensures all the tables exist.""" c = self.conn.cursor() + c.execute('PRAGMA foreign_keys = ON') c.execute('CREATE TABLE IF NOT EXISTS accounts (device_id TEXT PRIMARY KEY,' 'account BLOB)') + c.execute('CREATE TABLE IF NOT EXISTS olm_sessions (device_id TEXT,' + 'session_id TEXT PRIMARY KEY, curve_key TEXT, session BLOB,' + 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' + 'ON DELETE CASCADE)') c.close() self.conn.commit() @@ -90,5 +95,65 @@ def remove_olm_account(self): c.execute('DELETE FROM accounts WHERE device_id=?', (self.device_id,)) c.close() + def save_olm_session(self, curve_key, session): + self.save_olm_sessions({curve_key: [session]}) + + def save_olm_sessions(self, sessions): + """Saves Olm sessions. + + Args: + sessions (defaultdict(list)): A map from curve25519 keys to a list of + olm.Session objects. + """ + c = self.conn.cursor() + rows = [(self.device_id, s.id, key, s.pickle(self.pickle_key)) + for key in sessions for s in sessions[key]] + c.executemany('REPLACE INTO olm_sessions VALUES (?,?,?,?)', rows) + c.close() + self.conn.commit() + + def load_olm_sessions(self, sessions): + """Loads all saved Olm sessions. + + Args: + sessions (defaultdict(list)): A map from curve25519 keys to a list of + olm.Session objects, which will be populated. + """ + c = self.conn.cursor() + rows = c.execute('SELECT curve_key, session FROM olm_sessions WHERE device_id=?', + (self.device_id,)) + for row in rows: + session = olm.Session.from_pickle(bytes(row[1]), self.pickle_key) + sessions[row[0]].append(session) + c.close() + + def get_olm_sessions(self, curve_key, sessions_dict=None): + """Get the Olm sessions corresponding to a device. + + Args: + curve_key (str): The curve25519 key of the device. + sessions_dict (defaultdict(list)): Optional. A map from curve25519 keys to a + list of olm.Session objects, to which the session list will be added. + + Returns: + A list of olm.Session objects, or None if none were found. + + NOTE: + When overriding this, be careful to append the retrieved sessions to the + list of sessions already present and not to overwrite its reference. + """ + c = self.conn.cursor() + rows = c.execute( + 'SELECT session FROM olm_sessions WHERE device_id=? AND curve_key=?', + (self.device_id, curve_key) + ) + sessions = [olm.Session.from_pickle(bytes(row[0]), self.pickle_key) + for row in rows] + if sessions_dict is not None: + sessions_dict[curve_key].extend(sessions) + c.close() + # For consistency with other get_ methods, do not return an empty list + return sessions or None + def close(self): self.conn.close() diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index c40b423d..efe4d5a8 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -38,6 +38,9 @@ class OlmDevice(object): store_conf (dict): Optional. Configuration parameters for keys storage. Refer to :func:`~matrix_client.crypto.crypto_store.CryptoStore` for supported options, since it will be passed to this class. + load_all (bool): Optional. If True, all content of the database for the current + device will be loaded at once. This will increase runtime performance but + also launch time and memory usage. """ _olm_algorithm = 'm.olm.v1.curve25519-aes-sha2' @@ -51,7 +54,8 @@ def __init__(self, signed_keys_proportion=1, keys_threshold=0.1, Store=CryptoStore, - store_conf=None): + store_conf=None, + load_all=False): if not 0 <= signed_keys_proportion <= 1: raise ValueError('signed_keys_proportion must be between 0 and 1.') if not 0 <= keys_threshold <= 1: @@ -60,10 +64,13 @@ def __init__(self, check_user_id(user_id) self.user_id = user_id self.device_id = device_id + self.olm_sessions = defaultdict(list) conf = store_conf or {} self.db = Store(self.device_id, **conf) self.olm_account = self.db.get_olm_account() if self.olm_account: + if load_all: + self.db.load_olm_sessions(self.olm_sessions) logger.info('Loaded Olm account from database for device %s.', device_id) else: self.olm_account = olm.Account() @@ -80,7 +87,6 @@ def __init__(self, keys_threshold) self.device_keys = defaultdict(dict) self.device_list = DeviceList(self, api, self.device_keys) - self.olm_sessions = defaultdict(list) self.megolm_outbound_sessions = {} self.megolm_inbound_sessions = defaultdict(lambda: defaultdict(dict)) self.megolm_index_record = defaultdict(dict) @@ -195,6 +201,7 @@ def olm_start_sessions(self, user_devices): missing[user_id] = missing_devices logger.warning('Failed to claim the keys of %s.', missing) + new_sessions = defaultdict(list) for user_id in user_devices: for device_id, one_time_key in keys.get(user_id, {}).items(): try: @@ -214,12 +221,14 @@ def olm_start_sessions(self, user_devices): key_object['key']) sessions = self.olm_sessions[device_keys['curve25519']] sessions.append(session) + new_sessions[device_keys['curve25519']].append(session) logger.info('Established Olm session %s with device %s of user ' '%s.', device_id, session.id, user_id) else: logger.warning('Signature verification for one-time key of device %s ' 'of user %s failed, could not start olm session.', device_id, user_id) + self.db.save_olm_sessions(new_sessions) def olm_build_encrypted_event(self, event_type, content, user_id, device_id): """Encrypt an event using Olm. @@ -264,6 +273,7 @@ def olm_build_encrypted_event(self, event_type, content, user_id, device_id): raise RuntimeError('No session for this device, could not encrypt.') encrypted_message = session.encrypt(json.dumps(payload)) + self.db.save_olm_session(identity_key, session) ciphertext_payload = { identity_key: { 'type': encrypted_message.message_type, @@ -348,11 +358,15 @@ def _olm_decrypt(self, olm_message, sender_key): """ sessions = self.olm_sessions[sender_key] + if not sessions: + # `sessions` should get populated by this method + self.db.get_olm_sessions(sender_key, self.olm_sessions) # Try to decrypt message body using one of the known sessions for that device for session in sessions: try: event = session.decrypt(olm_message) + self.db.save_olm_session(sender_key, session) logger.info('Success decrypting Olm event using existing session %s.', session.id) break @@ -389,6 +403,7 @@ def _olm_decrypt(self, olm_message, sender_key): '{}.'.format(e)) self.olm_account.remove_one_time_keys(session) self.db.save_olm_account(self.olm_account) + self.db.save_olm_session(sender_key, session) sessions.append(session) return json.loads(event) @@ -408,7 +423,9 @@ def olm_ensure_sessions(self, user_devices): # to establish a session with a device, but this attempt was # unsuccessful. We do not retry to establish a session. if curve_key not in self.olm_sessions: - user_devices_no_session[user_id].append(device_id) + sessions = self.db.get_olm_sessions(curve_key, self.olm_sessions) + if not sessions: + user_devices_no_session[user_id].append(device_id) if user_devices_no_session: self.olm_start_sessions(user_devices_no_session) From 33a06d803214ce4ecd4d07eaa74532666f5c2cb7 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sat, 30 Jun 2018 15:12:26 +0200 Subject: [PATCH 24/66] add olm persistence tests --- test/crypto/crypto_store_test.py | 57 ++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index a0565e7c..a371b4b0 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -2,6 +2,7 @@ olm = pytest.importorskip("olm") # noqa import os +from collections import defaultdict from tempfile import mkdtemp from matrix_client.crypto.crypto_store import CryptoStore @@ -30,6 +31,14 @@ def cleanup(self): yield os.remove(self.db_filepath) + @pytest.fixture() + def account(self): + account = self.store.get_olm_account() + if account is None: + account = olm.Account() + self.store.save_olm_account(account) + return account + def test_olm_account_persistence(self): account = olm.Account() identity_keys = account.identity_keys @@ -47,3 +56,51 @@ def test_olm_account_persistence(self): # Load the account from an OlmDevice device = OlmDevice(None, self.user_id, self.device_id, store_conf=self.store_conf) assert device.olm_account.identity_keys == account.identity_keys + + def test_olm_sessions_persistence(self, account): + curve_key = account.identity_keys['curve25519'] + session = olm.OutboundSession(account, curve_key, curve_key) + sessions = defaultdict(list) + + self.store.load_olm_sessions(sessions) + assert not sessions + assert not self.store.get_olm_sessions(curve_key) + + self.store.save_olm_session(curve_key, session) + self.store.load_olm_sessions(sessions) + assert sessions[curve_key][0].id == session.id + + saved_sessions = self.store.get_olm_sessions(curve_key) + assert saved_sessions[0].id == session.id + + sessions.clear() + saved_sessions = self.store.get_olm_sessions(curve_key, sessions) + assert sessions[curve_key][0].id == session.id + + # Replace the session when its internal state has changed + pickle = session.pickle() + session.encrypt('test') + self.store.save_olm_session(curve_key, session) + saved_sessions = self.store.get_olm_sessions(curve_key) + assert saved_sessions[0].pickle != pickle + + # Load all sessions at once + device = OlmDevice( + None, self.user_id, self.device_id, store_conf=self.store_conf, load_all=True) + assert device.olm_sessions[curve_key][0].id == session.id + + # Load sessions dynamically + device = OlmDevice(None, self.user_id, self.device_id, store_conf=self.store_conf) + assert not device.olm_sessions + with pytest.raises(AttributeError): + device._olm_decrypt(None, curve_key) + assert device.olm_sessions[curve_key][0].id == session.id + + device.olm_sessions.clear() + device.device_keys[self.user_id][self.device_id] = {'curve25519': curve_key} + device.olm_ensure_sessions({self.user_id: [self.device_id]}) + assert device.olm_sessions[curve_key][0].id == session.id + + # Test cascade deletion + self.store.remove_olm_account() + assert not self.store.get_olm_sessions(curve_key) From f8487bd7b7f84e75c38fd44f1a2d4cbd891410e0 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 14 Jun 2018 18:25:19 +0200 Subject: [PATCH 25/66] persist megolm inbound sessions --- matrix_client/crypto/crypto_store.py | 70 ++++++++++++++++++++++++++++ matrix_client/crypto/olm_device.py | 18 +++++-- 2 files changed, 83 insertions(+), 5 deletions(-) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index 9f1784e5..1823d9de 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -48,6 +48,11 @@ def create_tables_if_needed(self): 'session_id TEXT PRIMARY KEY, curve_key TEXT, session BLOB,' 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' 'ON DELETE CASCADE)') + c.execute('CREATE TABLE IF NOT EXISTS megolm_inbound_sessions ' + '(device_id TEXT, session_id TEXT PRIMARY KEY, room_id TEXT,' + 'curve_key TEXT, session BLOB,' + 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' + 'ON DELETE CASCADE)') c.close() self.conn.commit() @@ -155,5 +160,70 @@ def get_olm_sessions(self, curve_key, sessions_dict=None): # For consistency with other get_ methods, do not return an empty list return sessions or None + def save_inbound_session(self, room_id, curve_key, session): + """Saves a Megolm inbound session. + + Args: + room_id (str): The room corresponding to the session. + curve_key (str): The curve25519 key of the device. + session (olm.InboundGroupSession): The session to save. + """ + c = self.conn.cursor() + c.execute('REPLACE INTO megolm_inbound_sessions VALUES (?,?,?,?,?)', + (self.device_id, session.id, room_id, curve_key, + session.pickle(self.pickle_key))) + c.close() + self.conn.commit() + + def load_inbound_sessions(self, sessions): + """Loads all saved inbound Megolm sessions. + + Args: + sessions (defaultdict(defaultdict(dict))): An object which will get + populated with the sessions. The format is + ``{: {: {: + }}}``. + """ + c = self.conn.cursor() + rows = c.execute( + 'SELECT room_id, curve_key, session FROM megolm_inbound_sessions WHERE ' + 'device_id=?', (self.device_id,) + ) + for row in rows: + session = olm.InboundGroupSession.from_pickle(bytes(row[2]), self.pickle_key) + sessions[row[0]][row[1]][session.id] = session + c.close() + + def get_inbound_session(self, room_id, curve_key, session_id, sessions=None): + """Gets a saved inbound Megolm session. + + Args: + room_id (str): The room corresponding to the session. + curve_key (str): The curve25519 key of the device. + session_id (str): The id of the session. + sessions (dict): Optional. A map from session id to olm.InboundGroupSession + object, to which the session will be added. + + Returns: + olm.InboundGroupSession object, or None if the session was not found. + """ + c = self.conn.cursor() + c.execute( + 'SELECT session FROM megolm_inbound_sessions WHERE device_id=? AND room_id=? ' + 'AND curve_key=? AND session_id=?', + (self.device_id, room_id, curve_key, session_id) + ) + try: + session_data = c.fetchone()[0] + session_data = bytes(session_data) + except TypeError: + return None + finally: + c.close() + session = olm.InboundGroupSession.from_pickle(session_data, self.pickle_key) + if sessions is not None: + sessions[session.id] = session + return session + def close(self): self.conn.close() diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index efe4d5a8..dc43a47e 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -64,13 +64,15 @@ def __init__(self, check_user_id(user_id) self.user_id = user_id self.device_id = device_id - self.olm_sessions = defaultdict(list) conf = store_conf or {} self.db = Store(self.device_id, **conf) + self.olm_sessions = defaultdict(list) + self.megolm_inbound_sessions = defaultdict(lambda: defaultdict(dict)) self.olm_account = self.db.get_olm_account() if self.olm_account: if load_all: self.db.load_olm_sessions(self.olm_sessions) + self.db.load_inbound_sessions(self.megolm_inbound_sessions) logger.info('Loaded Olm account from database for device %s.', device_id) else: self.olm_account = olm.Account() @@ -88,7 +90,6 @@ def __init__(self, self.device_keys = defaultdict(dict) self.device_list = DeviceList(self, api, self.device_keys) self.megolm_outbound_sessions = {} - self.megolm_inbound_sessions = defaultdict(lambda: defaultdict(dict)) self.megolm_index_record = defaultdict(dict) def upload_identity_keys(self): @@ -641,6 +642,9 @@ def megolm_add_inbound_session(self, room_id, sender_key, session_id, session_ke sessions = self.megolm_inbound_sessions[room_id][sender_key] if session_id in sessions: return False + # Load the session if it exists + if self.db.get_inbound_session(room_id, sender_key, session_id, sessions): + return False try: session = olm.InboundGroupSession(session_key) except olm.OlmGroupSessionError: @@ -649,6 +653,7 @@ def megolm_add_inbound_session(self, room_id, sender_key, session_id, session_ke logger.warning('Session ID mismatch in m.room_key event. Expected %s from ' 'event property, got %s.', session_id, session.id) return False + self.db.save_inbound_session(room_id, sender_key, session) sessions[session_id] = session return True @@ -679,9 +684,12 @@ def megolm_decrypt_event(self, event): try: session = sessions[session_id] except KeyError: - raise RuntimeError("Unable to decrypt event sent by device {} of user {}: " - "The sender's device has not sent us the keys for this " - "message.".format(device_id, user_id)) + session = self.db.get_inbound_session( + room_id, sender_key, session_id, sessions) + if not session: + raise RuntimeError("Unable to decrypt event sent by device {} of user " + "{}: The sender's device has not sent us the keys for " + "this message.".format(device_id, user_id)) try: decrypted_event, message_index = session.decrypt(content['ciphertext']) From 4bc247327a06868f86338b8a7a0de795f102a219 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sun, 1 Jul 2018 13:55:20 +0200 Subject: [PATCH 26/66] add megolm inbound persistence tests --- test/crypto/crypto_store_test.py | 84 +++++++++++++++++++++++++++++--- 1 file changed, 76 insertions(+), 8 deletions(-) diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index a371b4b0..3ae2b0f1 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -14,6 +14,7 @@ class TestCryptoStore(object): # Initialise a store and test some init code device_id = 'AUIETSRN' user_id = '@user:matrix.org' + room_id = '!test:example.com' db_name = 'test.db' db_path = mkdtemp() store_conf = { @@ -39,6 +40,14 @@ def account(self): self.store.save_olm_account(account) return account + @pytest.fixture() + def curve_key(self, account): + return account.identity_keys['curve25519'] + + @pytest.fixture() + def device(self): + return OlmDevice(None, self.user_id, self.device_id, store_conf=self.store_conf) + def test_olm_account_persistence(self): account = olm.Account() identity_keys = account.identity_keys @@ -57,8 +66,7 @@ def test_olm_account_persistence(self): device = OlmDevice(None, self.user_id, self.device_id, store_conf=self.store_conf) assert device.olm_account.identity_keys == account.identity_keys - def test_olm_sessions_persistence(self, account): - curve_key = account.identity_keys['curve25519'] + def test_olm_sessions_persistence(self, account, curve_key, device): session = olm.OutboundSession(account, curve_key, curve_key) sessions = defaultdict(list) @@ -84,13 +92,7 @@ def test_olm_sessions_persistence(self, account): saved_sessions = self.store.get_olm_sessions(curve_key) assert saved_sessions[0].pickle != pickle - # Load all sessions at once - device = OlmDevice( - None, self.user_id, self.device_id, store_conf=self.store_conf, load_all=True) - assert device.olm_sessions[curve_key][0].id == session.id - # Load sessions dynamically - device = OlmDevice(None, self.user_id, self.device_id, store_conf=self.store_conf) assert not device.olm_sessions with pytest.raises(AttributeError): device._olm_decrypt(None, curve_key) @@ -104,3 +106,69 @@ def test_olm_sessions_persistence(self, account): # Test cascade deletion self.store.remove_olm_account() assert not self.store.get_olm_sessions(curve_key) + + def test_megolm_inbound_persistence(self, curve_key, device): + out_session = olm.OutboundGroupSession() + session = olm.InboundGroupSession(out_session.session_key) + sessions = defaultdict(lambda: defaultdict(dict)) + + self.store.load_inbound_sessions(sessions) + assert not sessions + assert not self.store.get_inbound_session(self.room_id, curve_key, session.id) + + self.store.save_inbound_session(self.room_id, curve_key, session) + self.store.load_inbound_sessions(sessions) + assert sessions[self.room_id][curve_key][session.id].id == session.id + + saved_session = self.store.get_inbound_session(self.room_id, curve_key, + session.id) + assert saved_session.id == session.id + + sessions = {} + saved_session = self.store.get_inbound_session(self.room_id, curve_key, + session.id, sessions) + assert sessions[session.id].id == session.id + + assert not device.megolm_inbound_sessions + created = device.megolm_add_inbound_session( + self.room_id, curve_key, session.id, out_session.session_key) + assert not created + assert device.megolm_inbound_sessions[self.room_id][curve_key][session.id].id == \ + session.id + + device.megolm_inbound_sessions.clear() + content = { + 'sender_key': curve_key, + 'session_id': session.id, + 'algorithm': device._megolm_algorithm, + 'device_id': '' + } + event = { + 'sender': '', + 'room_id': self.room_id, + 'content': content + } + with pytest.raises(KeyError): + device.megolm_decrypt_event(event) + assert device.megolm_inbound_sessions[self.room_id][curve_key][session.id].id == \ + session.id + + self.store.remove_olm_account() + assert not self.store.get_inbound_session(self.room_id, curve_key, session.id) + + def test_load_all(self, account, curve_key): + curve_key = account.identity_keys['curve25519'] + session = olm.OutboundSession(account, curve_key, curve_key) + out_session = olm.OutboundGroupSession() + in_session = olm.InboundGroupSession(out_session.session_key) + + self.store.save_inbound_session(self.room_id, curve_key, in_session) + self.store.save_olm_session(curve_key, session) + + device = OlmDevice( + None, self.user_id, self.device_id, store_conf=self.store_conf, load_all=True) + + assert session.id in {s.id for s in device.olm_sessions[curve_key]} + saved_in_session = \ + device.megolm_inbound_sessions[self.room_id][curve_key][in_session.id] + assert saved_in_session.id == in_session.id From 93f408a6b26ca9a426f1d2a117be8dc0a5e36c27 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 14 Jun 2018 23:17:09 +0200 Subject: [PATCH 27/66] persist megolm outbound sessions --- matrix_client/crypto/crypto_store.py | 132 +++++++++++++++++- .../crypto/megolm_outbound_session.py | 15 ++ matrix_client/crypto/olm_device.py | 14 +- 3 files changed, 158 insertions(+), 3 deletions(-) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index 1823d9de..294413b0 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -1,10 +1,13 @@ import logging import os import sqlite3 +from datetime import timedelta import olm from appdirs import user_data_dir +from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession + logger = logging.getLogger(__name__) @@ -34,7 +37,8 @@ def __init__(self, os.makedirs(data_dir) except OSError: pass - self.conn = sqlite3.connect(os.path.join(data_dir, db_name)) + self.conn = sqlite3.connect(os.path.join(data_dir, db_name), + detect_types=sqlite3.PARSE_DECLTYPES) self.pickle_key = pickle_key self.create_tables_if_needed() @@ -53,6 +57,18 @@ def create_tables_if_needed(self): 'curve_key TEXT, session BLOB,' 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' 'ON DELETE CASCADE)') + c.execute('CREATE TABLE IF NOT EXISTS megolm_outbound_sessions ' + '(device_id TEXT, room_id TEXT PRIMARY KEY, session BLOB,' + 'max_age_s FLOAT, max_messages INTEGER, creation_time TIMESTAMP,' + 'message_count INTEGER,' + 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' + 'ON DELETE CASCADE)') + c.execute('CREATE TABLE IF NOT EXISTS megolm_outbound_devices ' + '(device_id TEXT, room_id TEXT, user_device_id TEXT,' + 'UNIQUE(device_id, room_id, user_device_id),' + 'FOREIGN KEY(room_id) REFERENCES megolm_outbound_sessions(room_id) ' + 'ON DELETE CASCADE,' + 'FOREIGN KEY(device_id) REFERENCES accounts(device_id))') c.close() self.conn.commit() @@ -225,5 +241,119 @@ def get_inbound_session(self, room_id, curve_key, session_id, sessions=None): sessions[session.id] = session return session + def save_outbound_session(self, room_id, session): + """Saves a Megolm outbound session. + + Args: + room_id (str): The room corresponding to the session. + session (MegolmOutboundSession): The session to save. + """ + c = self.conn.cursor() + pickle = session.pickle(self.pickle_key) + c.execute( + 'INSERT OR IGNORE INTO megolm_outbound_sessions VALUES (?,?,?,?,?,?,?)', + (self.device_id, room_id, pickle, session.max_age.total_seconds(), + session.max_messages, session.creation_time, session.message_count) + ) + c.execute('UPDATE megolm_outbound_sessions SET session=? WHERE device_id=? AND ' + 'room_id=?', (pickle, self.device_id, room_id)) + c.close() + self.conn.commit() + + def load_outbound_sessions(self, sessions): + """Loads all saved outbound Megolm sessions. + + Also loads the devices each are shared with. + + Args: + sessions (dict): A map from room_id to a :class:`.MegolmOutboundSession` + object, which will be populated. + """ + c = self.conn.cursor() + rows = c.execute( + 'SELECT room_id, session, max_age_s, max_messages, creation_time,' + 'message_count FROM megolm_outbound_sessions WHERE device_id=?', + (self.device_id,) + ) + for row in rows.fetchall(): + device_ids = c.execute( + 'SELECT user_device_id FROM megolm_outbound_devices WHERE device_id=? ' + 'AND room_id=?', (self.device_id, row[0]) + ) + devices = {device_id[0] for device_id in device_ids} + max_age_s = row[2] + max_age = timedelta(seconds=max_age_s) + session = MegolmOutboundSession.from_pickle( + bytes(row[1]), devices, max_age, row[3], row[4], row[5], self.pickle_key) + sessions[row[0]] = session + c.close() + + def get_outbound_session(self, room_id, sessions=None): + """Gets a saved outbound Megolm session. + + Also loads the devices it is shared with. + + Args: + room_id (str): The room corresponding to the session. + sessions (dict): Optional. A map from room_id to a + :class:`.MegolmOutboundSession` object, to which the session will be + added. + + Returns: + :class:`.MegolmOutboundSession` object, or ``None`` if the session was + not found. + """ + c = self.conn.cursor() + c.execute( + 'SELECT session, max_age_s, max_messages, creation_time, message_count ' + 'FROM megolm_outbound_sessions WHERE device_id=? AND room_id=?', + (self.device_id, room_id) + ) + try: + row = c.fetchone() + session_data = bytes(row[0]) + except TypeError: + c.close() + return None + device_ids = c.execute( + 'SELECT user_device_id FROM megolm_outbound_devices WHERE device_id=? ' + 'AND room_id=?', (self.device_id, room_id) + ) + devices = {device_id[0] for device_id in device_ids} + c.close() + max_age_s = row[1] + max_age = timedelta(seconds=max_age_s) + session = MegolmOutboundSession.from_pickle( + session_data, devices, max_age, row[2], row[3], row[4], self.pickle_key) + if sessions is not None: + sessions[room_id] = session + return session + + def remove_outbound_session(self, room_id): + """Removes a saved outbound Megolm session. + + Args: + room_id (str): The room corresponding to the session. + """ + c = self.conn.cursor() + c.execute('DELETE FROM megolm_outbound_sessions WHERE device_id=? AND room_id=?', + (self.device_id, room_id)) + c.close() + self.conn.commit() + + def save_megolm_outbound_devices(self, room_id, device_ids): + """Saves devices an outbound Megolm session is shared with. + + Args: + room_id (str): The room corresponding to the session. + device_ids (iterable): A list of device ids. + """ + c = self.conn.cursor() + rows = [(self.device_id, room_id, device_id) for device_id in device_ids] + c.executemany( + 'INSERT OR IGNORE INTO megolm_outbound_devices VALUES (?,?,?)', rows) + c.close() + self.conn.commit() + def close(self): self.conn.close() diff --git a/matrix_client/crypto/megolm_outbound_session.py b/matrix_client/crypto/megolm_outbound_session.py index cff87f6a..d21b049f 100644 --- a/matrix_client/crypto/megolm_outbound_session.py +++ b/matrix_client/crypto/megolm_outbound_session.py @@ -16,6 +16,10 @@ class MegolmOutboundSession(OutboundGroupSession): max_messages (int): Optional. The maximum number of messages that should be sent. A new message in considered sent each time there is a call to ``encrypt``. Default to 100 if not present. + + Attributes: + creation_time (datetime.datetime): Creation time of the session. + message_count (int): Number of messages encrypted using the session. """ def __init__(self, max_age=None, max_messages=None): @@ -58,3 +62,14 @@ def should_rotate(self): def encrypt(self, plaintext): self.message_count += 1 return super(MegolmOutboundSession, self).encrypt(plaintext) + + @classmethod + def from_pickle(cls, pickle, devices, max_age, max_messages, creation_time, + message_count, passphrase=''): + session = super(MegolmOutboundSession, cls).from_pickle(pickle, passphrase) + session.devices = devices + session.max_age = max_age + session.max_messages = max_messages + session.creation_time = creation_time + session.message_count = message_count + return session diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index dc43a47e..0be39fe5 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -68,11 +68,13 @@ def __init__(self, self.db = Store(self.device_id, **conf) self.olm_sessions = defaultdict(list) self.megolm_inbound_sessions = defaultdict(lambda: defaultdict(dict)) + self.megolm_outbound_sessions = {} self.olm_account = self.db.get_olm_account() if self.olm_account: if load_all: self.db.load_olm_sessions(self.olm_sessions) self.db.load_inbound_sessions(self.megolm_inbound_sessions) + self.db.load_outbound_sessions(self.megolm_outbound_sessions) logger.info('Loaded Olm account from database for device %s.', device_id) else: self.olm_account = olm.Account() @@ -89,7 +91,6 @@ def __init__(self, keys_threshold) self.device_keys = defaultdict(dict) self.device_list = DeviceList(self, api, self.device_keys) - self.megolm_outbound_sessions = {} self.megolm_index_record = defaultdict(dict) def upload_identity_keys(self): @@ -449,6 +450,8 @@ def megolm_start_session(self, room): user_devices = {user.user_id: list(self.device_keys[user.user_id]) for user in users} self.device_list.get_room_device_keys(room) + self.db.remove_outbound_session(room.room_id) + self.db.save_outbound_session(room.room_id, session) self.megolm_share_session(room.room_id, user_devices, session) # Store a corresponding inbound session, so that we can decrypt our own messages self.megolm_add_inbound_session(room.room_id, self.identity_keys['curve25519'], @@ -492,6 +495,7 @@ def megolm_share_session(self, room_id, user_devices, session): new_devices.add(device_id) self.api.send_to_device('m.room.encrypted', messages) session.add_devices(new_devices) + self.db.save_megolm_outbound_devices(room_id, new_devices) def megolm_share_session_with_new_devices(self, room, session): """Share a megolm session with new devices in a room. @@ -529,7 +533,11 @@ def megolm_build_encrypted_event(self, room, event): room_id = room.room_id session = self.megolm_outbound_sessions.get(room_id) - if not session or session.should_rotate(): + if not session: + session = self.db.get_outbound_session(room_id, self.megolm_outbound_sessions) + if not session: + session = self.megolm_start_session(room) + if session.should_rotate(): session = self.megolm_start_session(room) else: self.megolm_share_session_with_new_devices(room, session) @@ -541,6 +549,7 @@ def megolm_build_encrypted_event(self, room, event): } encrypted_payload = session.encrypt(json.dumps(payload)) + self.db.save_outbound_session(room_id, session) encrypted_event = { 'algorithm': self._megolm_algorithm, @@ -561,6 +570,7 @@ def megolm_remove_outbound_session(self, room_id): """ try: self.megolm_outbound_sessions.pop(room_id) + self.db.remove_outbound_session(room_id) logger.info('Removed Meglom outbound session in %s.', room_id) except KeyError: pass From 0ed22c15fbf2283bb2377e258420b24f5051423f Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Mon, 2 Jul 2018 22:23:07 +0200 Subject: [PATCH 28/66] add outbound sessions persistence tests --- test/crypto/crypto_store_test.py | 59 +++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index 3ae2b0f1..538932d0 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -7,6 +7,8 @@ from matrix_client.crypto.crypto_store import CryptoStore from matrix_client.crypto.olm_device import OlmDevice +from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession +from matrix_client.room import Room class TestCryptoStore(object): @@ -15,6 +17,7 @@ class TestCryptoStore(object): device_id = 'AUIETSRN' user_id = '@user:matrix.org' room_id = '!test:example.com' + room = Room(None, room_id) db_name = 'test.db' db_path = mkdtemp() store_conf = { @@ -156,14 +159,65 @@ def test_megolm_inbound_persistence(self, curve_key, device): self.store.remove_olm_account() assert not self.store.get_inbound_session(self.room_id, curve_key, session.id) + @pytest.mark.usefixtures('account') + def test_megolm_outbound_persistence(self, device): + session = MegolmOutboundSession(max_messages=2, max_age=100000) + session.message_count = 1 + session.add_device(self.device_id) + sessions = {} + + self.store.load_outbound_sessions(sessions) + assert not sessions + assert not self.store.get_outbound_session(self.room_id) + + self.store.save_outbound_session(self.room_id, session) + self.store.save_megolm_outbound_devices(self.room_id, {self.device_id}) + self.store.load_outbound_sessions(sessions) + assert sessions[self.room_id].id == session.id + assert sessions[self.room_id].devices == session.devices + assert sessions[self.room_id].creation_time == session.creation_time + assert sessions[self.room_id].max_messages == session.max_messages + assert sessions[self.room_id].message_count == session.message_count + assert sessions[self.room_id].max_age == session.max_age + + saved_session = self.store.get_outbound_session(self.room_id) + assert saved_session.id == session.id + assert saved_session.devices == session.devices + assert saved_session.creation_time == session.creation_time + assert saved_session.max_messages == session.max_messages + assert saved_session.message_count == session.message_count + assert saved_session.max_age == session.max_age + + sessions.clear() + saved_session = self.store.get_outbound_session(self.room_id, sessions) + assert sessions[self.room_id].id == session.id + + self.store.remove_outbound_session(self.room_id) + assert not self.store.get_outbound_session(self.room_id) + + self.store.save_outbound_session(self.room_id, session) + saved_session = self.store.get_outbound_session(self.room_id) + # Verify the saved devices have been erased with the session + assert not saved_session.devices + + with pytest.raises(AttributeError): + device.megolm_build_encrypted_event(self.room, {}) + assert device.megolm_outbound_sessions[self.room_id].id == session.id + + self.store.remove_olm_account() + assert not self.store.get_outbound_session(self.room_id) + def test_load_all(self, account, curve_key): curve_key = account.identity_keys['curve25519'] session = olm.OutboundSession(account, curve_key, curve_key) - out_session = olm.OutboundGroupSession() + out_session = MegolmOutboundSession() + out_session.add_device(self.device_id) in_session = olm.InboundGroupSession(out_session.session_key) self.store.save_inbound_session(self.room_id, curve_key, in_session) self.store.save_olm_session(curve_key, session) + self.store.save_outbound_session(self.room_id, out_session) + self.store.save_megolm_outbound_devices(self.room_id, {self.device_id}) device = OlmDevice( None, self.user_id, self.device_id, store_conf=self.store_conf, load_all=True) @@ -172,3 +226,6 @@ def test_load_all(self, account, curve_key): saved_in_session = \ device.megolm_inbound_sessions[self.room_id][curve_key][in_session.id] assert saved_in_session.id == in_session.id + saved_out_session = device.megolm_outbound_sessions[self.room_id] + assert saved_out_session.id == out_session.id + assert saved_out_session.devices == out_session.devices From 46405db0bb7663d0e3cf4ee53857e9b9f5661b0a Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 4 Jul 2018 15:48:58 +0200 Subject: [PATCH 29/66] add device keys persistence Since device tracking is done in a separate thread, we need to be careful not to use the same connection object between threads (in fact the problem existed since the first persistence commit when using MatrixClient.start_listener_thread). --- matrix_client/client.py | 5 + matrix_client/crypto/crypto_store.py | 167 ++++++++++++++++++++++++++- matrix_client/crypto/device_list.py | 46 +++++++- matrix_client/crypto/olm_device.py | 7 +- 4 files changed, 214 insertions(+), 11 deletions(-) diff --git a/matrix_client/client.py b/matrix_client/client.py index 6cc829b6..8f80048e 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -133,6 +133,7 @@ def __init__(self, base_url, token=None, user_id=None, self._encryption = encryption self.encryption_conf = encryption_conf or {} self.olm_device = None + self.first_sync = True if isinstance(cache_level, CACHE): self._cache_level = cache_level else: @@ -592,6 +593,10 @@ def _sync(self, timeout_ms=30000): self.sync_token = response["next_batch"] + if self._encryption and self.first_sync: + self.first_sync = False + self.olm_device.device_list.update_after_restart(self.sync_token) + for presence_update in response['presence']['events']: for callback in self.presence_listeners.values(): callback(presence_update) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index 294413b0..b19d5092 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -1,7 +1,9 @@ import logging import os import sqlite3 +from collections import defaultdict from datetime import timedelta +from threading import current_thread import olm from appdirs import user_data_dir @@ -37,11 +39,16 @@ def __init__(self, os.makedirs(data_dir) except OSError: pass - self.conn = sqlite3.connect(os.path.join(data_dir, db_name), - detect_types=sqlite3.PARSE_DECLTYPES) + self.db_filepath = os.path.join(data_dir, db_name) + + # Map from a thread id to a connection object + self._conn = defaultdict(self.instanciate_connection) self.pickle_key = pickle_key self.create_tables_if_needed() + def instanciate_connection(self): + return sqlite3.connect(self.db_filepath, detect_types=sqlite3.PARSE_DECLTYPES) + def create_tables_if_needed(self): """Ensures all the tables exist.""" c = self.conn.cursor() @@ -69,6 +76,19 @@ def create_tables_if_needed(self): 'FOREIGN KEY(room_id) REFERENCES megolm_outbound_sessions(room_id) ' 'ON DELETE CASCADE,' 'FOREIGN KEY(device_id) REFERENCES accounts(device_id))') + c.execute('CREATE TABLE IF NOT EXISTS device_keys ' + '(device_id TEXT, user_id TEXT, user_device_id TEXT PRIMARY KEY,' + 'ed_key TEXT, curve_key TEXT,' + 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' + 'ON DELETE CASCADE)') + c.execute('CREATE TABLE IF NOT EXISTS tracked_users ' + '(device_id TEXT, user_id TEXT, UNIQUE(device_id, user_id),' + 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' + 'ON DELETE CASCADE)') + c.execute('CREATE TABLE IF NOT EXISTS sync_tokens ' + '(device_id TEXT PRIMARY KEY, token TEXT,' + 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' + 'ON DELETE CASCADE)') c.close() self.conn.commit() @@ -355,5 +375,148 @@ def save_megolm_outbound_devices(self, room_id, device_ids): c.close() self.conn.commit() + def save_device_keys(self, device_keys): + """Saves device keys. + + Args: + device_keys (defaultdict(dict)): The format is ``{: {: + {'curve25519': , 'ed25519': }``. + """ + c = self.conn.cursor() + rows = [] + for user_id, devices_dict in device_keys.items(): + for device_id, keys_dict in devices_dict.items(): + rows.append((self.device_id, user_id, device_id, keys_dict['ed25519'], + keys_dict['curve25519'])) + c.executemany('REPLACE INTO device_keys VALUES (?,?,?,?,?)', rows) + c.close() + self.conn.commit() + + def load_device_keys(self, device_keys): + """Loads all saved device keys. + + Args: + device_keys (defaultdict(dict)): An object which will get populated with + the keys. The format is ``{: {: + {'curve25519': , 'ed25519': }``. + """ + c = self.conn.cursor() + rows = c.execute( + 'SELECT user_id, user_device_id, ed_key, curve_key FROM device_keys ' + 'WHERE device_id=?', (self.device_id,) + ) + for row in rows: + device_keys[row[0]][row[1]] = { + 'ed25519': row[2], + 'curve25519': row[3] + } + c.close() + + def get_device_keys(self, user_devices, device_keys=None): + """Gets the devices keys of the specified devices. + + Args: + user_devices (dict): A map from user ids to a list of device ids. + device_keys (defaultdict(dict)): Optional. Will be updated with + the retrieved keys. + + Returns: + A defaultdict(dict) containing the keys. + """ + c = self.conn.cursor() + rows = [] + for user_id in user_devices: + if not user_devices[user_id]: + c.execute( + 'SELECT user_id, user_device_id, ed_key, curve_key FROM device_keys ' + 'WHERE device_id=? AND user_id=?', (self.device_id, user_id) + ) + rows.extend(c.fetchall()) + else: + for device_id in user_devices[user_id]: + c.execute( + 'SELECT user_id, user_device_id, ed_key, curve_key FROM ' + 'device_keys WHERE device_id=? AND user_id=? AND ' + 'user_device_id=?', (self.device_id, user_id, device_id) + ) + rows.extend(c.fetchall()) + c.close() + result = defaultdict(dict) + for row in rows: + result[row[0]][row[1]] = { + 'ed25519': row[2], + 'curve25519': row[3] + } + if device_keys is not None and result: + device_keys.update(result) + return result + + def save_tracked_users(self, user_ids): + """Saves tracked users. + + Args: + user_ids (iterable): The user ids to save. + """ + c = self.conn.cursor() + rows = [(self.device_id, user_id) for user_id in user_ids] + c.executemany('INSERT OR IGNORE INTO tracked_users VALUES (?,?)', rows) + c.close() + self.conn.commit() + + def remove_tracked_users(self, user_ids): + """Removes tracked users. + + Args: + user_ids (iterable): The user ids to remove. + """ + c = self.conn.cursor() + rows = [(user_id,) for user_id in user_ids] + c.executemany('DELETE FROM tracked_users WHERE user_id=?', rows) + c.close() + self.conn.commit() + + def load_tracked_users(self, tracked_users): + """Loads all tracked users. + + Args: + tracked_users (set): Will be populated with user ids. + """ + c = self.conn.cursor() + rows = c.execute( + 'SELECT user_id FROM tracked_users WHERE device_id=?', (self.device_id,)) + tracked_users.update(t[0] for t in rows) + c.close() + return tracked_users + + def save_sync_token(self, sync_token): + """Saves a sync token. + + Args: + sync_token (str): The token to save. + """ + c = self.conn.cursor() + c.execute('REPLACE INTO sync_tokens VALUES (?,?)', (self.device_id, sync_token)) + c.close() + self.conn.commit() + + def get_sync_token(self): + """Gets the saved sync token. + + Returns: + A string corresponding to the token, or None if there wasn't any. + """ + c = self.conn.cursor() + c.execute('SELECT token FROM sync_tokens WHERE device_id=?', (self.device_id,)) + try: + return c.fetchone()[0] + except TypeError: + return None + finally: + c.close() + def close(self): self.conn.close() + + @property + def conn(self): + return self._conn[current_thread().ident] diff --git a/matrix_client/crypto/device_list.py b/matrix_client/crypto/device_list.py index c0bf406a..c3872f83 100644 --- a/matrix_client/crypto/device_list.py +++ b/matrix_client/crypto/device_list.py @@ -19,7 +19,7 @@ class DeviceList: device_keys (defaultdict(dict)): A map from user to device to keys. """ - def __init__(self, olm_device, api, device_keys): + def __init__(self, olm_device, api, device_keys, db): self.olm_device = olm_device self.api = api self.device_keys = device_keys @@ -35,9 +35,11 @@ def __init__(self, olm_device, api, device_keys): # Allows to wake up the thread when there are new users to update, and to # synchronise shared data. self.thread_condition = Condition() + self.db = db + self.db.load_tracked_users(self.tracked_user_ids) self.update_thread = _UpdateDeviceList( self.thread_condition, self.outdated_user_ids, self._download_device_keys, - self.tracked_user_ids + self.tracked_user_ids, db ) self.update_thread.start() @@ -53,7 +55,11 @@ def get_room_device_keys(self, room, blocking=True): downloaded before returning. """ logger.info('Fetching all missing keys in room %s.', room.room_id) - user_ids = {u.user_id for u in room.get_joined_members()} - self.tracked_user_ids + members = {m.user_id for m in room.get_joined_members()} + missing_members = {m: [] for m in members if not self.device_keys[m]} + if missing_members: + self.db.get_device_keys(missing_members, self.device_keys) + user_ids = members - self.tracked_user_ids if not user_ids: logger.info('Already had all the keys in room %s.', room.room_id) if blocking: @@ -102,6 +108,19 @@ def track_users(self, user_ids): if user_ids: self._add_outdated_users(user_ids) + def update_after_restart(self, to_token): + from_token = self.db.get_sync_token() + if not from_token: + # First launch. Persist this token in case we would not have the occasion to + # save one this session. + self.db.save_sync_token(to_token) + return + resp = self.api.key_changes(from_token, to_token) + if resp.get('left'): + self.stop_tracking_users(resp['left']) + if resp.get('changed'): + self.update_user_device_keys(resp['changed']) + def stop_tracking_users(self, user_ids): """Stop tracking users. @@ -113,6 +132,7 @@ def stop_tracking_users(self, user_ids): with self.thread_condition: self.tracked_user_ids.difference_update(user_ids) self.outdated_user_ids.difference_update(user_ids) + self.db.remove_tracked_users(user_ids) logger.info('Stopped tracking users: %s.', user_ids) def update_user_device_keys(self, user_ids, since_token=None): @@ -239,7 +259,7 @@ def sync_token(self, token): class _UpdateDeviceList(Thread): - def __init__(self, cond, user_ids, download_method, tracked_user_ids): + def __init__(self, cond, user_ids, download_method, tracked_user_ids, db): # We wait on this condition when there is nothing to do. Outside code should use # it to notify us when they add data to be processed in outdated_user_ids so that # we can wake up and process it. @@ -250,6 +270,10 @@ def __init__(self, cond, user_ids, download_method, tracked_user_ids): # Cleared when we start a download, and set when we have finished it. This can be # used by outside code in order to know if we are in the middle of a download, and # allows to wait for it to complete by waiting on this event. + self.db = db + # Cleared when we start a download, and set when we have finished it. This can be + # used by outside code in order to know if we are in the middle of a download, and + # allows to wait for it to complete by waiting on this event. self.event = Event() # Used internally to terminate gracefully on program exit. self._should_terminate = Event() @@ -270,18 +294,28 @@ def run(self): to_download = self.outdated_user_ids.copy() self.outdated_user_ids.clear() self.event.clear() - self.tracked_user_ids.update(to_download) + new_user_ids = to_download.difference(self.tracked_user_ids) + if new_user_ids: + self.tracked_user_ids.update(new_user_ids) payload = {user_id: [] for user_id in to_download} logger.info('Downloading device keys for users: %s.', to_download) try: - self.download(payload, self.outdated_user_ids.sync_token) + changed = self.download(payload, self.outdated_user_ids.sync_token) self.event.set() to_download.mark_as_processed() + if changed: + self.db.save_device_keys(changed) + if new_user_ids: + self.db.save_tracked_users(new_user_ids) + if self.outdated_user_ids.sync_token: + # FIXME this should be next_batch instead of since + self.db.save_sync_token(self.outdated_user_ids.sync_token) except (MatrixHttpLibError, MatrixRequestError) as e: logger.warning('Network error when fetching device keys (will retry): %s', e) with self.cond: self.outdated_user_ids.update(to_download) + self.tracked_user_ids.difference_update(new_user_ids) def join(self, timeout=None): # If we are joined, this means that the main program is terminating. diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 0be39fe5..47e00c54 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -69,12 +69,14 @@ def __init__(self, self.olm_sessions = defaultdict(list) self.megolm_inbound_sessions = defaultdict(lambda: defaultdict(dict)) self.megolm_outbound_sessions = {} + self.device_keys = defaultdict(dict) self.olm_account = self.db.get_olm_account() if self.olm_account: if load_all: self.db.load_olm_sessions(self.olm_sessions) self.db.load_inbound_sessions(self.megolm_inbound_sessions) self.db.load_outbound_sessions(self.megolm_outbound_sessions) + self.db.load_device_keys(self.device_keys) logger.info('Loaded Olm account from database for device %s.', device_id) else: self.olm_account = olm.Account() @@ -89,8 +91,7 @@ def __init__(self, self.one_time_keys_manager = OneTimeKeysManager(target_keys_number, signed_keys_proportion, keys_threshold) - self.device_keys = defaultdict(dict) - self.device_list = DeviceList(self, api, self.device_keys) + self.device_list = DeviceList(self, api, self.device_keys, self.db) self.megolm_index_record = defaultdict(dict) def upload_identity_keys(self): @@ -447,9 +448,9 @@ def megolm_start_session(self, room): session.id, room.room_id) users = room.get_joined_members() + self.device_list.get_room_device_keys(room) user_devices = {user.user_id: list(self.device_keys[user.user_id]) for user in users} - self.device_list.get_room_device_keys(room) self.db.remove_outbound_session(room.room_id) self.db.save_outbound_session(room.room_id, session) self.megolm_share_session(room.room_id, user_devices, session) From 64134d7fb0c32d8e48b5a3ced3a4f5e2559ed6fb Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 4 Jul 2018 15:48:49 +0200 Subject: [PATCH 30/66] add device keys persistence tests --- test/client_test.py | 1 + test/crypto/crypto_store_test.py | 86 +++++++++++++++++++++++++++++++- test/crypto/device_list_test.py | 34 +++++++++++-- 3 files changed, 116 insertions(+), 5 deletions(-) diff --git a/test/client_test.py b/test/client_test.py index 7cde5f74..472ad195 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -571,6 +571,7 @@ def test_detect_encryption_state(): @responses.activate def test_one_time_keys_sync(): client = MatrixClient(HOSTNAME, encryption=True) + client.first_sync = False sync_url = HOSTNAME + MATRIX_V2_API_PATH + "/sync" sync_response = deepcopy(response_examples.example_sync) payload = {'dummy': 1} diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index 538932d0..83a07b6e 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -9,6 +9,7 @@ from matrix_client.crypto.olm_device import OlmDevice from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession from matrix_client.room import Room +from matrix_client.user import User class TestCryptoStore(object): @@ -18,6 +19,8 @@ class TestCryptoStore(object): user_id = '@user:matrix.org' room_id = '!test:example.com' room = Room(None, room_id) + user = User(None, user_id, '') + room._members[user_id] = user db_name = 'test.db' db_path = mkdtemp() store_conf = { @@ -200,24 +203,104 @@ def test_megolm_outbound_persistence(self, device): # Verify the saved devices have been erased with the session assert not saved_session.devices - with pytest.raises(AttributeError): + with pytest.raises(KeyError): device.megolm_build_encrypted_event(self.room, {}) assert device.megolm_outbound_sessions[self.room_id].id == session.id self.store.remove_olm_account() assert not self.store.get_outbound_session(self.room_id) + @pytest.mark.usefixtures('account') + def test_device_keys_persistence(self, device): + user_devices = {self.user_id: [self.device_id]} + keys = { + 'curve25519': 'curve', + 'ed25519': 'ed' + } + device_keys = defaultdict(dict) + + self.store.load_device_keys(device_keys) + assert not device_keys + assert not self.store.get_device_keys(user_devices, device_keys) + assert not device_keys + + device_keys_to_save = {self.user_id: {self.device_id: keys}} + self.store.save_device_keys(device_keys_to_save) + self.store.load_device_keys(device_keys) + assert device_keys == device_keys_to_save + + device_keys.clear() + assert self.store.get_device_keys(user_devices) == device_keys_to_save + assert self.store.get_device_keys(user_devices, device_keys) + assert device_keys == device_keys_to_save + + # Test [] wildcard + assert self.store.get_device_keys({self.user_id: []}) == device_keys_to_save + + device.device_list.tracked_user_ids = {self.user_id} + device.device_list.get_room_device_keys(self.room) + assert device.device_keys == device_keys_to_save + + # Test multiples [] + device_keys.clear() + user_id = 'test' + device_id = 'test' + device_keys_to_save[user_id] = {device_id: keys} + self.store.save_device_keys(device_keys_to_save) + user_devices[user_id] = [] + user_devices[self.user_id] = [] + assert self.store.get_device_keys(user_devices) == device_keys_to_save + + self.store.remove_olm_account() + assert not self.store.get_device_keys(user_devices) + + @pytest.mark.usefixtures('account') + def test_tracked_users_persistence(self): + tracked_user_ids = set() + tracked_user_ids_to_save = {self.user_id} + + self.store.load_tracked_users(tracked_user_ids) + assert not tracked_user_ids + + self.store.save_tracked_users(tracked_user_ids_to_save) + self.store.load_tracked_users(tracked_user_ids) + assert tracked_user_ids == tracked_user_ids_to_save + + self.store.remove_tracked_users({self.user_id}) + tracked_user_ids.clear() + self.store.load_tracked_users(tracked_user_ids) + assert not tracked_user_ids + + @pytest.mark.usefixtures('account') + def test_sync_token_persistence(self): + sync_token = 'test' + + assert not self.store.get_sync_token() + + self.store.save_sync_token(sync_token) + assert self.store.get_sync_token() == sync_token + + sync_token = 'new' + self.store.save_sync_token(sync_token) + assert self.store.get_sync_token() == sync_token + def test_load_all(self, account, curve_key): curve_key = account.identity_keys['curve25519'] session = olm.OutboundSession(account, curve_key, curve_key) out_session = MegolmOutboundSession() out_session.add_device(self.device_id) in_session = olm.InboundGroupSession(out_session.session_key) + keys = { + 'curve25519': 'curve', + 'ed25519': 'ed' + } + device_keys_to_save = {self.user_id: {self.device_id: keys}} self.store.save_inbound_session(self.room_id, curve_key, in_session) self.store.save_olm_session(curve_key, session) self.store.save_outbound_session(self.room_id, out_session) self.store.save_megolm_outbound_devices(self.room_id, {self.device_id}) + self.store.save_device_keys(device_keys_to_save) device = OlmDevice( None, self.user_id, self.device_id, store_conf=self.store_conf, load_all=True) @@ -229,3 +312,4 @@ def test_load_all(self, account, curve_key): saved_out_session = device.megolm_outbound_sessions[self.room_id] assert saved_out_session.id == out_session.id assert saved_out_session.devices == out_session.devices + assert device.device_keys == device_keys_to_save diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py index 71d321c9..28222fab 100644 --- a/test/crypto/device_list_test.py +++ b/test/crypto/device_list_test.py @@ -13,7 +13,7 @@ from matrix_client.errors import MatrixRequestError from matrix_client.crypto.device_list import (_OutdatedUsersSet as OutdatedUsersSet, _UpdateDeviceList as UpdateDeviceList) -from test.crypto.dummy_olm_device import OlmDevice +from test.crypto.dummy_olm_device import OlmDevice, DummyStore from test.response_examples import example_key_query_response HOSTNAME = 'http://example.com' @@ -131,7 +131,8 @@ def test_update_thread(self): def dummy_download(user_devices, since_token=None): assert user_devices == {self.user_id: []} return - thread = UpdateDeviceList(Condition(), outdated_users, dummy_download, set()) + thread = UpdateDeviceList(Condition(), outdated_users, dummy_download, set(), + DummyStore()) thread.start() event.wait() @@ -151,7 +152,7 @@ def error_on_first_download(user_devices, since_token=None): return error_on_first_download.c = 0 thread = UpdateDeviceList( - Condition(), outdated_users, error_on_first_download, set()) + Condition(), outdated_users, error_on_first_download, set(), DummyStore()) thread.start() thread.event.wait() assert error_on_first_download.c == 2 @@ -160,7 +161,7 @@ def error_on_first_download(user_devices, since_token=None): # Cover a missing branch thread = UpdateDeviceList( - Condition(), outdated_users, error_on_first_download, set()) + Condition(), outdated_users, error_on_first_download, set(), DummyStore()) thread._should_terminate.set() thread.start() thread.join() @@ -243,6 +244,31 @@ def test_update_user_device_keys(self): self.device_list.update_thread.event.wait() assert len(responses.calls) == 1 + @responses.activate + def test_update_after_restart(self): + keys_changes_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/changes' + + class DB(DummyStore): + + def __getattribute__(self, name): + if name == 'get_sync_token': + return lambda: 'test' + return super(DB, self).__getattribute__(name) + db = self.device_list.db + + # First launch, no sync token + self.device_list.update_after_restart('test') + + self.device_list.db = DB() + responses.add(responses.GET, keys_changes_url, json={}) + self.device_list.update_after_restart('test') + + resp = {'left': 'test', 'changed': self.user_id} + responses.replace(responses.GET, keys_changes_url, json=resp) + self.device_list.tracked_user_ids.clear() + self.device_list.update_after_restart('test') + self.device_list.db = db + def test_outdated_users_set(): s = OutdatedUsersSet() From 1804f7a1657eaaa87b746fa34685805997fe5ed5 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 4 Jul 2018 17:09:04 +0200 Subject: [PATCH 31/66] ignore optional dependency appdirs when building doc --- docs/source/conf.py | 2 +- matrix_client/crypto/crypto_store.py | 22 +++++++--------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index e3a76c0d..8e63b535 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -96,4 +96,4 @@ 'Miscellaneous'), ] -autodoc_mock_imports = ["olm", "canonicaljson"] +autodoc_mock_imports = ["olm", "canonicaljson", "appdirs"] diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index b19d5092..ba2abbd3 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -222,8 +222,7 @@ def load_inbound_sessions(self, sessions): """ c = self.conn.cursor() rows = c.execute( - 'SELECT room_id, curve_key, session FROM megolm_inbound_sessions WHERE ' - 'device_id=?', (self.device_id,) + 'SELECT * FROM megolm_inbound_sessions WHERE device_id=?', (self.device_id,) ) for row in rows: session = olm.InboundGroupSession.from_pickle(bytes(row[2]), self.pickle_key) @@ -291,10 +290,7 @@ def load_outbound_sessions(self, sessions): """ c = self.conn.cursor() rows = c.execute( - 'SELECT room_id, session, max_age_s, max_messages, creation_time,' - 'message_count FROM megolm_outbound_sessions WHERE device_id=?', - (self.device_id,) - ) + 'SELECT * FROM megolm_outbound_sessions WHERE device_id=?', (self.device_id,)) for row in rows.fetchall(): device_ids = c.execute( 'SELECT user_device_id FROM megolm_outbound_devices WHERE device_id=? ' @@ -325,8 +321,7 @@ def get_outbound_session(self, room_id, sessions=None): """ c = self.conn.cursor() c.execute( - 'SELECT session, max_age_s, max_messages, creation_time, message_count ' - 'FROM megolm_outbound_sessions WHERE device_id=? AND room_id=?', + 'SELECT * FROM megolm_outbound_sessions WHERE device_id=? AND room_id=?', (self.device_id, room_id) ) try: @@ -402,9 +397,7 @@ def load_device_keys(self, device_keys): """ c = self.conn.cursor() rows = c.execute( - 'SELECT user_id, user_device_id, ed_key, curve_key FROM device_keys ' - 'WHERE device_id=?', (self.device_id,) - ) + 'SELECT * FROM device_keys WHERE device_id=?', (self.device_id,)) for row in rows: device_keys[row[0]][row[1]] = { 'ed25519': row[2], @@ -428,15 +421,14 @@ def get_device_keys(self, user_devices, device_keys=None): for user_id in user_devices: if not user_devices[user_id]: c.execute( - 'SELECT user_id, user_device_id, ed_key, curve_key FROM device_keys ' - 'WHERE device_id=? AND user_id=?', (self.device_id, user_id) + 'SELECT * FROM device_keys WHERE device_id=? AND user_id=?', + (self.device_id, user_id) ) rows.extend(c.fetchall()) else: for device_id in user_devices[user_id]: c.execute( - 'SELECT user_id, user_device_id, ed_key, curve_key FROM ' - 'device_keys WHERE device_id=? AND user_id=? AND ' + 'SELECT * FROM device_keys WHERE device_id=? AND user_id=? AND ' 'user_device_id=?', (self.device_id, user_id, device_id) ) rows.extend(c.fetchall()) From 41e35d12a66468fc75eeaa8d0b3ab2a523633d74 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sat, 21 Jul 2018 16:40:24 +0200 Subject: [PATCH 32/66] general improvement to CryptoStore Nicer sqlite practices and better docstrings. --- matrix_client/crypto/crypto_store.py | 158 ++++++++++++++------------- 1 file changed, 85 insertions(+), 73 deletions(-) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index ba2abbd3..349f875e 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -47,48 +47,51 @@ def __init__(self, self.create_tables_if_needed() def instanciate_connection(self): - return sqlite3.connect(self.db_filepath, detect_types=sqlite3.PARSE_DECLTYPES) + con = sqlite3.connect(self.db_filepath, detect_types=sqlite3.PARSE_DECLTYPES) + con.row_factory = sqlite3.Row + return con def create_tables_if_needed(self): """Ensures all the tables exist.""" c = self.conn.cursor() - c.execute('PRAGMA foreign_keys = ON') - c.execute('CREATE TABLE IF NOT EXISTS accounts (device_id TEXT PRIMARY KEY,' - 'account BLOB)') - c.execute('CREATE TABLE IF NOT EXISTS olm_sessions (device_id TEXT,' - 'session_id TEXT PRIMARY KEY, curve_key TEXT, session BLOB,' - 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' - 'ON DELETE CASCADE)') - c.execute('CREATE TABLE IF NOT EXISTS megolm_inbound_sessions ' - '(device_id TEXT, session_id TEXT PRIMARY KEY, room_id TEXT,' - 'curve_key TEXT, session BLOB,' - 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' - 'ON DELETE CASCADE)') - c.execute('CREATE TABLE IF NOT EXISTS megolm_outbound_sessions ' - '(device_id TEXT, room_id TEXT PRIMARY KEY, session BLOB,' - 'max_age_s FLOAT, max_messages INTEGER, creation_time TIMESTAMP,' - 'message_count INTEGER,' - 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' - 'ON DELETE CASCADE)') - c.execute('CREATE TABLE IF NOT EXISTS megolm_outbound_devices ' - '(device_id TEXT, room_id TEXT, user_device_id TEXT,' - 'UNIQUE(device_id, room_id, user_device_id),' - 'FOREIGN KEY(room_id) REFERENCES megolm_outbound_sessions(room_id) ' - 'ON DELETE CASCADE,' - 'FOREIGN KEY(device_id) REFERENCES accounts(device_id))') - c.execute('CREATE TABLE IF NOT EXISTS device_keys ' - '(device_id TEXT, user_id TEXT, user_device_id TEXT PRIMARY KEY,' - 'ed_key TEXT, curve_key TEXT,' - 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' - 'ON DELETE CASCADE)') - c.execute('CREATE TABLE IF NOT EXISTS tracked_users ' - '(device_id TEXT, user_id TEXT, UNIQUE(device_id, user_id),' - 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' - 'ON DELETE CASCADE)') - c.execute('CREATE TABLE IF NOT EXISTS sync_tokens ' - '(device_id TEXT PRIMARY KEY, token TEXT,' - 'FOREIGN KEY(device_id) REFERENCES accounts(device_id) ' - 'ON DELETE CASCADE)') + c.executescript(""" +PRAGMA foreign_keys = ON; +CREATE TABLE IF NOT EXISTS accounts (device_id TEXT PRIMARY KEY NOT NULL, account BLOB); +CREATE TABLE IF NOT EXISTS olm_sessions( + device_id TEXT, session_id TEXT PRIMARY KEY, curve_key TEXT, session BLOB, + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS megolm_inbound_sessions( + device_id TEXT, session_id TEXT PRIMARY KEY, room_id TEXT, curve_key TEXT, + session BLOB, + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS megolm_outbound_sessions( + device_id TEXT, room_id TEXT PRIMARY KEY, session BLOB, max_age_s FLOAT, + max_messages INTEGER, creation_time TIMESTAMP, message_count INTEGER, + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS megolm_outbound_devices( + device_id TEXT, room_id TEXT, user_device_id TEXT, + UNIQUE(device_id, room_id, user_device_id), + FOREIGN KEY(room_id) REFERENCES megolm_outbound_sessions(room_id) ON DELETE CASCADE, + FOREIGN KEY(device_id) REFERENCES accounts(device_id) +); +CREATE TABLE IF NOT EXISTS device_keys( + device_id TEXT, user_id TEXT, user_device_id TEXT PRIMARY KEY, ed_key TEXT, + curve_key TEXT, + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS tracked_users( + device_id TEXT, user_id TEXT, + UNIQUE(device_id, user_id), + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS sync_tokens( + device_id TEXT PRIMARY KEY, token TEXT, + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); + """) c.close() self.conn.commit() @@ -111,13 +114,14 @@ def get_olm_account(self): """Gets the Olm account. Returns: - olm.Account object, or None if it wasn't found for the current device_id. + ``olm.Account`` object, or ``None`` if it wasn't found for the current + device_id. """ c = self.conn.cursor() c.execute( 'SELECT account FROM accounts WHERE device_id=?', (self.device_id,)) try: - account_data = c.fetchone()[0] + account_data = c.fetchone()['account'] # sqlite gives us unicode in Python2, we want bytes account_data = bytes(account_data) except TypeError: @@ -144,7 +148,7 @@ def save_olm_sessions(self, sessions): Args: sessions (defaultdict(list)): A map from curve25519 keys to a list of - olm.Session objects. + ``olm.Session`` objects. """ c = self.conn.cursor() rows = [(self.device_id, s.id, key, s.pickle(self.pickle_key)) @@ -158,14 +162,14 @@ def load_olm_sessions(self, sessions): Args: sessions (defaultdict(list)): A map from curve25519 keys to a list of - olm.Session objects, which will be populated. + ``olm.Session`` objects, which will be populated. """ c = self.conn.cursor() rows = c.execute('SELECT curve_key, session FROM olm_sessions WHERE device_id=?', (self.device_id,)) for row in rows: - session = olm.Session.from_pickle(bytes(row[1]), self.pickle_key) - sessions[row[0]].append(session) + session = olm.Session.from_pickle(bytes(row['session']), self.pickle_key) + sessions[row['curve_key']].append(session) c.close() def get_olm_sessions(self, curve_key, sessions_dict=None): @@ -174,10 +178,10 @@ def get_olm_sessions(self, curve_key, sessions_dict=None): Args: curve_key (str): The curve25519 key of the device. sessions_dict (defaultdict(list)): Optional. A map from curve25519 keys to a - list of olm.Session objects, to which the session list will be added. + list of ``olm.Session`` objects, to which the session list will be added. Returns: - A list of olm.Session objects, or None if none were found. + A list of ``olm.Session`` objects, or ``None`` if none were found. NOTE: When overriding this, be careful to append the retrieved sessions to the @@ -188,7 +192,7 @@ def get_olm_sessions(self, curve_key, sessions_dict=None): 'SELECT session FROM olm_sessions WHERE device_id=? AND curve_key=?', (self.device_id, curve_key) ) - sessions = [olm.Session.from_pickle(bytes(row[0]), self.pickle_key) + sessions = [olm.Session.from_pickle(bytes(row['session']), self.pickle_key) for row in rows] if sessions_dict is not None: sessions_dict[curve_key].extend(sessions) @@ -225,8 +229,9 @@ def load_inbound_sessions(self, sessions): 'SELECT * FROM megolm_inbound_sessions WHERE device_id=?', (self.device_id,) ) for row in rows: - session = olm.InboundGroupSession.from_pickle(bytes(row[2]), self.pickle_key) - sessions[row[0]][row[1]][session.id] = session + session = olm.InboundGroupSession.from_pickle( + bytes(row['session']), self.pickle_key) + sessions[row['room_id']][row['curve_key']][session.id] = session c.close() def get_inbound_session(self, room_id, curve_key, session_id, sessions=None): @@ -236,11 +241,11 @@ def get_inbound_session(self, room_id, curve_key, session_id, sessions=None): room_id (str): The room corresponding to the session. curve_key (str): The curve25519 key of the device. session_id (str): The id of the session. - sessions (dict): Optional. A map from session id to olm.InboundGroupSession - object, to which the session will be added. + sessions (dict): Optional. A map from session id to + ``olm.InboundGroupSession`` object, to which the session will be added. Returns: - olm.InboundGroupSession object, or None if the session was not found. + ``olm.InboundGroupSession`` object, or ``None`` if the session was not found. """ c = self.conn.cursor() c.execute( @@ -249,7 +254,7 @@ def get_inbound_session(self, room_id, curve_key, session_id, sessions=None): (self.device_id, room_id, curve_key, session_id) ) try: - session_data = c.fetchone()[0] + session_data = c.fetchone()['session'] session_data = bytes(session_data) except TypeError: return None @@ -285,8 +290,8 @@ def load_outbound_sessions(self, sessions): Also loads the devices each are shared with. Args: - sessions (dict): A map from room_id to a :class:`.MegolmOutboundSession` - object, which will be populated. + sessions (dict): A map from room_id to a ``MegolmOutboundSession`` object, + which will be populated. """ c = self.conn.cursor() rows = c.execute( @@ -294,14 +299,16 @@ def load_outbound_sessions(self, sessions): for row in rows.fetchall(): device_ids = c.execute( 'SELECT user_device_id FROM megolm_outbound_devices WHERE device_id=? ' - 'AND room_id=?', (self.device_id, row[0]) + 'AND room_id=?', (self.device_id, row['room_id']) ) devices = {device_id[0] for device_id in device_ids} - max_age_s = row[2] + max_age_s = row['max_age_s'] max_age = timedelta(seconds=max_age_s) session = MegolmOutboundSession.from_pickle( - bytes(row[1]), devices, max_age, row[3], row[4], row[5], self.pickle_key) - sessions[row[0]] = session + bytes(row['session']), devices, max_age, row['max_messages'], + row['creation_time'], row['message_count'], self.pickle_key + ) + sessions[row['room_id']] = session c.close() def get_outbound_session(self, room_id, sessions=None): @@ -326,7 +333,7 @@ def get_outbound_session(self, room_id, sessions=None): ) try: row = c.fetchone() - session_data = bytes(row[0]) + session_data = bytes(row['session']) except TypeError: c.close() return None @@ -336,10 +343,12 @@ def get_outbound_session(self, room_id, sessions=None): ) devices = {device_id[0] for device_id in device_ids} c.close() - max_age_s = row[1] + max_age_s = row['max_age_s'] max_age = timedelta(seconds=max_age_s) session = MegolmOutboundSession.from_pickle( - session_data, devices, max_age, row[2], row[3], row[4], self.pickle_key) + session_data, devices, max_age, row['max_messages'], row['creation_time'], + row['message_count'], self.pickle_key + ) if sessions is not None: sessions[room_id] = session return session @@ -399,9 +408,9 @@ def load_device_keys(self, device_keys): rows = c.execute( 'SELECT * FROM device_keys WHERE device_id=?', (self.device_id,)) for row in rows: - device_keys[row[0]][row[1]] = { - 'ed25519': row[2], - 'curve25519': row[3] + device_keys[row['user_id']][row['user_device_id']] = { + 'ed25519': row['ed_key'], + 'curve25519': row['curve_key'] } c.close() @@ -410,11 +419,14 @@ def get_device_keys(self, user_devices, device_keys=None): Args: user_devices (dict): A map from user ids to a list of device ids. + If no device ids are given for a user, all will be retrieved. device_keys (defaultdict(dict)): Optional. Will be updated with - the retrieved keys. + the retrieved keys. The format is ``{: {: + {'curve25519': , 'ed25519': }``. Returns: - A defaultdict(dict) containing the keys. + A ``defaultdict(dict)`` containing the keys, the format is the same as the + ``device_keys`` argument. """ c = self.conn.cursor() rows = [] @@ -435,9 +447,9 @@ def get_device_keys(self, user_devices, device_keys=None): c.close() result = defaultdict(dict) for row in rows: - result[row[0]][row[1]] = { - 'ed25519': row[2], - 'curve25519': row[3] + result[row['user_id']][row['user_device_id']] = { + 'ed25519': row['ed_key'], + 'curve25519': row['curve_key'] } if device_keys is not None and result: device_keys.update(result) @@ -476,7 +488,7 @@ def load_tracked_users(self, tracked_users): c = self.conn.cursor() rows = c.execute( 'SELECT user_id FROM tracked_users WHERE device_id=?', (self.device_id,)) - tracked_users.update(t[0] for t in rows) + tracked_users.update(row['user_id'] for row in rows) c.close() return tracked_users @@ -495,12 +507,12 @@ def get_sync_token(self): """Gets the saved sync token. Returns: - A string corresponding to the token, or None if there wasn't any. + A string corresponding to the token, or ``None`` if there wasn't any. """ c = self.conn.cursor() c.execute('SELECT token FROM sync_tokens WHERE device_id=?', (self.device_id,)) try: - return c.fetchone()[0] + return c.fetchone()['token'] except TypeError: return None finally: From 2d7b271152fbf631830d06545db0381d611aea57 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 25 Jul 2018 17:59:49 +0200 Subject: [PATCH 33/66] better primary keys in crypto store --- matrix_client/crypto/crypto_store.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index 349f875e..db5c954b 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -67,24 +67,25 @@ def create_tables_if_needed(self): FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS megolm_outbound_sessions( - device_id TEXT, room_id TEXT PRIMARY KEY, session BLOB, max_age_s FLOAT, + device_id TEXT, room_id TEXT, session BLOB, max_age_s FLOAT, max_messages INTEGER, creation_time TIMESTAMP, message_count INTEGER, + PRIMARY KEY(device_id, room_id), FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS megolm_outbound_devices( device_id TEXT, room_id TEXT, user_device_id TEXT, - UNIQUE(device_id, room_id, user_device_id), - FOREIGN KEY(room_id) REFERENCES megolm_outbound_sessions(room_id) ON DELETE CASCADE, - FOREIGN KEY(device_id) REFERENCES accounts(device_id) + PRIMARY KEY(device_id, room_id, user_device_id), + FOREIGN KEY(device_id, room_id) REFERENCES + megolm_outbound_sessions(device_id, room_id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS device_keys( - device_id TEXT, user_id TEXT, user_device_id TEXT PRIMARY KEY, ed_key TEXT, - curve_key TEXT, + device_id TEXT, user_id TEXT, user_device_id TEXT, ed_key TEXT, + curve_key TEXT, PRIMARY KEY(device_id, user_id, user_device_id), FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS tracked_users( device_id TEXT, user_id TEXT, - UNIQUE(device_id, user_id), + PRIMARY KEY(device_id, user_id), FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS sync_tokens( From cddd13a8a5f52f2483e12fff6dc5645564a4cbd3 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 3 Aug 2018 15:33:26 +0200 Subject: [PATCH 34/66] restore device ID from user ID in CryptoStore A side-effect is that this removes the ability to store E2E data for different devices of the same user in the DB. It shouldn't be much of a problem as it is easy to use multiple DB files for different instances of MatrixClient. Signed-off-by: Valentin Deniaud --- matrix_client/client.py | 37 ++++++++++++++++--- matrix_client/crypto/crypto_store.py | 55 +++++++++++++++++++++++----- matrix_client/crypto/olm_device.py | 15 ++++++-- test/crypto/crypto_store_test.py | 28 +++++++++++++- 4 files changed, 115 insertions(+), 20 deletions(-) diff --git a/matrix_client/client.py b/matrix_client/client.py index 8f80048e..b2537343 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from .api import MatrixHttpApi +from .checks import check_user_id from .errors import MatrixRequestError, MatrixUnexpectedResponse from .room import Room from .user import User @@ -63,6 +64,9 @@ class MatrixClient(object): encryption_conf (dict): Optional. Configuration parameters for encryption. Refer to :func:`~matrix_client.crypto.olm_device.OlmDevice` for supported options, since it will be passed to this class. + restore_device_id (bool): Optional. Only valid when encryption is enabled. When + turned on, the device ID corresponding to the user ID will be retrieved from + the encryption database, if it exists. Returns: `MatrixClient` @@ -111,7 +115,8 @@ def global_callback(incoming_event): def __init__(self, base_url, token=None, user_id=None, valid_cert_check=True, sync_filter_limit=20, - cache_level=CACHE.ALL, encryption=False, encryption_conf=None): + cache_level=CACHE.ALL, encryption=False, encryption_conf=None, + restore_device_id=False): if user_id: warn( "user_id is deprecated. " @@ -121,6 +126,9 @@ def __init__(self, base_url, token=None, user_id=None, if encryption and not ENCRYPTION_SUPPORT: raise ValueError("Failed to enable encryption. Please make sure the olm " "library is available.") + if restore_device_id and not encryption: + raise ValueError("restore_device_id only makes sense when encryption is " + "enabled.") self.api = MatrixHttpApi(base_url, token) self.api.validate_certificate(valid_cert_check) @@ -134,6 +142,7 @@ def __init__(self, base_url, token=None, user_id=None, self.encryption_conf = encryption_conf or {} self.olm_device = None self.first_sync = True + self.restore_device_id = restore_device_id if isinstance(cache_level, CACHE): self._cache_level = cache_level else: @@ -266,8 +275,11 @@ def login(self, username, password, limit=10, sync=True, device_id=None): limit (int): Deprecated. How many messages to return when syncing. This will be replaced by a filter API in a later release. sync (bool): Optional. Whether to initiate a /sync request after logging in. - device_id (str): Optional. ID of the client device. The server will - auto-generate a device_id if this is not specified. + device_id (str): Optional. ID of the client device. If it is not specified, + the server will auto-generate one, or it may be retrieved + from database if ``restore_device_id`` is ``True``. If it is specified, + and ``restore_device_id`` is ``True``, the eventual encryption keys stored + along with a previous device ID of the current user are discarded. Returns: str: Access token @@ -275,6 +287,20 @@ def login(self, username, password, limit=10, sync=True, device_id=None): Raises: MatrixRequestError """ + if not device_id and self.restore_device_id: + try: + check_user_id(username) + except ValueError: + raise ValueError("When using restore_device_id, a full user ID " + "must be supplied when logging in.") + try: + self.olm_device = OlmDevice( + self.api, username, **self.encryption_conf) + device_id = self.olm_device.device_id + logger.info('Device ID was sucessfully retrieved from database.') + except ValueError: + pass + response = self.api.login( "m.login.password", user=username, password=password, device_id=device_id ) @@ -285,8 +311,9 @@ def login(self, username, password, limit=10, sync=True, device_id=None): self.device_id = response["device_id"] if self._encryption: - self.olm_device = OlmDevice( - self.api, self.user_id, self.device_id, **self.encryption_conf) + if not self.olm_device: + self.olm_device = OlmDevice( + self.api, self.user_id, self.device_id, **self.encryption_conf) self.olm_device.upload_identity_keys() self.olm_device.upload_one_time_keys() diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index db5c954b..5fdc4dc7 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -17,7 +17,9 @@ class CryptoStore(object): """Manages persistent storage for an OlmDevice. Args: - device_id (str): The device id of the OlmDevice. + user_id (str): The user ID of the OlmDevice. + device_id (str): Optional. The device ID of the OlmDevice. Will be retrieved using + ``user_id`` if not present. db_name (str): Optional. The name of the database file to use. Will be created if necessary. db_path (str): Optional. The path where to store the database file. Defaults to @@ -28,11 +30,13 @@ class CryptoStore(object): """ def __init__(self, - device_id, + user_id, + device_id=None, db_name='crypto.db', db_path=None, app_name='matrix-python-sdk', pickle_key='DEFAULT_KEY'): + self.user_id = user_id self.device_id = device_id data_dir = db_path or user_data_dir(app_name, '') try: @@ -56,7 +60,9 @@ def create_tables_if_needed(self): c = self.conn.cursor() c.executescript(""" PRAGMA foreign_keys = ON; -CREATE TABLE IF NOT EXISTS accounts (device_id TEXT PRIMARY KEY NOT NULL, account BLOB); +CREATE TABLE IF NOT EXISTS accounts( + device_id TEXT NOT NULL UNIQUE, account BLOB, user_id TEXT PRIMARY KEY NOT NULL +); CREATE TABLE IF NOT EXISTS olm_sessions( device_id TEXT, session_id TEXT PRIMARY KEY, curve_key TEXT, session BLOB, FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE @@ -104,25 +110,56 @@ def save_olm_account(self, account): """ account_data = account.pickle(self.pickle_key) c = self.conn.cursor() - c.execute('INSERT OR IGNORE INTO accounts (device_id, account) VALUES (?,?)', - (self.device_id, account_data)) + c.execute( + 'INSERT OR IGNORE INTO accounts (device_id, account, user_id) VALUES (?,?,?)', + (self.device_id, account_data, self.user_id) + ) c.execute('UPDATE accounts SET account=? WHERE device_id=?', (account_data, self.device_id)) c.close() self.conn.commit() + def replace_olm_account(self, account): + """Replace an Olm account. + + Instead of updating it as done with :meth:`save_olm_account`, this saves the + new account and discards all data associated with the previous one. + + Args: + account (olm.Account): The account object to save. + """ + account_data = account.pickle(self.pickle_key) + c = self.conn.cursor() + c.execute('REPLACE INTO accounts (device_id, account, user_id) VALUES (?,?,?)', + (self.device_id, account_data, self.user_id)) + c.close() + self.conn.commit() + def get_olm_account(self): """Gets the Olm account. Returns: ``olm.Account`` object, or ``None`` if it wasn't found for the current device_id. + + Raises: + ``ValueError`` if ``device_id`` was ``None`` and couldn't be retrieved. """ c = self.conn.cursor() - c.execute( - 'SELECT account FROM accounts WHERE device_id=?', (self.device_id,)) + if self.device_id: + c.execute( + 'SELECT account, device_id FROM accounts WHERE user_id=? AND device_id=?', + (self.user_id, self.device_id) + ) + else: + c.execute('SELECT account, device_id FROM accounts WHERE user_id=?', + (self.user_id,)) + row = c.fetchone() + if not row and not self.device_id: + raise ValueError('Failed to retrieve device_id.') try: - account_data = c.fetchone()['account'] + self.device_id = row['device_id'] + account_data = row['account'] # sqlite gives us unicode in Python2, we want bytes account_data = bytes(account_data) except TypeError: @@ -138,7 +175,7 @@ def remove_olm_account(self): (keys, sessions...) """ c = self.conn.cursor() - c.execute('DELETE FROM accounts WHERE device_id=?', (self.device_id,)) + c.execute('DELETE FROM accounts WHERE user_id=?', (self.user_id,)) c.close() def save_olm_session(self, curve_key, session): diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 47e00c54..e1aacfba 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -22,7 +22,8 @@ class OlmDevice(object): Args: api (MatrixHttpApi): The api object used to make requests. user_id (str): Matrix user ID. Must match the one used when logging in. - device_id (str): Must match the one used when logging in. + device_id (str): Optional. Must match the one used when logging in. If absent, + attempt to retrieve it from database using ``user_id``. signed_keys_proportion (float): Optional. The proportion of signed one-time keys we should maintain on the HS compared to unsigned keys. The maximum value of ``1`` means only signed keys will be uploaded, while the minimum value of @@ -41,6 +42,10 @@ class OlmDevice(object): load_all (bool): Optional. If True, all content of the database for the current device will be loaded at once. This will increase runtime performance but also launch time and memory usage. + + Raises: + ``ValueError`` if ``device_id`` was not given and couldn't be retrieved + from database. """ _olm_algorithm = 'm.olm.v1.curve25519-aes-sha2' @@ -50,7 +55,7 @@ class OlmDevice(object): def __init__(self, api, user_id, - device_id, + device_id=None, signed_keys_proportion=1, keys_threshold=0.1, Store=CryptoStore, @@ -65,12 +70,14 @@ def __init__(self, self.user_id = user_id self.device_id = device_id conf = store_conf or {} - self.db = Store(self.device_id, **conf) + self.db = Store(user_id, device_id=device_id, **conf) self.olm_sessions = defaultdict(list) self.megolm_inbound_sessions = defaultdict(lambda: defaultdict(dict)) self.megolm_outbound_sessions = {} self.device_keys = defaultdict(dict) self.olm_account = self.db.get_olm_account() + if not device_id: + self.device_id = self.db.device_id if self.olm_account: if load_all: self.db.load_olm_sessions(self.olm_sessions) @@ -80,7 +87,7 @@ def __init__(self, logger.info('Loaded Olm account from database for device %s.', device_id) else: self.olm_account = olm.Account() - self.db.save_olm_account(self.olm_account) + self.db.replace_olm_account(self.olm_account) logger.info('Created new Olm account for device %s.', device_id) self.identity_keys = self.olm_account.identity_keys # Try to maintain half the number of one-time keys libolm can hold uploaded diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index 83a07b6e..d9abda4d 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -27,11 +27,13 @@ class TestCryptoStore(object): 'db_name': db_name, 'db_path': db_path } - store = CryptoStore(device_id, db_path=db_path, db_name=db_name) + store = CryptoStore( + user_id, device_id=device_id, db_path=db_path, db_name=db_name) db_filepath = os.path.join(db_path, db_name) assert os.path.exists(db_filepath) store.close() - store = CryptoStore(device_id, db_path=db_path, db_name='test.db') + store = CryptoStore( + user_id, device_id=device_id, db_path=db_path, db_name=db_name) @pytest.fixture(autouse=True, scope='class') def cleanup(self): @@ -63,15 +65,37 @@ def test_olm_account_persistence(self): saved_account = self.store.get_olm_account() assert saved_account is None + # Try to load inexisting account without device_id + self.store.device_id = None + with pytest.raises(ValueError): + self.store.get_olm_account() + self.store.device_id = self.device_id + # Save and load self.store.save_olm_account(account) saved_account = self.store.get_olm_account() assert saved_account.identity_keys == identity_keys + # Save and load without device_id + self.store.save_olm_account(account) + self.store.device_id = None + saved_account = self.store.get_olm_account() + assert saved_account.identity_keys == identity_keys + assert self.store.device_id == self.device_id + + # Replace the account, causing foreign keys to be deleted + self.store.save_sync_token('test') + self.store.replace_olm_account(account) + assert self.store.get_sync_token() is None + # Load the account from an OlmDevice device = OlmDevice(None, self.user_id, self.device_id, store_conf=self.store_conf) assert device.olm_account.identity_keys == account.identity_keys + # Load the account from an OlmDevice, without device_id + device = OlmDevice(None, self.user_id, store_conf=self.store_conf) + assert device.device_id == self.device_id + def test_olm_sessions_persistence(self, account, curve_key, device): session = olm.OutboundSession(account, curve_key, curve_key) sessions = defaultdict(list) From bb10897f778dadbd7ae368dc421a1bf2e45c96ae Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 7 Aug 2018 18:09:08 +0200 Subject: [PATCH 35/66] turn on SQLite secure_delete --- matrix_client/crypto/crypto_store.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index 5fdc4dc7..293b64ff 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -59,6 +59,7 @@ def create_tables_if_needed(self): """Ensures all the tables exist.""" c = self.conn.cursor() c.executescript(""" +PRAGMA secure_delete = ON; PRAGMA foreign_keys = ON; CREATE TABLE IF NOT EXISTS accounts( device_id TEXT NOT NULL UNIQUE, account BLOB, user_id TEXT PRIMARY KEY NOT NULL From 9567772a7ba3faa382390ae4216fe1ca8076c643 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 6 Jul 2018 17:01:46 +0200 Subject: [PATCH 36/66] add m.file missing required key --- matrix_client/api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/matrix_client/api.py b/matrix_client/api.py index 102d3e5c..5e381ca4 100644 --- a/matrix_client/api.py +++ b/matrix_client/api.py @@ -329,7 +329,7 @@ def redact_event(self, room_id, event_id, reason=None, txn_id=None, timestamp=No # content_type can be a image,audio or video # extra information should be supplied, see # https://matrix.org/docs/spec/r0.0.1/client_server.html - def send_content(self, room_id, item_url, item_name, msg_type, + def send_content(self, room_id, item_url, item_name, msg_type, filename=None, extra_information=None, timestamp=None): if extra_information is None: extra_information = {} @@ -340,6 +340,8 @@ def send_content(self, room_id, item_url, item_name, msg_type, "body": item_name, "info": extra_information } + if msg_type == "m.file": + content_pack["filename"] = filename or item_name return self.send_message_event(room_id, "m.room.message", content_pack, timestamp=timestamp) From 77d36f4a870b1fae8e8d30720de81ce1acfcc38c Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 6 Jul 2018 20:30:48 +0200 Subject: [PATCH 37/66] add encrypted attachments dependencies --- docs/source/conf.py | 3 ++- setup.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 8e63b535..b7bc465a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -96,4 +96,5 @@ 'Miscellaneous'), ] -autodoc_mock_imports = ["olm", "canonicaljson", "appdirs"] +autodoc_mock_imports = ["olm", "canonicaljson", "appdirs", "unpaddedbase64", "Crypto", + "Crypto.Cipher", "Crypto.Hash", "Crypto.Util"] diff --git a/setup.py b/setup.py index f555d4c0..6b809d0b 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,8 @@ def exec_file(names): 'test': ['pytest', 'responses'], 'doc': ['Sphinx==1.4.6', 'sphinx-rtd-theme==0.1.9', 'sphinxcontrib-napoleon==0.5.3'], 'format': ['flake8'], - 'e2e': ['python-olm==dev', 'canonicaljson', 'appdirs'] + 'e2e': ['python-olm==dev', 'canonicaljson', 'appdirs', 'unpaddedbase64', + 'pycrypto'] }, dependency_links=[ 'git+https://github.com/poljar/python-olm.git@4752eb22f005cb9f6143857008572e6d83252841#egg=python-olm-dev' From db8f3788b931342d6191ff7aef8b1cd710b49fdb Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 6 Jul 2018 20:48:24 +0200 Subject: [PATCH 38/66] encrypted attachments support --- docs/source/matrix_client.rst | 5 ++ matrix_client/crypto/encrypt_attachments.py | 80 +++++++++++++++++++++ test/crypto/encrypted_attachments_test.py | 15 ++++ 3 files changed, 100 insertions(+) create mode 100644 matrix_client/crypto/encrypt_attachments.py create mode 100644 test/crypto/encrypted_attachments_test.py diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index e008928f..a3dd2e47 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -71,3 +71,8 @@ matrix_client.crypto :members: :undoc-members: :show-inheritance: + +.. automodule:: matrix_client.crypto.encrypt_attachments + :members: + :undoc-members: + :show-inheritance: diff --git a/matrix_client/crypto/encrypt_attachments.py b/matrix_client/crypto/encrypt_attachments.py new file mode 100644 index 00000000..aa6ed62c --- /dev/null +++ b/matrix_client/crypto/encrypt_attachments.py @@ -0,0 +1,80 @@ +import unpaddedbase64 +from Crypto.Cipher import AES +from Crypto.Util import Counter +from Crypto import Random +from Crypto.Hash import SHA256 + + +def encrypt_attachment(plaintext): + """Encrypt a plaintext in order to send it as an encrypted attachment. + + Args: + plaintext (bytes): The data to encrypt. + + Returns: + A tuple of the ciphertext bytes and a dict containing the info needed + to decrypt data. The keys are: + + | key: AES-CTR JWK key object. + | iv: Base64 encoded 16 byte AES-CTR IV. + | hashes.sha256: Base64 encoded SHA-256 hash of the ciphertext. + """ + # 8 bytes IV + iv = Random.new().read(8) + # 8 bytes counter, prefixed by the IV + ctr = Counter.new(64, prefix=iv, initial_value=0) + key = Random.new().read(32) + cipher = AES.new(key, AES.MODE_CTR, counter=ctr) + ciphertext = cipher.encrypt(plaintext) + h = SHA256.new() + h.update(ciphertext) + digest = h.digest() + json_web_key = { + 'kty': 'oct', + 'alg': 'A256CTR', + 'ext': True, + 'k': unpaddedbase64.encode_base64(key, urlsafe=True), + 'key_ops': ['encrypt', 'decrypt'] + } + keys = { + 'v': 'v2', + 'key': json_web_key, + # Send IV concatenated with counter + 'iv': unpaddedbase64.encode_base64(iv + b'\x00' * 8), + 'hashes': { + 'sha256': unpaddedbase64.encode_base64(digest), + } + } + return ciphertext, keys + + +def decrypt_attachment(ciphertext, info): + """Decrypt an encrypted attachment. + + Args: + ciphertext (bytes): The data to decrypt. + info (dict): The information needed to decrypt the attachment. + + | key: AES-CTR JWK key object. + | iv: Base64 encoded 16 byte AES-CTR IV. + | hashes.sha256: Base64 encoded SHA-256 hash of the ciphertext. + + Returns: + The plaintext bytes. + + Raises: + RuntimeError if the integrity check fails. + """ + expected_hash = unpaddedbase64.decode_base64(info['hashes']['sha256']) + h = SHA256.new() + h.update(ciphertext) + if h.digest() != expected_hash: + raise RuntimeError('Mismatched SHA-256 digest.') + + key = unpaddedbase64.decode_base64(info['key']['k']) + # Drop last 8 bytes, which are 0 + iv = unpaddedbase64.decode_base64(info['iv'])[:8] + ctr = Counter.new(64, prefix=iv, initial_value=0) + cipher = AES.new(key, AES.MODE_CTR, counter=ctr) + + return cipher.decrypt(ciphertext) diff --git a/test/crypto/encrypted_attachments_test.py b/test/crypto/encrypted_attachments_test.py new file mode 100644 index 00000000..ee13cf9f --- /dev/null +++ b/test/crypto/encrypted_attachments_test.py @@ -0,0 +1,15 @@ +import pytest +pytest.importorskip('olm') # noqa + +from matrix_client.crypto.encrypt_attachments import (encrypt_attachment, + decrypt_attachment) + + +def test_encrypt_decrypt(): + message = b'test' + ciphertext, info = encrypt_attachment(message) + assert decrypt_attachment(ciphertext, info) == message + + ciphertext += b'\x00' + with pytest.raises(RuntimeError): + decrypt_attachment(ciphertext, info) From b236ea21afc236d2def8a7b977c2f3a08369d256 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 6 Jul 2018 20:26:23 +0200 Subject: [PATCH 39/66] plug-in encrypted attachments --- matrix_client/api.py | 8 ++++++-- matrix_client/room.py | 5 +++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/matrix_client/api.py b/matrix_client/api.py index 5e381ca4..d89dd66c 100644 --- a/matrix_client/api.py +++ b/matrix_client/api.py @@ -330,18 +330,22 @@ def redact_event(self, room_id, event_id, reason=None, txn_id=None, timestamp=No # extra information should be supplied, see # https://matrix.org/docs/spec/r0.0.1/client_server.html def send_content(self, room_id, item_url, item_name, msg_type, filename=None, - extra_information=None, timestamp=None): + extra_information=None, timestamp=None, encryption_info=None): if extra_information is None: extra_information = {} content_pack = { - "url": item_url, "msgtype": msg_type, "body": item_name, "info": extra_information } if msg_type == "m.file": content_pack["filename"] = filename or item_name + if encryption_info: + encryption_info['url'] = item_url + content_pack['file'] = encryption_info + else: + content_pack['url'] = item_url return self.send_message_event(room_id, "m.room.message", content_pack, timestamp=timestamp) diff --git a/matrix_client/room.py b/matrix_client/room.py index 75b891e3..c2a45623 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -159,7 +159,7 @@ def send_emote(self, text): else: return self.client.api.send_emote(self.room_id, text) - def send_file(self, url, name, **fileinfo): + def send_file(self, url, name, encryption_info=None, **fileinfo): """Send a pre-uploaded file to the room. See http://matrix.org/docs/spec/r0.2.0/client_server.html#m-file for @@ -173,7 +173,8 @@ def send_file(self, url, name, **fileinfo): return self.client.api.send_content( self.room_id, url, name, "m.file", - extra_information=fileinfo + extra_information=fileinfo, + encryption_info=encryption_info ) def send_notice(self, text): From 8d4f2848ba2f7e2f33a940ffb3b7a2bf34025c3d Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 16 Aug 2018 14:48:50 +0200 Subject: [PATCH 40/66] automatically send encrypted m.file messages --- matrix_client/api.py | 10 ++++++++-- matrix_client/room.py | 17 +++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/matrix_client/api.py b/matrix_client/api.py index d89dd66c..479aa6fe 100644 --- a/matrix_client/api.py +++ b/matrix_client/api.py @@ -331,6 +331,13 @@ def redact_event(self, room_id, event_id, reason=None, txn_id=None, timestamp=No # https://matrix.org/docs/spec/r0.0.1/client_server.html def send_content(self, room_id, item_url, item_name, msg_type, filename=None, extra_information=None, timestamp=None, encryption_info=None): + content_pack = self.get_content_body(item_url, item_name, msg_type, filename, + extra_information, encryption_info) + return self.send_message_event(room_id, "m.room.message", content_pack, + timestamp=timestamp) + + def get_content_body(self, item_url, item_name, msg_type, filename=None, + extra_information=None, encryption_info=None): if extra_information is None: extra_information = {} @@ -346,8 +353,7 @@ def send_content(self, room_id, item_url, item_name, msg_type, filename=None, content_pack['file'] = encryption_info else: content_pack['url'] = item_url - return self.send_message_event(room_id, "m.room.message", content_pack, - timestamp=timestamp) + return content_pack def get_location_body(self, geo_uri, name, thumb_url=None, thumb_info=None): content_pack = { diff --git a/matrix_client/room.py b/matrix_client/room.py index c2a45623..d4a1ff40 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -170,12 +170,17 @@ def send_file(self, url, name, encryption_info=None, **fileinfo): name (str): The filename of the image. fileinfo (): Extra information about the file """ - - return self.client.api.send_content( - self.room_id, url, name, "m.file", - extra_information=fileinfo, - encryption_info=encryption_info - ) + if self.encrypted and self.client._encryption: + content = self.client.api.get_content_body( + url, name, "m.file", extra_information=fileinfo, + encryption_info=encryption_info + ) + return self.send_encrypted(content) + else: + return self.client.api.send_content( + self.room_id, url, name, "m.file", + extra_information=fileinfo + ) def send_notice(self, text): """Send a notice (from bot) message to the room.""" From 656005bf510b5d203b0467cd547ba26760b4c44c Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sun, 15 Jul 2018 12:38:08 +0200 Subject: [PATCH 41/66] add Device class --- matrix_client/device.py | 54 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 matrix_client/device.py diff --git a/matrix_client/device.py b/matrix_client/device.py new file mode 100644 index 00000000..4fde19ea --- /dev/null +++ b/matrix_client/device.py @@ -0,0 +1,54 @@ +from .errors import MatrixRequestError + + +class Device(object): + + def __init__(self, + api, + device_id, + display_name=None, + last_seen_ip=None, + last_seen_ts=None, + verified=False, + blacklisted=False, + ignored=False, + ed25519_key=None, + curve25519_key=None): + self.api = api + self.device_id = device_id + self.display_name = display_name + self.last_seen_ts = last_seen_ts + self.last_seen_ip = last_seen_ip + self.verified = verified + self.blacklisted = blacklisted + self.ignored = ignored + self._ed25519 = ed25519_key + self._curve25519 = curve25519_key + + def get_info(self): + """Gets information on the device. + + The ``display_name``, ``last_seen_ip`` and ``last_seen_ts`` attribute will + get updated, if these were available. + + Returns: + True if successful, False if the device was not found. + """ + try: + info = self.api.get_device(self.device_id) + except MatrixRequestError as e: + if e.code == 404: + return False + raise + self.display_name = info.get('display_name') + self.last_seen_ip = info.get('last_seen_ip') + self.last_seen_ts = info.get('last_seen_ts') + return True + + @property + def ed25519(self): + return self._ed25519 + + @property + def curve25519(self): + return self._curve25519 From 6bce7b4c86489811a076ef7caa5a53c8f8cb14d3 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sun, 22 Jul 2018 17:22:52 +0200 Subject: [PATCH 42/66] add devices attribute to User --- matrix_client/user.py | 45 ++++++++++++++++++++++++++++++++++++------- test/user_test.py | 2 +- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/matrix_client/user.py b/matrix_client/user.py index e56a89ef..2dfd0299 100644 --- a/matrix_client/user.py +++ b/matrix_client/user.py @@ -15,17 +15,19 @@ from warnings import warn from .checks import check_user_id +from .device import Device class User(object): """ The User class can be used to call user specific functions. """ - def __init__(self, api, user_id, displayname=None): + def __init__(self, client, user_id, displayname=None): check_user_id(user_id) self.user_id = user_id self.displayname = displayname - self.api = api + self.client = client + self._devices = {} def get_display_name(self, room=None): """Get this user's display name. @@ -43,7 +45,7 @@ def get_display_name(self, room=None): except KeyError: return self.user_id if not self.displayname: - self.displayname = self.api.get_display_name(self.user_id) + self.displayname = self.client.api.get_display_name(self.user_id) return self.displayname or self.user_id def get_friendly_name(self): @@ -59,13 +61,13 @@ def set_display_name(self, display_name): display_name (str): Display Name """ self.displayname = display_name - return self.api.set_display_name(self.user_id, display_name) + return self.client.api.set_display_name(self.user_id, display_name) def get_avatar_url(self): - mxcurl = self.api.get_avatar_url(self.user_id) + mxcurl = self.client.api.get_avatar_url(self.user_id) url = None if mxcurl is not None: - url = self.api.get_download_url(mxcurl) + url = self.client.api.get_download_url(mxcurl) return url def set_avatar_url(self, avatar_url): @@ -74,4 +76,33 @@ def set_avatar_url(self, avatar_url): Args: avatar_url (str): mxc url from previously uploaded """ - return self.api.set_avatar_url(self.user_id, avatar_url) + return self.client.api.set_avatar_url(self.user_id, avatar_url) + + @property + def devices(self): + # If this user is joined in an encrypted room with us, we may already have an + # up-to-date list of their devices. + if self.client._encryption and \ + self.user_id in self.client.olm_device.device_list.tracked_user_ids: + + if self.user_id not in self.client.device_keys: + self.client.db.get_device_keys( + self.client.api, {self.user_id: []}, self.client.device_keys + ) + self._devices = self.client.device_keys[self.user_id] + else: + devices = self.client.api.query_keys({self.user_id: []})["device_keys"] + for device_id in devices: + if device_id not in self._devices: + # Do not add the keys even if they are in the payload, because + # we are not able to verify them right know. This means that device + # verification will only become available once we share an encrypted + # room with this user. + self._devices[device_id] = Device(self.client.api, device_id) + + for device in self._devices: + device.get_info() + + # Returning a copy prevents adding/removing devices while allowing to verify or + # blacklist them. + return self._devices.copy() diff --git a/test/user_test.py b/test/user_test.py index db5bae82..beae6b3f 100644 --- a/test/user_test.py +++ b/test/user_test.py @@ -15,7 +15,7 @@ class TestUser: @pytest.fixture() def user(self): - return User(self.cli.api, self.user_id) + return User(self.cli, self.user_id) @pytest.fixture() def room(self): From 0c766557ca77405b7739e0a6078cbcd6e9f1638f Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sun, 15 Jul 2018 12:45:32 +0200 Subject: [PATCH 43/66] make OlmDevice subclass Device --- matrix_client/crypto/olm_device.py | 35 +++++++++++++++++------------- test/crypto/device_list_test.py | 2 +- test/crypto/olm_device_test.py | 2 +- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index e1aacfba..4696cff9 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -6,6 +6,7 @@ from canonicaljson import encode_canonical_json from matrix_client.checks import check_user_id +from matrix_client.device import Device from matrix_client.crypto.one_time_keys import OneTimeKeysManager from matrix_client.crypto.device_list import DeviceList from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession @@ -14,7 +15,7 @@ logger = logging.getLogger(__name__) -class OlmDevice(object): +class OlmDevice(Device): """Manages the Olm cryptographic functions. Has a unique Olm account which holds identity keys. @@ -68,7 +69,6 @@ def __init__(self, self.api = api check_user_id(user_id) self.user_id = user_id - self.device_id = device_id conf = store_conf or {} self.db = Store(user_id, device_id=device_id, **conf) self.olm_sessions = defaultdict(list) @@ -77,7 +77,7 @@ def __init__(self, self.device_keys = defaultdict(dict) self.olm_account = self.db.get_olm_account() if not device_id: - self.device_id = self.db.device_id + device_id = self.db.device_id if self.olm_account: if load_all: self.db.load_olm_sessions(self.olm_sessions) @@ -89,7 +89,6 @@ def __init__(self, self.olm_account = olm.Account() self.db.replace_olm_account(self.olm_account) logger.info('Created new Olm account for device %s.', device_id) - self.identity_keys = self.olm_account.identity_keys # Try to maintain half the number of one-time keys libolm can hold uploaded # on the HS. This is because some keys will be claimed by peers but not # used instantly, and we want them to stay in libolm, until the limit is reached @@ -100,6 +99,11 @@ def __init__(self, keys_threshold) self.device_list = DeviceList(self, api, self.device_keys, self.db) self.megolm_index_record = defaultdict(dict) + keys = self.olm_account.identity_keys + super(OlmDevice, self).__init__(self.api, + device_id, + ed25519_key=keys['ed25519'], + curve25519_key=keys['curve25519']) def upload_identity_keys(self): """Uploads this device's identity keys to HS. @@ -110,8 +114,10 @@ def upload_identity_keys(self): 'user_id': self.user_id, 'device_id': self.device_id, 'algorithms': self._algorithms, - 'keys': {'{}:{}'.format(alg, self.device_id): key - for alg, key in self.identity_keys.items()} + 'keys': { + 'curve25519:{}'.format(self.device_id): self.curve25519, + 'ed25519:{}'.format(self.device_id): self.ed25519 + } } self.sign_json(device_keys) ret = self.api.upload_keys(device_keys=device_keys) @@ -268,7 +274,7 @@ def olm_build_encrypted_event(self, event_type, content, user_id, device_id): 'sender': self.user_id, 'sender_device': self.device_id, 'keys': { - 'ed25519': self.identity_keys['ed25519'] + 'ed25519': self.ed25519 }, 'recipient': user_id, 'recipient_keys': { @@ -293,7 +299,7 @@ def olm_build_encrypted_event(self, event_type, content, user_id, device_id): event = { 'algorithm': self._olm_algorithm, - 'sender_key': self.identity_keys['curve25519'], + 'sender_key': self.curve25519, 'ciphertext': ciphertext_payload } return event @@ -320,7 +326,7 @@ def olm_decrypt_event(self, content, user_id): ciphertext = content['ciphertext'] try: - payload = ciphertext[self.identity_keys['curve25519']] + payload = ciphertext[self.curve25519] except KeyError: raise RuntimeError('This message was not encrypted for us.') @@ -343,10 +349,10 @@ def olm_decrypt_event(self, content, user_id): .format(decrypted_event['recipient'], self.user_id, decrypted_event) ) our_key = decrypted_event['recipient_keys']['ed25519'] - if our_key != self.identity_keys['ed25519']: + if our_key != self.ed25519: raise RuntimeError( 'Found key {} instead of ours own ed25519 key {} in Olm plaintext {}.' - .format(our_key, self.identity_keys['ed25519'], decrypted_event) + .format(our_key, self.ed25519, decrypted_event) ) return decrypted_event @@ -462,9 +468,8 @@ def megolm_start_session(self, room): self.db.save_outbound_session(room.room_id, session) self.megolm_share_session(room.room_id, user_devices, session) # Store a corresponding inbound session, so that we can decrypt our own messages - self.megolm_add_inbound_session(room.room_id, self.identity_keys['curve25519'], - session.id, - session.session_key) + self.megolm_add_inbound_session( + room.room_id, self.curve25519, session.id, session.session_key) return session def megolm_share_session(self, room_id, user_devices, session): @@ -561,7 +566,7 @@ def megolm_build_encrypted_event(self, room, event): encrypted_event = { 'algorithm': self._megolm_algorithm, - 'sender_key': self.identity_keys['curve25519'], + 'sender_key': self.curve25519, 'ciphertext': encrypted_payload, 'session_id': session.id, 'device_id': self.device_id diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py index 28222fab..940ba963 100644 --- a/test/crypto/device_list_test.py +++ b/test/crypto/device_list_test.py @@ -57,7 +57,7 @@ def test_download_device_keys(self): alice_device = OlmDevice(self.cli.api, self.alice, 'JLAFKJWSCS') resp = deepcopy(example_key_query_response) resp['device_keys'][self.alice]['JLAFKJWSCS']['keys']['ed25519:JLAFKJWSCS'] = \ - alice_device.identity_keys['ed25519'] + alice_device.ed25519 resp['device_keys'][self.alice]['JLAFKJWSCS'] = \ alice_device.sign_json(resp['device_keys'][self.alice]['JLAFKJWSCS']) responses.add(responses.POST, self.query_url, json=resp) diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 54cceea7..2cb6a957 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -578,7 +578,7 @@ def test_olm_handle_encrypted_event(self): self.device.olm_account.generate_one_time_keys(1) otk = next(iter(self.device.olm_account.one_time_keys['curve25519'].values())) self.device.olm_account.mark_keys_as_published() - sender_key = self.device.identity_keys['curve25519'] + sender_key = self.device.curve25519 session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) alice_device.olm_sessions[sender_key] = [session] From f2b25e886d4c9f7d9db48d30e6495f7d50c677de Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sun, 15 Jul 2018 16:26:53 +0200 Subject: [PATCH 44/66] better device keys handling --- matrix_client/crypto/crypto_store.py | 34 +++++++++--------- matrix_client/crypto/device_list.py | 25 +++++++------ matrix_client/crypto/olm_device.py | 27 +++++++------- test/crypto/crypto_store_test.py | 50 +++++++++++++------------- test/crypto/device_list_test.py | 20 ++++++----- test/crypto/olm_device_test.py | 53 ++++++++++++---------------- 6 files changed, 103 insertions(+), 106 deletions(-) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index 293b64ff..d6d143be 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -9,6 +9,7 @@ from appdirs import user_data_dir from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession +from matrix_client.device import Device logger = logging.getLogger(__name__) @@ -423,37 +424,36 @@ def save_device_keys(self, device_keys): Args: device_keys (defaultdict(dict)): The format is ``{: {: - {'curve25519': , 'ed25519': }``. + Device``. """ c = self.conn.cursor() rows = [] for user_id, devices_dict in device_keys.items(): - for device_id, keys_dict in devices_dict.items(): - rows.append((self.device_id, user_id, device_id, keys_dict['ed25519'], - keys_dict['curve25519'])) + for device_id, device in devices_dict.items(): + rows.append((self.device_id, user_id, device_id, device.ed25519, + device.curve25519)) c.executemany('REPLACE INTO device_keys VALUES (?,?,?,?,?)', rows) c.close() self.conn.commit() - def load_device_keys(self, device_keys): + def load_device_keys(self, api, device_keys): """Loads all saved device keys. Args: device_keys (defaultdict(dict)): An object which will get populated with - the keys. The format is ``{: {: - {'curve25519': , 'ed25519': }``. + the keys. The format is ``{: {: Device}}``. """ c = self.conn.cursor() rows = c.execute( 'SELECT * FROM device_keys WHERE device_id=?', (self.device_id,)) for row in rows: - device_keys[row['user_id']][row['user_device_id']] = { - 'ed25519': row['ed_key'], - 'curve25519': row['curve_key'] - } + device = Device(api, row['user_device_id'], + ed25519_key=row['ed_key'], + curve25519_key=row['curve_key']) + device_keys[row['user_id']][row['user_device_id']] = device c.close() - def get_device_keys(self, user_devices, device_keys=None): + def get_device_keys(self, api, user_devices, device_keys=None): """Gets the devices keys of the specified devices. Args: @@ -461,7 +461,7 @@ def get_device_keys(self, user_devices, device_keys=None): If no device ids are given for a user, all will be retrieved. device_keys (defaultdict(dict)): Optional. Will be updated with the retrieved keys. The format is ``{: {: - {'curve25519': , 'ed25519': }``. + Device}}``. Returns: A ``defaultdict(dict)`` containing the keys, the format is the same as the @@ -486,10 +486,10 @@ def get_device_keys(self, user_devices, device_keys=None): c.close() result = defaultdict(dict) for row in rows: - result[row['user_id']][row['user_device_id']] = { - 'ed25519': row['ed_key'], - 'curve25519': row['curve_key'] - } + device = Device(api, row['user_device_id'], + ed25519_key=row['ed_key'], + curve25519_key=row['curve_key']) + result[row['user_id']][row['user_device_id']] = device if device_keys is not None and result: device_keys.update(result) return result diff --git a/matrix_client/crypto/device_list.py b/matrix_client/crypto/device_list.py index c3872f83..37d36b6a 100644 --- a/matrix_client/crypto/device_list.py +++ b/matrix_client/crypto/device_list.py @@ -2,6 +2,7 @@ from collections import defaultdict from threading import Thread, Condition, Event, Lock +from matrix_client.device import Device from matrix_client.errors import MatrixHttpLibError, MatrixRequestError logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ class DeviceList: Args: olm_device (OlmDevice): Will be used to get additional info, such as device id. api (MatrixHttpApi): The api object used to make requests. - device_keys (defaultdict(dict)): A map from user to device to keys. + device_keys (defaultdict(dict)): A map from user to device id to Device. """ def __init__(self, olm_device, api, device_keys, db): @@ -58,7 +59,7 @@ def get_room_device_keys(self, room, blocking=True): members = {m.user_id for m in room.get_joined_members()} missing_members = {m: [] for m in members if not self.device_keys[m]} if missing_members: - self.db.get_device_keys(missing_members, self.device_keys) + self.db.get_device_keys(self.api, missing_members, self.device_keys) user_ids = members - self.tracked_user_ids if not user_ids: logger.info('Already had all the keys in room %s.', room.room_id) @@ -199,18 +200,22 @@ def _download_device_keys(self, user_devices, since_token=None): logger.warning('Signature verification failed for device %s of ' 'user %s.', device_id, user_id) continue - keys = self.device_keys[user_id].setdefault(device_id, {}) - if keys: - if keys['ed25519'] != signing_key: + devices = self.device_keys[user_id] + try: + device = devices[device_id] + except KeyError: + devices[device_id] = Device(self.api, device_id, + curve25519_key=curve_key, + ed25519_key=signing_key) + else: + if device.ed25519 != signing_key: logger.warning('Ed25519 key has changed for device %s of ' 'user %s.', device_id, user_id) continue - if keys['curve25519'] == curve_key: + if device.curve25519 == curve_key: continue - else: - keys['ed25519'] = signing_key - keys['curve25519'] = curve_key - changed[user_id][device_id] = keys + device._curve25519 = curve_key + changed[user_id][device_id] = devices[device_id] logger.info('Successfully downloaded keys for devices: %s.', {user_id: list(changed[user_id]) for user_id in changed}) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 4696cff9..d5a8564b 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -83,7 +83,7 @@ def __init__(self, self.db.load_olm_sessions(self.olm_sessions) self.db.load_inbound_sessions(self.megolm_inbound_sessions) self.db.load_outbound_sessions(self.megolm_outbound_sessions) - self.db.load_device_keys(self.device_keys) + self.db.load_device_keys(self.api, self.device_keys) logger.info('Loaded Olm account from database for device %s.', device_id) else: self.olm_account = olm.Account() @@ -221,23 +221,23 @@ def olm_start_sessions(self, user_devices): for user_id in user_devices: for device_id, one_time_key in keys.get(user_id, {}).items(): try: - device_keys = self.device_keys[user_id][device_id] + device = self.device_keys[user_id][device_id] except KeyError: logger.warning('Key for device %s of user %s not found, could not ' 'start Olm session.', device_id, user_id) continue key_object = next(iter(one_time_key.values())) verified = self.verify_json(key_object, - device_keys['ed25519'], + device.ed25519, user_id, device_id) if verified: session = olm.OutboundSession(self.olm_account, - device_keys['curve25519'], + device.curve25519, key_object['key']) - sessions = self.olm_sessions[device_keys['curve25519']] + sessions = self.olm_sessions[device.curve25519] sessions.append(session) - new_sessions[device_keys['curve25519']].append(session) + new_sessions[device.curve25519].append(session) logger.info('Established Olm session %s with device %s of user ' '%s.', device_id, session.id, user_id) else: @@ -261,13 +261,10 @@ def olm_build_encrypted_event(self, event_type, content, user_id, device_id): The Olm encrypted event, as JSON. """ try: - keys = self.device_keys[user_id][device_id] + device = self.device_keys[user_id][device_id] except KeyError: raise RuntimeError('Device is unknown, could not encrypt.') - signing_key = keys['ed25519'] - identity_key = keys['curve25519'] - payload = { 'type': event_type, 'content': content, @@ -278,20 +275,20 @@ def olm_build_encrypted_event(self, event_type, content, user_id, device_id): }, 'recipient': user_id, 'recipient_keys': { - 'ed25519': signing_key + 'ed25519': device.ed25519 } } - sessions = self.olm_sessions[identity_key] + sessions = self.olm_sessions[device.curve25519] if sessions: session = sorted(sessions, key=lambda s: s.id)[0] else: raise RuntimeError('No session for this device, could not encrypt.') encrypted_message = session.encrypt(json.dumps(payload)) - self.db.save_olm_session(identity_key, session) + self.db.save_olm_session(device.curve25519, session) ciphertext_payload = { - identity_key: { + device.curve25519: { 'type': encrypted_message.message_type, 'body': encrypted_message.ciphertext } @@ -433,7 +430,7 @@ def olm_ensure_sessions(self, user_devices): user_devices_no_session = defaultdict(list) for user_id in user_devices: for device_id in user_devices[user_id]: - curve_key = self.device_keys[user_id][device_id]['curve25519'] + curve_key = self.device_keys[user_id][device_id].curve25519 # Check if we have a list of sessions for this device, which can be # empty. Implicitely, an empty list will indicate that we already tried # to establish a session with a device, but this attempt was diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index d9abda4d..8da0a62f 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -129,7 +129,7 @@ def test_olm_sessions_persistence(self, account, curve_key, device): assert device.olm_sessions[curve_key][0].id == session.id device.olm_sessions.clear() - device.device_keys[self.user_id][self.device_id] = {'curve25519': curve_key} + device.device_keys[self.user_id][self.device_id] = device device.olm_ensure_sessions({self.user_id: [self.device_id]}) assert device.olm_sessions[curve_key][0].id == session.id @@ -237,46 +237,49 @@ def test_megolm_outbound_persistence(self, device): @pytest.mark.usefixtures('account') def test_device_keys_persistence(self, device): user_devices = {self.user_id: [self.device_id]} - keys = { - 'curve25519': 'curve', - 'ed25519': 'ed' - } device_keys = defaultdict(dict) - self.store.load_device_keys(device_keys) + self.store.load_device_keys(None, device_keys) assert not device_keys - assert not self.store.get_device_keys(user_devices, device_keys) + assert not self.store.get_device_keys(None, user_devices, device_keys) assert not device_keys - device_keys_to_save = {self.user_id: {self.device_id: keys}} + device_keys_to_save = {self.user_id: {self.device_id: device}} self.store.save_device_keys(device_keys_to_save) - self.store.load_device_keys(device_keys) - assert device_keys == device_keys_to_save + self.store.load_device_keys(None, device_keys) + assert device_keys[self.user_id][self.device_id].curve25519 == \ + device.curve25519 device_keys.clear() - assert self.store.get_device_keys(user_devices) == device_keys_to_save - assert self.store.get_device_keys(user_devices, device_keys) - assert device_keys == device_keys_to_save + devices = self.store.get_device_keys(None, user_devices)[self.user_id] + assert devices[self.device_id].curve25519 == device.curve25519 + assert self.store.get_device_keys(None, user_devices, device_keys) + assert device_keys[self.user_id][self.device_id].curve25519 == \ + device.curve25519 # Test [] wildcard - assert self.store.get_device_keys({self.user_id: []}) == device_keys_to_save + devices = self.store.get_device_keys(None, {self.user_id: []})[self.user_id] + assert devices[self.device_id].curve25519 == device.curve25519 device.device_list.tracked_user_ids = {self.user_id} device.device_list.get_room_device_keys(self.room) - assert device.device_keys == device_keys_to_save + assert device_keys[self.user_id][self.device_id].curve25519 == \ + device.curve25519 # Test multiples [] device_keys.clear() user_id = 'test' device_id = 'test' - device_keys_to_save[user_id] = {device_id: keys} + device_keys_to_save[user_id] = {device_id: device} self.store.save_device_keys(device_keys_to_save) user_devices[user_id] = [] user_devices[self.user_id] = [] - assert self.store.get_device_keys(user_devices) == device_keys_to_save + device_keys = self.store.get_device_keys(None, user_devices) + assert device_keys[self.user_id][self.device_id].curve25519 == device.curve25519 + assert device_keys[user_id][device_id].curve25519 == device.curve25519 self.store.remove_olm_account() - assert not self.store.get_device_keys(user_devices) + assert not self.store.get_device_keys(None, user_devices) @pytest.mark.usefixtures('account') def test_tracked_users_persistence(self): @@ -308,17 +311,13 @@ def test_sync_token_persistence(self): self.store.save_sync_token(sync_token) assert self.store.get_sync_token() == sync_token - def test_load_all(self, account, curve_key): + def test_load_all(self, account, curve_key, device): curve_key = account.identity_keys['curve25519'] session = olm.OutboundSession(account, curve_key, curve_key) out_session = MegolmOutboundSession() out_session.add_device(self.device_id) in_session = olm.InboundGroupSession(out_session.session_key) - keys = { - 'curve25519': 'curve', - 'ed25519': 'ed' - } - device_keys_to_save = {self.user_id: {self.device_id: keys}} + device_keys_to_save = {self.user_id: {self.device_id: device}} self.store.save_inbound_session(self.room_id, curve_key, in_session) self.store.save_olm_session(curve_key, session) @@ -336,4 +335,5 @@ def test_load_all(self, account, curve_key): saved_out_session = device.megolm_outbound_sessions[self.room_id] assert saved_out_session.id == out_session.id assert saved_out_session.devices == out_session.devices - assert device.device_keys == device_keys_to_save + assert device.device_keys[self.user_id][self.device_id].curve25519 == \ + device.curve25519 diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py index 940ba963..014575f9 100644 --- a/test/crypto/device_list_test.py +++ b/test/crypto/device_list_test.py @@ -9,6 +9,7 @@ from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient +from matrix_client.device import Device from matrix_client.room import User from matrix_client.errors import MatrixRequestError from matrix_client.crypto.device_list import (_OutdatedUsersSet as OutdatedUsersSet, @@ -94,24 +95,27 @@ def test_download_device_keys(self): assert download_device_keys(user_devices) req = json.loads(responses.calls[0].request.body) assert req['device_keys'] == {self.alice: [], bob: [], self.user_id: []} + device = Device(self.cli.api, 'JLAFKJWSCS', + curve25519_key='3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI', + ed25519_key='VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA') expected_device_keys = { self.alice: { - 'JLAFKJWSCS': { - 'curve25519': '3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI', - 'ed25519': 'VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA' - } + 'JLAFKJWSCS': device } } - assert self.device.device_keys == expected_device_keys + assert self.device.device_keys[self.alice]['JLAFKJWSCS'].curve25519 == \ + device.curve25519 # Different curve25519, key should get updated assert download_device_keys(user_devices) - expected_device_keys[self.alice]['JLAFKJWSCS']['curve25519'] = new_id_key - assert self.device.device_keys == expected_device_keys + expected_device_keys[self.alice]['JLAFKJWSCS']._curve25519 = new_id_key + assert self.device.device_keys[self.alice]['JLAFKJWSCS'].curve25519 == \ + device.curve25519 # Different ed25519, key should not get updated assert not download_device_keys(user_devices) - assert self.device.device_keys == expected_device_keys + assert self.device.device_keys[self.alice]['JLAFKJWSCS'].ed25519 == \ + device.ed25519 self.device.device_keys.clear() # All the remaining responses are wrong and we should not add the key diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 2cb6a957..12c45c0b 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -16,6 +16,7 @@ from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient from matrix_client.user import User +from matrix_client.device import Device from test.crypto.dummy_olm_device import OlmDevice from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession from test.response_examples import (example_key_upload_response, @@ -35,10 +36,9 @@ class TestOlmDevice: alice = '@alice:example.com' alice_device_id = 'JLAFKJWSCS' alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' - alice_identity_keys = { - 'curve25519': alice_curve_key, - 'ed25519': '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' - } + alice_ed_key = '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' + alice_device = Device(cli.api, alice_device_id, curve25519_key=alice_curve_key, + ed25519_key=alice_ed_key) alice_olm_session = olm.OutboundSession( device.olm_account, alice_curve_key, alice_curve_key) room = cli._mkroom(room_id) @@ -246,8 +246,7 @@ def test_olm_start_sessions(self): # Cover logging part olm_device.logger.setLevel(logging.WARNING) # Now should be good - self.device.device_keys[self.alice][self.alice_device_id] = \ - self.alice_identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device self.device.olm_start_sessions(user_devices) assert self.device.olm_sessions[self.alice_curve_key] @@ -280,8 +279,7 @@ def test_olm_build_encrypted_event(self): 'm.text', event_content, self.alice, self.alice_device_id) # We don't have a session with Alice - self.device.device_keys[self.alice][self.alice_device_id] = \ - self.alice_identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device with pytest.raises(RuntimeError): self.device.olm_build_encrypted_event( 'm.text', event_content, self.alice, self.alice_device_id) @@ -353,15 +351,14 @@ def test_olm_decrypt_event(self): self.device.device_keys.clear() self.device.olm_sessions.clear() alice_device = OlmDevice(self.device.api, self.alice, self.alice_device_id) - alice_device.device_keys[self.user_id][self.device_id] = self.device.identity_keys - self.device.device_keys[self.alice][self.alice_device_id] = \ - alice_device.identity_keys + alice_device.device_keys[self.user_id][self.device_id] = self.device + self.device.device_keys[self.alice][self.alice_device_id] = alice_device # Artificially start an Olm session from Alice self.device.olm_account.generate_one_time_keys(1) otk = next(iter(self.device.olm_account.one_time_keys['curve25519'].values())) self.device.olm_account.mark_keys_as_published() - sender_key = self.device.identity_keys['curve25519'] + sender_key = self.device.curve25519 session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) alice_device.olm_sessions[sender_key] = [session] @@ -405,19 +402,18 @@ def test_olm_decrypt_event(self): encrypted_event = alice_device.olm_build_encrypted_event( 'example_type', {'content': 'test'}, self.user_id, self.device_id) - backup = self.device.identity_keys['ed25519'] - self.device.identity_keys['ed25519'] = 'wrong' + backup = self.device.ed25519 + self.device._ed25519 = 'wrong' with pytest.raises(RuntimeError): self.device.olm_decrypt_event(encrypted_event, self.alice) - self.device.identity_keys['ed25519'] = backup + self.device._ed25519 = backup @responses.activate def test_olm_ensure_sessions(self): claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' responses.add(responses.POST, claim_url, json=example_claim_keys_response) self.device.olm_sessions.clear() - self.device.device_keys[self.alice][self.alice_device_id] = \ - self.alice_identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device user_devices = {self.alice: [self.alice_device_id]} self.device.olm_ensure_sessions(user_devices) @@ -434,9 +430,9 @@ def test_megolm_share_session(self): to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' responses.add(responses.PUT, to_device_url, json={}) self.device.olm_sessions.clear() - self.device.device_keys[self.alice][self.alice_device_id] = \ - self.alice_identity_keys - self.device.device_keys['dummy']['dummy'] = {'curve25519': 'a', 'ed25519': 'a'} + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + self.device.device_keys['dummy']['dummy'] = \ + Device(self.cli.api, 'dummy', curve25519_key='a', ed25519_key='a') user_devices = {self.alice: [self.alice_device_id], 'dummy': ['dummy']} session = MegolmOutboundSession() @@ -452,8 +448,7 @@ def test_megolm_share_session(self): def test_megolm_start_session(self): to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' responses.add(responses.PUT, to_device_url, json={}) - self.device.device_keys[self.alice][self.alice_device_id] = \ - self.alice_identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device self.device.device_list.tracked_user_ids.add(self.alice) self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] @@ -482,8 +477,7 @@ def test_megolm_start_session(self): def test_megolm_share_session_with_new_devices(self): to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' responses.add(responses.PUT, to_device_url, json={}) - self.device.device_keys[self.alice][self.alice_device_id] = \ - self.alice_identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] session = MegolmOutboundSession() self.device.megolm_outbound_sessions[self.room_id] = session @@ -500,8 +494,7 @@ def test_megolm_build_encrypted_event(self): to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' responses.add(responses.PUT, to_device_url, json={}) self.device.megolm_outbound_sessions.clear() - self.device.device_keys[self.alice][self.alice_device_id] = \ - self.alice_identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device self.device.device_list.tracked_user_ids.add(self.alice) self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] event = {'type': 'm.room.message', 'content': {'body': 'test'}} @@ -527,8 +520,7 @@ def test_send_encrypted_message(self): message_url = HOSTNAME + MATRIX_V2_API_PATH + \ '/rooms/{}/send/m.room.encrypted/1'.format(quote(self.room.room_id)) responses.add(responses.PUT, message_url, json={}) - self.device.device_keys[self.alice][self.alice_device_id] = \ - self.alice_identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] session = MegolmOutboundSession() session.add_device(self.alice_device_id) @@ -570,9 +562,8 @@ def test_handle_room_key_event(self): def test_olm_handle_encrypted_event(self): self.device.olm_sessions.clear() alice_device = OlmDevice(self.device.api, self.alice, self.alice_device_id) - alice_device.device_keys[self.user_id][self.device_id] = self.device.identity_keys - self.device.device_keys[self.alice][self.alice_device_id] = \ - alice_device.identity_keys + alice_device.device_keys[self.user_id][self.device_id] = self.device + self.device.device_keys[self.alice][self.alice_device_id] = alice_device # Artificially start an Olm session from Alice self.device.olm_account.generate_one_time_keys(1) From b27dfa84c8db7819be2ad8fb4cf7b5d60a74bacf Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sat, 28 Jul 2018 17:43:09 +0200 Subject: [PATCH 45/66] remember signing key when establishing Megolm Inbound Session --- docs/source/matrix_client.rst | 2 +- matrix_client/crypto/crypto_store.py | 31 ++++++++++--------- matrix_client/crypto/olm_device.py | 13 +++++--- ...megolm_outbound_session.py => sessions.py} | 20 +++++++++++- test/crypto/crypto_store_test.py | 16 ++++++---- test/crypto/olm_device_test.py | 18 +++++++---- test/response_examples.py | 5 ++- 7 files changed, 70 insertions(+), 35 deletions(-) rename matrix_client/crypto/{megolm_outbound_session.py => sessions.py} (79%) diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index a3dd2e47..79f0560c 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -62,7 +62,7 @@ matrix_client.crypto :undoc-members: :show-inheritance: -.. automodule:: matrix_client.crypto.megolm_outbound_session +.. automodule:: matrix_client.crypto.sessions :members: :undoc-members: :show-inheritance: diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index d6d143be..bcdb5461 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -8,7 +8,7 @@ import olm from appdirs import user_data_dir -from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession +from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession from matrix_client.device import Device logger = logging.getLogger(__name__) @@ -71,7 +71,7 @@ def create_tables_if_needed(self): ); CREATE TABLE IF NOT EXISTS megolm_inbound_sessions( device_id TEXT, session_id TEXT PRIMARY KEY, room_id TEXT, curve_key TEXT, - session BLOB, + ed_key TEXT, session BLOB, FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS megolm_outbound_sessions( @@ -246,11 +246,11 @@ def save_inbound_session(self, room_id, curve_key, session): Args: room_id (str): The room corresponding to the session. curve_key (str): The curve25519 key of the device. - session (olm.InboundGroupSession): The session to save. + session (MegolmInboundSession): The session to save. """ c = self.conn.cursor() - c.execute('REPLACE INTO megolm_inbound_sessions VALUES (?,?,?,?,?)', - (self.device_id, session.id, room_id, curve_key, + c.execute('REPLACE INTO megolm_inbound_sessions VALUES (?,?,?,?,?,?)', + (self.device_id, session.id, room_id, curve_key, session.ed25519, session.pickle(self.pickle_key))) c.close() self.conn.commit() @@ -262,15 +262,15 @@ def load_inbound_sessions(self, sessions): sessions (defaultdict(defaultdict(dict))): An object which will get populated with the sessions. The format is ``{: {: {: - }}}``. + }}}``. """ c = self.conn.cursor() rows = c.execute( 'SELECT * FROM megolm_inbound_sessions WHERE device_id=?', (self.device_id,) ) for row in rows: - session = olm.InboundGroupSession.from_pickle( - bytes(row['session']), self.pickle_key) + session = MegolmInboundSession.from_pickle( + bytes(row['session']), row['ed_key'], self.pickle_key) sessions[row['room_id']][row['curve_key']][session.id] = session c.close() @@ -282,25 +282,26 @@ def get_inbound_session(self, room_id, curve_key, session_id, sessions=None): curve_key (str): The curve25519 key of the device. session_id (str): The id of the session. sessions (dict): Optional. A map from session id to - ``olm.InboundGroupSession`` object, to which the session will be added. + ``MegolmInboundSession`` object, to which the session will be added. Returns: - ``olm.InboundGroupSession`` object, or ``None`` if the session was not found. + ``MegolmInboundSession`` object, or ``None`` if the session was not found. """ c = self.conn.cursor() c.execute( - 'SELECT session FROM megolm_inbound_sessions WHERE device_id=? AND room_id=? ' - 'AND curve_key=? AND session_id=?', + 'SELECT session, ed_key FROM megolm_inbound_sessions WHERE device_id=? AND ' + 'room_id=? AND curve_key=? AND session_id=?', (self.device_id, room_id, curve_key, session_id) ) try: - session_data = c.fetchone()['session'] - session_data = bytes(session_data) + row = c.fetchone() + session_data = bytes(row['session']) except TypeError: return None finally: c.close() - session = olm.InboundGroupSession.from_pickle(session_data, self.pickle_key) + session = MegolmInboundSession.from_pickle(session_data, row['ed_key'], + self.pickle_key) if sessions is not None: sessions[session.id] = session return session diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index d5a8564b..f5b87143 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -9,7 +9,7 @@ from matrix_client.device import Device from matrix_client.crypto.one_time_keys import OneTimeKeysManager from matrix_client.crypto.device_list import DeviceList -from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession +from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession from matrix_client.crypto.crypto_store import CryptoStore logger = logging.getLogger(__name__) @@ -466,7 +466,7 @@ def megolm_start_session(self, room): self.megolm_share_session(room.room_id, user_devices, session) # Store a corresponding inbound session, so that we can decrypt our own messages self.megolm_add_inbound_session( - room.room_id, self.curve25519, session.id, session.session_key) + room.room_id, self.curve25519, self.ed25519, session.id, session.session_key) return session def megolm_share_session(self, room_id, user_devices, session): @@ -628,6 +628,7 @@ def handle_room_key_event(self, event, sender_key): Args: event (dict): m.room_key event. """ + signing_key = event['keys']['ed25519'] content = event['content'] if content['algorithm'] != self._megolm_algorithm: logger.info('Ignoring unsupported algorithm %s in m.room_key event.', @@ -637,7 +638,7 @@ def handle_room_key_event(self, event, sender_key): device_id = event['sender_device'] new = self.megolm_add_inbound_session(content['room_id'], sender_key, - content['session_id'], + signing_key, content['session_id'], content['session_key']) if new: logger.info('Created a new Megolm inbound session with device %s of ' @@ -646,7 +647,8 @@ def handle_room_key_event(self, event, sender_key): logger.info('Inbound Megolm session with device %s of user %s ' 'already exists or is invalid.', device_id, user_id) - def megolm_add_inbound_session(self, room_id, sender_key, session_id, session_key): + def megolm_add_inbound_session(self, room_id, sender_key, signing_key, session_id, + session_key): """Create a new Megolm inbound session if necessary. Args: @@ -654,6 +656,7 @@ def megolm_add_inbound_session(self, room_id, sender_key, session_id, session_ke sender_key (str): The curve25519 key of the sender's device. session_id (str): The id of the session. session_key (str): The key of the session. + signing_key (str): The ed25519 key of the event which established the session. Returns: ``True`` if a new session was created, ``False`` if it already existed or if @@ -666,7 +669,7 @@ def megolm_add_inbound_session(self, room_id, sender_key, session_id, session_ke if self.db.get_inbound_session(room_id, sender_key, session_id, sessions): return False try: - session = olm.InboundGroupSession(session_key) + session = MegolmInboundSession(session_key, signing_key) except olm.OlmGroupSessionError: return False if session.id != session_id: diff --git a/matrix_client/crypto/megolm_outbound_session.py b/matrix_client/crypto/sessions.py similarity index 79% rename from matrix_client/crypto/megolm_outbound_session.py rename to matrix_client/crypto/sessions.py index d21b049f..fe9d024d 100644 --- a/matrix_client/crypto/megolm_outbound_session.py +++ b/matrix_client/crypto/sessions.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta -from olm import OutboundGroupSession +from olm import OutboundGroupSession, InboundGroupSession class MegolmOutboundSession(OutboundGroupSession): @@ -73,3 +73,21 @@ def from_pickle(cls, pickle, devices, max_age, max_messages, creation_time, session.creation_time = creation_time session.message_count = message_count return session + + +class MegolmInboundSession(InboundGroupSession): + + """Olm session with memory of the ed25519 key of the user it was established with.""" + + def __init__(self, session_key, signing_key): + self.ed25519 = signing_key + super(MegolmInboundSession, self).__init__(session_key) + + def __new__(cls, *args): + return super(MegolmInboundSession, cls).__new__(cls) + + @classmethod + def from_pickle(cls, pickle, signing_key, passphrase=''): + session = super(MegolmInboundSession, cls).from_pickle(pickle, passphrase) + session.ed25519 = signing_key + return session diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index 8da0a62f..7a8b8687 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -7,7 +7,7 @@ from matrix_client.crypto.crypto_store import CryptoStore from matrix_client.crypto.olm_device import OlmDevice -from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession +from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession from matrix_client.room import Room from matrix_client.user import User @@ -52,6 +52,10 @@ def account(self): def curve_key(self, account): return account.identity_keys['curve25519'] + @pytest.fixture() + def ed_key(self, account): + return account.identity_keys['ed25519'] + @pytest.fixture() def device(self): return OlmDevice(None, self.user_id, self.device_id, store_conf=self.store_conf) @@ -137,9 +141,9 @@ def test_olm_sessions_persistence(self, account, curve_key, device): self.store.remove_olm_account() assert not self.store.get_olm_sessions(curve_key) - def test_megolm_inbound_persistence(self, curve_key, device): + def test_megolm_inbound_persistence(self, curve_key, ed_key, device): out_session = olm.OutboundGroupSession() - session = olm.InboundGroupSession(out_session.session_key) + session = MegolmInboundSession(out_session.session_key, ed_key) sessions = defaultdict(lambda: defaultdict(dict)) self.store.load_inbound_sessions(sessions) @@ -161,7 +165,7 @@ def test_megolm_inbound_persistence(self, curve_key, device): assert not device.megolm_inbound_sessions created = device.megolm_add_inbound_session( - self.room_id, curve_key, session.id, out_session.session_key) + self.room_id, curve_key, ed_key, session.id, out_session.session_key) assert not created assert device.megolm_inbound_sessions[self.room_id][curve_key][session.id].id == \ session.id @@ -311,12 +315,12 @@ def test_sync_token_persistence(self): self.store.save_sync_token(sync_token) assert self.store.get_sync_token() == sync_token - def test_load_all(self, account, curve_key, device): + def test_load_all(self, account, curve_key, ed_key, device): curve_key = account.identity_keys['curve25519'] session = olm.OutboundSession(account, curve_key, curve_key) out_session = MegolmOutboundSession() out_session.add_device(self.device_id) - in_session = olm.InboundGroupSession(out_session.session_key) + in_session = MegolmInboundSession(out_session.session_key, ed_key) device_keys_to_save = {self.user_id: {self.device_id: device}} self.store.save_inbound_session(self.room_id, curve_key, in_session) diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 12c45c0b..bd57ff89 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -18,7 +18,7 @@ from matrix_client.user import User from matrix_client.device import Device from test.crypto.dummy_olm_device import OlmDevice -from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession +from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession from test.response_examples import (example_key_upload_response, example_claim_keys_response, example_room_key_event) @@ -533,15 +533,21 @@ def test_megolm_add_inbound_session(self): self.device.megolm_inbound_sessions.clear() assert not self.device.megolm_add_inbound_session( - self.room_id, self.alice_curve_key, session.id, 'wrong') + self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, 'wrong') assert self.device.megolm_add_inbound_session( - self.room_id, self.alice_curve_key, session.id, session.session_key) + self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, + session.session_key + ) assert session.id in \ self.device.megolm_inbound_sessions[self.room_id][self.alice_curve_key] assert not self.device.megolm_add_inbound_session( - self.room_id, self.alice_curve_key, session.id, session.session_key) + self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, + session.session_key + ) assert not self.device.megolm_add_inbound_session( - self.room_id, self.alice_curve_key, 'wrong', session.session_key) + self.room_id, self.alice_curve_key, self.alice_ed_key, 'wrong', + session.session_key + ) def test_handle_room_key_event(self): self.device.megolm_inbound_sessions.clear() @@ -629,7 +635,7 @@ def test_megolm_decrypt_event(self): with pytest.raises(RuntimeError): self.device.megolm_decrypt_event(event) - in_session = olm.InboundGroupSession(out_session.session_key) + in_session = MegolmInboundSession(out_session.session_key, self.alice_ed_key) sessions = self.device.megolm_inbound_sessions[self.room_id] sessions[self.alice_curve_key][in_session.id] = in_session diff --git a/test/response_examples.py b/test/response_examples.py index e577c464..9b5e9456 100644 --- a/test/response_examples.py +++ b/test/response_examples.py @@ -251,5 +251,8 @@ "G413GWJkw9T+G6y51bsNEKsSU23lnJz32u5XwgNY9qdFKxGA6WL1wZZS6/iGW4gfTU/Jk89aGSA8" "Aw") }, - "type": "m.room_key" + "type": "m.room_key", + "keys": { + "ed25519": "4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc", + } } From 6d31ccec23669a090c9a45e2f3a0db8ce30d0f92 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sat, 28 Jul 2018 17:49:27 +0200 Subject: [PATCH 46/66] configure device verification --- matrix_client/client.py | 9 +++++++-- matrix_client/room.py | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/matrix_client/client.py b/matrix_client/client.py index b2537343..597d1ef6 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -67,6 +67,10 @@ class MatrixClient(object): restore_device_id (bool): Optional. Only valid when encryption is enabled. When turned on, the device ID corresponding to the user ID will be retrieved from the encryption database, if it exists. + verify_devices (bool): Optional. When enabled, sending a message will fail when + there are unknown devices in an encrypted room. A client will have to + inspect those, and resend its message. Note that this can be configured later + on a per room basis. Returns: `MatrixClient` @@ -116,7 +120,7 @@ def global_callback(incoming_event): def __init__(self, base_url, token=None, user_id=None, valid_cert_check=True, sync_filter_limit=20, cache_level=CACHE.ALL, encryption=False, encryption_conf=None, - restore_device_id=False): + restore_device_id=False, verify_devices=False): if user_id: warn( "user_id is deprecated. " @@ -143,6 +147,7 @@ def __init__(self, base_url, token=None, user_id=None, self.olm_device = None self.first_sync = True self.restore_device_id = restore_device_id + self.verify_devices = verify_devices if isinstance(cache_level, CACHE): self._cache_level = cache_level else: @@ -594,7 +599,7 @@ def upload(self, content, content_type, filename=None): ) def _mkroom(self, room_id): - room = Room(self, room_id) + room = Room(self, room_id, verify_devices=self.verify_devices) if self._encryption: try: event = self.api.get_state_event(room_id, "m.room.encryption") diff --git a/matrix_client/room.py b/matrix_client/room.py index d4a1ff40..f04a7de1 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -31,7 +31,7 @@ class Room(object): NOTE: This does not verify the room with the Home Server. """ - def __init__(self, client, room_id): + def __init__(self, client, room_id, verify_devices=False): check_room_id(room_id) self.room_id = room_id @@ -55,6 +55,7 @@ def __init__(self, client, room_id): self.encrypted = False self.rotation_period_msgs = None self.rotation_period_ms = None + self.verify_devices = verify_devices def set_user_profile(self, displayname=None, From c1c8aca297fc0cb7ae7e611befbf086d35beb3a2 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sat, 28 Jul 2018 15:53:07 +0200 Subject: [PATCH 47/66] persist verification info --- matrix_client/crypto/crypto_store.py | 29 +++++++++++++++++----------- test/crypto/crypto_store_test.py | 3 +++ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index bcdb5461..bd16fbb1 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -88,7 +88,8 @@ def create_tables_if_needed(self): ); CREATE TABLE IF NOT EXISTS device_keys( device_id TEXT, user_id TEXT, user_device_id TEXT, ed_key TEXT, - curve_key TEXT, PRIMARY KEY(device_id, user_id, user_device_id), + curve_key TEXT, verified INTEGER, blacklisted INTEGER, ignored INTEGER, + PRIMARY KEY(device_id, user_id, user_device_id), FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS tracked_users( @@ -432,8 +433,9 @@ def save_device_keys(self, device_keys): for user_id, devices_dict in device_keys.items(): for device_id, device in devices_dict.items(): rows.append((self.device_id, user_id, device_id, device.ed25519, - device.curve25519)) - c.executemany('REPLACE INTO device_keys VALUES (?,?,?,?,?)', rows) + device.curve25519, device.verified, device.blacklisted, + device.ignored)) + c.executemany('REPLACE INTO device_keys VALUES (?,?,?,?,?,?,?,?)', rows) c.close() self.conn.commit() @@ -448,10 +450,8 @@ def load_device_keys(self, api, device_keys): rows = c.execute( 'SELECT * FROM device_keys WHERE device_id=?', (self.device_id,)) for row in rows: - device = Device(api, row['user_device_id'], - ed25519_key=row['ed_key'], - curve25519_key=row['curve_key']) - device_keys[row['user_id']][row['user_device_id']] = device + device_keys[row['user_id']][row['user_device_id']] = \ + self._device_from_row(row, api) c.close() def get_device_keys(self, api, user_devices, device_keys=None): @@ -487,14 +487,21 @@ def get_device_keys(self, api, user_devices, device_keys=None): c.close() result = defaultdict(dict) for row in rows: - device = Device(api, row['user_device_id'], - ed25519_key=row['ed_key'], - curve25519_key=row['curve_key']) - result[row['user_id']][row['user_device_id']] = device + result[row['user_id']][row['user_device_id']] = \ + self._device_from_row(row, api) + if device_keys is not None and result: device_keys.update(result) return result + @staticmethod + def _device_from_row(row, api): + return Device( + api, row['user_device_id'], ed25519_key=row['ed_key'], + curve25519_key=row['curve_key'], verified=row['verified'], + blacklisted=row['blacklisted'], ignored=row['ignored'] + ) + def save_tracked_users(self, user_ids): """Saves tracked users. diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index 7a8b8687..d1105080 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -242,6 +242,7 @@ def test_megolm_outbound_persistence(self, device): def test_device_keys_persistence(self, device): user_devices = {self.user_id: [self.device_id]} device_keys = defaultdict(dict) + device.verified = True self.store.load_device_keys(None, device_keys) assert not device_keys @@ -253,6 +254,7 @@ def test_device_keys_persistence(self, device): self.store.load_device_keys(None, device_keys) assert device_keys[self.user_id][self.device_id].curve25519 == \ device.curve25519 + assert device_keys[self.user_id][self.device_id].verified device_keys.clear() devices = self.store.get_device_keys(None, user_devices)[self.user_id] @@ -260,6 +262,7 @@ def test_device_keys_persistence(self, device): assert self.store.get_device_keys(None, user_devices, device_keys) assert device_keys[self.user_id][self.device_id].curve25519 == \ device.curve25519 + assert device_keys[self.user_id][self.device_id].verified # Test [] wildcard devices = self.store.get_device_keys(None, {self.user_id: []})[self.user_id] From e4f417fd9e01eefd5f8221c02586500e3ad6231b Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sat, 28 Jul 2018 17:51:42 +0200 Subject: [PATCH 48/66] alert of unknown devices when sending encrypted messages --- matrix_client/crypto/olm_device.py | 100 ++++++++++++++++++++++------- matrix_client/errors.py | 15 +++++ test/crypto/crypto_store_test.py | 5 +- test/crypto/olm_device_test.py | 38 ++++++++++- 4 files changed, 131 insertions(+), 27 deletions(-) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index f5b87143..8dc04506 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -7,6 +7,7 @@ from matrix_client.checks import check_user_id from matrix_client.device import Device +from matrix_client.errors import E2EUnknownDevices from matrix_client.crypto.one_time_keys import OneTimeKeysManager from matrix_client.crypto.device_list import DeviceList from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession @@ -442,11 +443,13 @@ def olm_ensure_sessions(self, user_devices): if user_devices_no_session: self.olm_start_sessions(user_devices_no_session) - def megolm_start_session(self, room): + def megolm_start_session(self, room, user_devices): """Start a megolm session in a room, and share it with its members. Args: room (Room): The room to use. + user_devices (dict): Map from user id to a list of device ids. The session + will be shared with those devices. Returns: The newly created session. @@ -457,10 +460,6 @@ def megolm_start_session(self, room): logger.info('Starting a new Meglom outbound session %s in %s.', session.id, room.room_id) - users = room.get_joined_members() - self.device_list.get_room_device_keys(room) - user_devices = {user.user_id: list(self.device_keys[user.user_id]) - for user in users} self.db.remove_outbound_session(room.room_id) self.db.save_outbound_session(room.room_id, session) self.megolm_share_session(room.room_id, user_devices, session) @@ -507,25 +506,66 @@ def megolm_share_session(self, room_id, user_devices, session): session.add_devices(new_devices) self.db.save_megolm_outbound_devices(room_id, new_devices) - def megolm_share_session_with_new_devices(self, room, session): + def megolm_share_session_with_new_devices(self, room, user_devices, session): """Share a megolm session with new devices in a room. Args: room (Room): The room corresponding to the session. session (MegolmOutboundSession): The session to share. + user_devices (dict): Map from user id to a list of device ids. The session + will be shared with those devices if not already. """ - user_devices = {} - users = room.get_joined_members() - for user in users: - user_id = user.user_id + new_user_devices = {} + for user_id in user_devices: missing_devices = list(set(self.device_keys[user_id].keys()) - self.megolm_outbound_sessions[room.room_id].devices) if missing_devices: - user_devices[user_id] = missing_devices - if user_devices: - logger.info('Sharing existing Megolm outbound session %s with new devices: ' - '%s', session.id, user_devices) - self.megolm_share_session(room.room_id, user_devices, session) + new_user_devices[user_id] = missing_devices + + if new_user_devices: + logger.info('Sharing existing Megolm outbound session %s with new ' + 'devices: %s', session.id, new_user_devices) + self.megolm_share_session(room.room_id, new_user_devices, session) + + def megolm_get_recipients(self, room, session=None): + """Get the devices who should be able to decrypt a Megolm event in a room. + + This implements device verification checks. + + Args: + room (Room): The room to use. + session (MegolmOutboundSession): Optional. If a device the session had + been shared with has been blacklisted, remove the session. + + Returns: + A two element tuple containing a map from user id to a list of device ids, + and a boolean indicating whether the session has been removed. + + Raises: + E2EUnknownDevices if there are never seen before devices in the room. + """ + users = room.get_joined_members() + + user_devices = defaultdict(list) + unknown_devices = defaultdict(list) + removed_session = False + for user in users: + for device_id, device in self.device_keys[user.user_id].items(): + if device.blacklisted: + if session and device.device_id in session.devices: + self.megolm_remove_outbound_session(room.room_id) + removed_session = True + else: + if not room.verify_devices or device.ignored or device.verified: + user_devices[user.user_id].append(device_id) + else: + unknown_devices[user.user_id].append(device) + if unknown_devices and room.verify_devices: + logger.warning('Room %s contains unknown devices which have not been ' + 'verified.', room.room_id) + raise E2EUnknownDevices(unknown_devices) + + return user_devices, removed_session def megolm_build_encrypted_event(self, room, event): """Build an encrypted Megolm payload from a plaintext event. @@ -539,18 +579,32 @@ def megolm_build_encrypted_event(self, room, event): Returns: The encrypted event, as a dict. + + Raises: + E2EUnknownDevices if there are never seen before devices in the room. """ room_id = room.room_id - session = self.megolm_outbound_sessions.get(room_id) - if not session: + try: + session = self.megolm_outbound_sessions[room_id] + except KeyError: session = self.db.get_outbound_session(room_id, self.megolm_outbound_sessions) - if not session: - session = self.megolm_start_session(room) - if session.should_rotate(): - session = self.megolm_start_session(room) + # We have to fetch device keys if there is no session. If there is one, we are + # already tracking the device list of users in the room, so it shouldn't be + # needed. + # However, there is the edge case where a device is blacklisted, and then the + # client is shutdown. When we load the session, if we do not fetch the keys + # (which triggers loading the devices from db), we would miss that a device + # had been blacklisted and we would keep using the session instead of rotating + # it as expected. Hence we also fetch device keys after a session is loaded. + self.device_list.get_room_device_keys(room) + + user_devices, removed_session = self.megolm_get_recipients(room, session) + + if not session or removed_session or session.should_rotate(): + session = self.megolm_start_session(room, user_devices) else: - self.megolm_share_session_with_new_devices(room, session) + self.megolm_share_session_with_new_devices(room, user_devices, session) payload = { 'type': event['type'], @@ -594,6 +648,8 @@ def send_encrypted_message(self, room, content): Raises: MatrixRequestError if there was an error sending the event. + E2EUnknownDevices if there are never seen before devices in the room. + The event will not be sent. """ event = {'content': content, 'room_id': room.room_id, 'type': 'm.room.message'} encrypted_event = self.megolm_build_encrypted_event(room, event) diff --git a/matrix_client/errors.py b/matrix_client/errors.py index 98bd5eb0..91154bb8 100644 --- a/matrix_client/errors.py +++ b/matrix_client/errors.py @@ -53,3 +53,18 @@ class MatrixNoEncryptionError(MatrixError): def __init__(self, content=""): super(MatrixNoEncryptionError, self).__init__(content) + + +class E2EUnknownDevices(Exception): + """The room contained unknown devices when sending a message. + + Args: + user_devices (dict): A map from user_id to a list of Device objects, + containing the unknown devices for that user. + """ + + def __init__(self, user_devices): + super(Exception, self).__init__( + "The room contains unknown devices which have not been verified. They can " + "be inspected via the 'user_devices' attribute of this exception.") + self.user_devices = user_devices diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index d1105080..7567b96f 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -231,8 +231,9 @@ def test_megolm_outbound_persistence(self, device): # Verify the saved devices have been erased with the session assert not saved_session.devices - with pytest.raises(KeyError): - device.megolm_build_encrypted_event(self.room, {}) + room = Room(None, self.room_id) + with pytest.raises(AttributeError): + device.megolm_build_encrypted_event(room, {}) assert device.megolm_outbound_sessions[self.room_id].id == session.id self.store.remove_olm_account() diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index bd57ff89..432703ca 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -17,6 +17,7 @@ from matrix_client.client import MatrixClient from matrix_client.user import User from matrix_client.device import Device +from matrix_client.errors import E2EUnknownDevices from test.crypto.dummy_olm_device import OlmDevice from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession from test.response_examples import (example_key_upload_response, @@ -451,8 +452,9 @@ def test_megolm_start_session(self): self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device self.device.device_list.tracked_user_ids.add(self.alice) self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + user_devices = {self.alice: [self.alice_device_id]} - self.device.megolm_start_session(self.room) + self.device.megolm_start_session(self.room, user_devices) session = self.device.megolm_outbound_sessions[self.room_id] assert self.alice_device_id in session.devices @@ -481,14 +483,44 @@ def test_megolm_share_session_with_new_devices(self): self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] session = MegolmOutboundSession() self.device.megolm_outbound_sessions[self.room_id] = session + user_devices = {self.alice: [self.alice_device_id]} - self.device.megolm_share_session_with_new_devices(self.room, session) + self.device.megolm_share_session_with_new_devices( + self.room, user_devices, session) assert self.alice_device_id in session.devices assert len(responses.calls) == 1 - self.device.megolm_share_session_with_new_devices(self.room, session) + self.device.megolm_share_session_with_new_devices( + self.room, user_devices, session) assert len(responses.calls) == 1 + def test_megolm_get_recipients(self): + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + + user_devices, _ = self.device.megolm_get_recipients(self.room) + assert user_devices == {self.alice: [self.alice_device_id]} + + self.device.megolm_outbound_sessions.clear() + session = MegolmOutboundSession() + self.device.megolm_outbound_sessions[self.room_id] = session + + user_devices, removed = self.device.megolm_get_recipients(self.room, session) + assert user_devices == {self.alice: [self.alice_device_id]} and not removed + + self.alice_device.blacklisted = True + _, removed = self.device.megolm_get_recipients(self.room, session) + assert not removed + session.add_device(self.alice_device_id) + _, removed = self.device.megolm_get_recipients(self.room, session) + assert removed and self.room_id not in self.device.megolm_outbound_sessions + self.alice_device.blacklisted = False + + self.room.verify_devices = True + with pytest.raises(E2EUnknownDevices) as e: + self.device.megolm_get_recipients(self.room) + assert e.value.user_devices == {self.alice: [self.alice_device]} + self.room.verify_devices = False + @responses.activate def test_megolm_build_encrypted_event(self): to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' From b5e1b7a87a1482a39424e64b54f79a88e447292f Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sat, 28 Jul 2018 17:52:26 +0200 Subject: [PATCH 49/66] add device verification checks --- docs/source/matrix_client.rst | 5 +++ matrix_client/crypto/olm_device.py | 37 ++++++++++++++++--- matrix_client/crypto/verified_event.py | 2 ++ matrix_client/room.py | 2 +- test/crypto/olm_device_test.py | 49 +++++++++++++++++++++++++- 5 files changed, 88 insertions(+), 7 deletions(-) create mode 100644 matrix_client/crypto/verified_event.py diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index 79f0560c..16a6aedf 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -76,3 +76,8 @@ matrix_client.crypto :members: :undoc-members: :show-inheritance: + +.. automodule:: matrix_client.crypto.verified_event + :members: + :undoc-members: + :show-inheritance: diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 8dc04506..ee91825f 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -12,6 +12,7 @@ from matrix_client.crypto.device_list import DeviceList from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession from matrix_client.crypto.crypto_store import CryptoStore +from matrix_client.crypto.verified_event import VerifiedEvent logger = logging.getLogger(__name__) @@ -334,8 +335,10 @@ def olm_decrypt_event(self, content, user_id): else: encrypted_message = olm.OlmMessage(payload['body']) - decrypted_event = self._olm_decrypt(encrypted_message, content['sender_key']) + sender_key = content['sender_key'] + decrypted_event = self._olm_decrypt(encrypted_message, sender_key) + signing_key = decrypted_event['keys']['ed25519'] if decrypted_event['sender'] != user_id: raise RuntimeError( 'Found user {} instead of sender {} in Olm plaintext {}.' @@ -352,6 +355,16 @@ def olm_decrypt_event(self, content, user_id): 'Found key {} instead of ours own ed25519 key {} in Olm plaintext {}.' .format(our_key, self.ed25519, decrypted_event) ) + try: + device = self.device_keys[user_id][decrypted_event['sender_device']] + except KeyError: + pass + else: + if device.verified: + if device.curve25519 != sender_key or device.ed25519 != signing_key: + raise RuntimeError( + 'Device keys mismatch between payload and /keys/query data.' + ) return decrypted_event @@ -739,11 +752,12 @@ def megolm_add_inbound_session(self, room_id, sender_key, signing_key, session_i def megolm_decrypt_event(self, event): """Decrypt a Megolm m.room.encrypted event. - The event is decrypted in-place, meaning its content and type properties are - overwritten by those of the decrypted event. - Args: - event (dict): The event to decrypt. + event (dict): The event to decrypt. It may be modified in the process. + + Returns: + The decrypted event, as a normal ``dict`` if unverified, or as a + :class:`.VerifiedEvent` if verified. """ content = event['content'] device_id = content['device_id'] @@ -777,6 +791,17 @@ def megolm_decrypt_event(self, event): 'with matching megolm session: {}.'.format(device_id, user_id, e)) + try: + device = self.device_keys[user_id][device_id] + except KeyError: + pass + else: + if device.verified: + if device.ed25519 != session.ed25519 or device.curve25519 != sender_key: + raise RuntimeError('Device keys mismatch in event sent by device {}.' + .format(device.device_id)) + event = VerifiedEvent(event) + try: properties = self.megolm_index_record[session.id][message_index] except KeyError: @@ -796,6 +821,8 @@ def megolm_decrypt_event(self, event): event['type'] = decrypted_event['type'] event['content'] = decrypted_event['content'] + return event + def sign_json(self, json): """Signs a JSON object. diff --git a/matrix_client/crypto/verified_event.py b/matrix_client/crypto/verified_event.py new file mode 100644 index 00000000..e156452b --- /dev/null +++ b/matrix_client/crypto/verified_event.py @@ -0,0 +1,2 @@ +class VerifiedEvent(dict): + pass diff --git a/matrix_client/room.py b/matrix_client/room.py index f04a7de1..8d8ee39d 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -345,7 +345,7 @@ def _put_event(self, event): if self.encrypted and self.client._encryption: if event['type'] == 'm.room.encrypted': try: - self.client.olm_device.megolm_decrypt_event(event) + event = self.client.olm_device.megolm_decrypt_event(event) except RuntimeError as e: logger.warning(e) self.events.append(event) diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 432703ca..123485ed 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -13,6 +13,7 @@ import responses from matrix_client.crypto import olm_device +from matrix_client.crypto.verified_event import VerifiedEvent from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient from matrix_client.user import User @@ -369,6 +370,27 @@ def test_olm_decrypt_event(self): # Now we can test self.device.olm_decrypt_event(encrypted_event, self.alice) + # Device verification + alice_device.verified = True + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.olm_decrypt_event(encrypted_event, self.alice) + + # The signing_key is wrong + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.device_keys[self.alice][self.alice_device_id]._ed25519 = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, self.alice) + + # We do not have the keys + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.device_keys[self.alice].clear() + self.device.olm_decrypt_event(encrypted_event, self.alice) + self.device.device_keys[self.alice][self.alice_device_id] = alice_device + alice_device.verified = False + # Type 1 Olm payload alice_device.olm_decrypt_event( self.device.olm_build_encrypted_event( @@ -667,7 +689,8 @@ def test_megolm_decrypt_event(self): with pytest.raises(RuntimeError): self.device.megolm_decrypt_event(event) - in_session = MegolmInboundSession(out_session.session_key, self.alice_ed_key) + session_key = out_session.session_key + in_session = MegolmInboundSession(session_key, self.alice_ed_key) sessions = self.device.megolm_inbound_sessions[self.room_id] sessions[self.alice_curve_key][in_session.id] = in_session @@ -690,6 +713,30 @@ def test_megolm_decrypt_event(self): event['event_id'] = 2 with pytest.raises(RuntimeError): self.device.megolm_decrypt_event(event) + event['event_id'] = 1 + + # Device verification + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + event['content'] = content + # Unverified + self.device.megolm_decrypt_event(event) + assert event['content'] == plaintext['content'] + assert isinstance(event, dict) + + event['content'] = content + # Verified + self.alice_device.verified = True + decrypted_event = self.device.megolm_decrypt_event(event) + assert decrypted_event['content'] == plaintext['content'] + assert isinstance(decrypted_event, VerifiedEvent) + + in_session = MegolmInboundSession(session_key, self.alice_curve_key) + sessions = self.device.megolm_inbound_sessions[self.room_id] + sessions[self.alice_curve_key][in_session.id] = in_session + # Wrong signing key + with pytest.raises(RuntimeError): + self.device.megolm_decrypt_event(event) + self.alice_device.verified = False event['content']['algorithm'] = 'wrong' with pytest.raises(RuntimeError): From 0492e9521f84541e2288e50f0c3424e6eb2b82b0 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 3 Aug 2018 17:00:58 +0200 Subject: [PATCH 50/66] persist device upon verification --- matrix_client/crypto/crypto_store.py | 12 +++----- matrix_client/crypto/device_list.py | 5 +-- matrix_client/crypto/olm_device.py | 5 ++- matrix_client/device.py | 46 ++++++++++++++++++++++++++-- test/crypto/crypto_store_test.py | 2 +- test/crypto/device_list_test.py | 2 +- test/crypto/olm_device_test.py | 8 ++--- 7 files changed, 59 insertions(+), 21 deletions(-) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index bd16fbb1..bd9c7c85 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -494,13 +494,11 @@ def get_device_keys(self, api, user_devices, device_keys=None): device_keys.update(result) return result - @staticmethod - def _device_from_row(row, api): - return Device( - api, row['user_device_id'], ed25519_key=row['ed_key'], - curve25519_key=row['curve_key'], verified=row['verified'], - blacklisted=row['blacklisted'], ignored=row['ignored'] - ) + def _device_from_row(self, row, api): + return Device(api, row['user_id'], row['user_device_id'], database=self, + ed25519_key=row['ed_key'], curve25519_key=row['curve_key'], + verified=row['verified'], blacklisted=row['blacklisted'], + ignored=row['ignored']) def save_tracked_users(self, user_ids): """Saves tracked users. diff --git a/matrix_client/crypto/device_list.py b/matrix_client/crypto/device_list.py index 37d36b6a..8952f468 100644 --- a/matrix_client/crypto/device_list.py +++ b/matrix_client/crypto/device_list.py @@ -204,9 +204,10 @@ def _download_device_keys(self, user_devices, since_token=None): try: device = devices[device_id] except KeyError: - devices[device_id] = Device(self.api, device_id, + devices[device_id] = Device(self.api, user_id, device_id, curve25519_key=curve_key, - ed25519_key=signing_key) + ed25519_key=signing_key, + database=self.db) else: if device.ed25519 != signing_key: logger.warning('Ed25519 key has changed for device %s of ' diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index ee91825f..9bf9863d 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -102,9 +102,8 @@ def __init__(self, self.device_list = DeviceList(self, api, self.device_keys, self.db) self.megolm_index_record = defaultdict(dict) keys = self.olm_account.identity_keys - super(OlmDevice, self).__init__(self.api, - device_id, - ed25519_key=keys['ed25519'], + super(OlmDevice, self).__init__(self.api, self.user_id, device_id, + database=self.db, ed25519_key=keys['ed25519'], curve25519_key=keys['curve25519']) def upload_identity_keys(self): diff --git a/matrix_client/device.py b/matrix_client/device.py index 4fde19ea..bd8f1d4a 100644 --- a/matrix_client/device.py +++ b/matrix_client/device.py @@ -5,7 +5,9 @@ class Device(object): def __init__(self, api, + user_id, device_id, + database=None, display_name=None, last_seen_ip=None, last_seen_ts=None, @@ -15,13 +17,15 @@ def __init__(self, ed25519_key=None, curve25519_key=None): self.api = api + self.user_id = user_id self.device_id = device_id + self.database = database self.display_name = display_name self.last_seen_ts = last_seen_ts self.last_seen_ip = last_seen_ip - self.verified = verified - self.blacklisted = blacklisted - self.ignored = ignored + self._verified = verified + self._blacklisted = blacklisted + self._ignored = ignored self._ed25519 = ed25519_key self._curve25519 = curve25519_key @@ -45,6 +49,15 @@ def get_info(self): self.last_seen_ts = info.get('last_seen_ts') return True + def save_to_db(func): + def save(self, boolean): + if not self.ed25519: + raise ValueError('Changing this property is not allowed when the device ' + 'keys are unknown.') + func(self, boolean) + self.database.save_device_keys({self.user_id: {self.device_id: self}}) + return save + @property def ed25519(self): return self._ed25519 @@ -52,3 +65,30 @@ def ed25519(self): @property def curve25519(self): return self._curve25519 + + @property + def verified(self): + return self._verified + + @verified.setter + @save_to_db + def verified(self, boolean): + self._verified = boolean + + @property + def ignored(self): + return self._ignored + + @ignored.setter + @save_to_db + def ignored(self, boolean): + self._ignored = boolean + + @property + def blacklisted(self): + return self._blacklisted + + @blacklisted.setter + @save_to_db + def blacklisted(self, boolean): + self._blacklisted = boolean diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index 7567b96f..2834e30e 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -243,7 +243,7 @@ def test_megolm_outbound_persistence(self, device): def test_device_keys_persistence(self, device): user_devices = {self.user_id: [self.device_id]} device_keys = defaultdict(dict) - device.verified = True + device._verified = True self.store.load_device_keys(None, device_keys) assert not device_keys diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py index 014575f9..bd0bb6df 100644 --- a/test/crypto/device_list_test.py +++ b/test/crypto/device_list_test.py @@ -95,7 +95,7 @@ def test_download_device_keys(self): assert download_device_keys(user_devices) req = json.loads(responses.calls[0].request.body) assert req['device_keys'] == {self.alice: [], bob: [], self.user_id: []} - device = Device(self.cli.api, 'JLAFKJWSCS', + device = Device(self.cli.api, self.alice, 'JLAFKJWSCS', database=DummyStore, curve25519_key='3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI', ed25519_key='VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA') expected_device_keys = { diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 123485ed..a315549d 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -19,7 +19,7 @@ from matrix_client.user import User from matrix_client.device import Device from matrix_client.errors import E2EUnknownDevices -from test.crypto.dummy_olm_device import OlmDevice +from test.crypto.dummy_olm_device import OlmDevice, DummyStore from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession from test.response_examples import (example_key_upload_response, example_claim_keys_response, @@ -39,8 +39,8 @@ class TestOlmDevice: alice_device_id = 'JLAFKJWSCS' alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' alice_ed_key = '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' - alice_device = Device(cli.api, alice_device_id, curve25519_key=alice_curve_key, - ed25519_key=alice_ed_key) + alice_device = Device(cli.api, alice, alice_device_id, database=DummyStore(), + curve25519_key=alice_curve_key, ed25519_key=alice_ed_key) alice_olm_session = olm.OutboundSession( device.olm_account, alice_curve_key, alice_curve_key) room = cli._mkroom(room_id) @@ -455,7 +455,7 @@ def test_megolm_share_session(self): self.device.olm_sessions.clear() self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device self.device.device_keys['dummy']['dummy'] = \ - Device(self.cli.api, 'dummy', curve25519_key='a', ed25519_key='a') + Device(self.cli.api, 'dummy', 'dummy', curve25519_key='a', ed25519_key='a') user_devices = {self.alice: [self.alice_device_id], 'dummy': ['dummy']} session = MegolmOutboundSession() From ba0950a696f75ca072d43341f124afb12d3faefa Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 7 Aug 2018 16:27:19 +0200 Subject: [PATCH 51/66] add Device class docstring --- matrix_client/device.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/matrix_client/device.py b/matrix_client/device.py index bd8f1d4a..2d0aff63 100644 --- a/matrix_client/device.py +++ b/matrix_client/device.py @@ -2,6 +2,23 @@ class Device(object): + """Represents a Matrix device, belonging to a user. + + Args: + api (MatrixHttpApi): The api object used to make requests. + user_id (str): User ID of this device's owner. + device_id (str): The device ID. + display_name (str): Optional. The display name of this device, if any. + last_seen_ip (str): Optional. The IP address where this device was last seen. + last_seen_ts (int): Optional. The timestamp (in milliseconds since the unix + epoch) when this device was last seen. + verified, blacklisted, ignored (bool): Optional. Device verification info. + ed25519_key (str): Optional. The Ed25519 fingerprint key of this device. The + corresponding attribute ``ed25519`` cannot be changed after initialisation. + curve25519_key (str): Optional. The Curve25519 fingerprint key of this device. The + corresponding attribute ``curve25519`` cannot be changed after initialisation. + database (CryptoStore): Optional. Allows to save device verification info. + """ def __init__(self, api, From 93d7b0e24b4d7516e25cd2e2ac2902de64037e45 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 7 Aug 2018 16:33:55 +0200 Subject: [PATCH 52/66] add Device class tests --- test/crypto/crypto_store_test.py | 12 +++++++++ test/device_test.py | 46 ++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 test/device_test.py diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index 2834e30e..42bc859d 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -265,6 +265,13 @@ def test_device_keys_persistence(self, device): device.curve25519 assert device_keys[self.user_id][self.device_id].verified + # Test device verification persistence + device.verified = False + device.ignored = True + devices = self.store.get_device_keys(None, user_devices)[self.user_id] + assert not devices[self.device_id].verified + assert devices[self.device_id].ignored + # Test [] wildcard devices = self.store.get_device_keys(None, {self.user_id: []})[self.user_id] assert devices[self.device_id].curve25519 == device.curve25519 @@ -286,6 +293,11 @@ def test_device_keys_persistence(self, device): assert device_keys[self.user_id][self.device_id].curve25519 == device.curve25519 assert device_keys[user_id][device_id].curve25519 == device.curve25519 + # Try to verify a device that has no keys + device._ed25519 = None + with pytest.raises(ValueError): + device.verified = False + self.store.remove_olm_account() assert not self.store.get_device_keys(None, user_devices) diff --git a/test/device_test.py b/test/device_test.py new file mode 100644 index 00000000..ab48f3bc --- /dev/null +++ b/test/device_test.py @@ -0,0 +1,46 @@ +import pytest +import responses + +from matrix_client.api import MATRIX_V2_API_PATH +from matrix_client.client import MatrixClient +from matrix_client.errors import MatrixRequestError +from matrix_client.device import Device + +HOSTNAME = 'http://localhost' + + +class TestDevice(object): + + cli = MatrixClient(HOSTNAME) + user_id = '@test:localhost' + device_id = 'AUIETRSN' + + @pytest.fixture() + def device(self): + return Device(self.cli.api, self.user_id, self.device_id) + + @responses.activate + def test_get_info(self, device): + device_url = HOSTNAME + MATRIX_V2_API_PATH + '/devices/' + self.device_id + display_name = 'android' + last_seen_ip = '1.2.3.4' + last_seen_ts = 1474491775024 + resp = { + "device_id": self.device_id, + "display_name": display_name, + "last_seen_ip": last_seen_ip, + "last_seen_ts": last_seen_ts + } + responses.add(responses.GET, device_url, json=resp) + + assert device.get_info() + assert device.display_name == display_name + assert device.last_seen_ip == last_seen_ip + assert device.last_seen_ts == last_seen_ts + + responses.replace(responses.GET, device_url, status=404) + assert not device.get_info() + + responses.replace(responses.GET, device_url, status=500) + with pytest.raises(MatrixRequestError): + device.get_info() From cdffc9f63075eebc92245fc957b9c2afcd03c0e7 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 9 Aug 2018 18:11:16 +0200 Subject: [PATCH 53/66] add get_fingerprint method to client --- matrix_client/client.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/matrix_client/client.py b/matrix_client/client.py index 597d1ef6..b1a71f5e 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -722,3 +722,12 @@ def remove_room_alias(self, room_alias): return True except MatrixRequestError: return False + + def get_fingerprint(self): + """Get the fingerprint of the current device. + + This is used when verifying devices. + """ + if not self._encryption: + raise ValueError("Encryption is not enabled, this device has no fingerprint.") + return self.olm_device.ed25519 From 0fdfdd636fd8af0d848050487ccbb12f05119cf6 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Sat, 15 Sep 2018 00:02:18 +0200 Subject: [PATCH 54/66] fixup! add devices attribute to User --- matrix_client/client.py | 2 +- matrix_client/room.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/matrix_client/client.py b/matrix_client/client.py index b1a71f5e..4edf4ca4 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -705,7 +705,7 @@ def get_user(self, user_id): """ warn("get_user is deprecated. Directly instantiate a User instead.", DeprecationWarning) - return User(self.api, user_id) + return User(self, user_id) # TODO: move to Room class def remove_room_alias(self, room_alias): diff --git a/matrix_client/room.py b/matrix_client/room.py index 8d8ee39d..5ec57d7a 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -545,7 +545,7 @@ def _add_member(self, user_id, displayname=None): if user_id in self.client.users: self._members[user_id] = self.client.users[user_id] return - self._members[user_id] = User(self.client.api, user_id, displayname) + self._members[user_id] = User(self.client, user_id, displayname) self.client.users[user_id] = self._members[user_id] def backfill_previous_messages(self, reverse=False, limit=10): From d11a0e7786e0e39155a5e1ebb75161c849647c57 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 2 Aug 2018 15:45:11 +0200 Subject: [PATCH 55/66] refactor crypto tests Make those tests more maintanable by getting rid of unwanted side effects and useless cleanups by using pytest.fixture, along with some readability improvements. --- test/crypto/device_list_test.py | 171 ++++++------ test/crypto/olm_device_test.py | 474 +++++++++++++++----------------- 2 files changed, 307 insertions(+), 338 deletions(-) diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py index bd0bb6df..1d53c45b 100644 --- a/test/crypto/device_list_test.py +++ b/test/crypto/device_list_test.py @@ -26,15 +26,20 @@ class TestDeviceList: alice = '@alice:example.com' room_id = '!test:example.com' device_id = 'AUIETSRN' - device = OlmDevice(cli.api, user_id, device_id) - device_list = device.device_list - signing_key = device.olm_account.identity_keys['ed25519'] query_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/query' + @pytest.fixture() + def device(self): + return OlmDevice(self.cli.api, self.user_id, self.device_id) + + @pytest.fixture() + def device_list(self, device): + return device.device_list + @responses.activate - def test_download_device_keys(self): + def test_download_device_keys(self, device, device_list): # The method we want to test - download_device_keys = self.device_list._download_device_keys + download_device_keys = device_list._download_device_keys bob = '@bob:example.com' eve = '@eve:example.com' user_devices = {self.alice: [], bob: [], self.user_id: []} @@ -44,6 +49,22 @@ def test_download_device_keys(self): resp = example_key_query_response responses.add(responses.POST, self.query_url, json=resp) + assert download_device_keys(user_devices) + req = json.loads(responses.calls[0].request.body) + assert req['device_keys'] == {self.alice: [], bob: [], self.user_id: []} + alice_device = Device( + self.cli.api, self.alice, 'JLAFKJWSCS', database=DummyStore, + curve25519_key='3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI', + ed25519_key='VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA' + ) + expected_device_keys = { + self.alice: { + 'JLAFKJWSCS': device + } + } + assert device.device_keys[self.alice]['JLAFKJWSCS'].curve25519 == \ + alice_device.curve25519 + # Still correct, but Alice's identity key has changed resp = deepcopy(example_key_query_response) new_id_key = 'ijxGZqwB/UvMtKABdaCdrI0OtQI6NhHBYiknoCkdWng' @@ -52,22 +73,33 @@ def test_download_device_keys(self): payload['signatures'][self.alice]['ed25519:JLAFKJWSCS'] = \ ('D9oLtYefMIr4StiHTIzn3+bhtPCfrZNDU9jsUbMu3MicfZLl4d8WlYn3TPmbwDi8XMGcT' 'nNnqfdi/tYUPvKfCA') - responses.add(responses.POST, self.query_url, json=resp) + responses.replace(responses.POST, self.query_url, json=resp) + + # The Curve25519 key should get updated + assert download_device_keys(user_devices) + expected_device_keys[self.alice]['JLAFKJWSCS']._curve25519 = new_id_key + assert device.device_keys[self.alice]['JLAFKJWSCS'].curve25519 == \ + device.curve25519 # Still correct, but Alice's signing key has changed - alice_device = OlmDevice(self.cli.api, self.alice, 'JLAFKJWSCS') + alice_olm_device = OlmDevice(self.cli.api, self.alice, 'JLAFKJWSCS') resp = deepcopy(example_key_query_response) resp['device_keys'][self.alice]['JLAFKJWSCS']['keys']['ed25519:JLAFKJWSCS'] = \ - alice_device.ed25519 + alice_olm_device.ed25519 resp['device_keys'][self.alice]['JLAFKJWSCS'] = \ - alice_device.sign_json(resp['device_keys'][self.alice]['JLAFKJWSCS']) - responses.add(responses.POST, self.query_url, json=resp) + alice_olm_device.sign_json(resp['device_keys'][self.alice]['JLAFKJWSCS']) + responses.replace(responses.POST, self.query_url, json=resp) + + # The Ed25519 key should not get updated + assert not download_device_keys(user_devices) + assert device.device_keys[self.alice]['JLAFKJWSCS'].ed25519 == \ + alice_device.ed25519 # Response containing an unknown user resp = deepcopy(example_key_query_response) user_device = resp['device_keys'].pop(self.alice) resp['device_keys'][eve] = user_device - responses.add(responses.POST, self.query_url, json=resp) + responses.replace(responses.POST, self.query_url, json=resp) # Response with an invalid signature resp = deepcopy(example_key_query_response) @@ -90,38 +122,12 @@ def test_download_device_keys(self): # And one more by adding ourself resp['device_keys'][self.user_id] = {self.device_id: 'dummy'} responses.add(responses.POST, self.query_url, json=resp) + device.device_keys.clear() - self.device.device_keys.clear() - assert download_device_keys(user_devices) - req = json.loads(responses.calls[0].request.body) - assert req['device_keys'] == {self.alice: [], bob: [], self.user_id: []} - device = Device(self.cli.api, self.alice, 'JLAFKJWSCS', database=DummyStore, - curve25519_key='3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI', - ed25519_key='VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA') - expected_device_keys = { - self.alice: { - 'JLAFKJWSCS': device - } - } - assert self.device.device_keys[self.alice]['JLAFKJWSCS'].curve25519 == \ - device.curve25519 - - # Different curve25519, key should get updated - assert download_device_keys(user_devices) - expected_device_keys[self.alice]['JLAFKJWSCS']._curve25519 = new_id_key - assert self.device.device_keys[self.alice]['JLAFKJWSCS'].curve25519 == \ - device.curve25519 - - # Different ed25519, key should not get updated - assert not download_device_keys(user_devices) - assert self.device.device_keys[self.alice]['JLAFKJWSCS'].ed25519 == \ - device.ed25519 - - self.device.device_keys.clear() # All the remaining responses are wrong and we should not add the key for _ in range(4): assert not download_device_keys(user_devices) - assert self.device.device_keys == {} + assert device.device_keys == {} assert len(responses.calls) == 7 @@ -172,84 +178,79 @@ def error_on_first_download(user_devices, since_token=None): assert not thread.is_alive() @responses.activate - def test_get_room_device_keys(self): - self.device_list.tracked_user_ids.clear() + def test_get_room_device_keys(self, device_list): room = self.cli._mkroom(self.room_id) room._members[self.alice] = User(self.cli.api, self.alice) responses.add(responses.POST, self.query_url, json=example_key_query_response) # Blocking - self.device_list.get_room_device_keys(room) - assert self.device_list.tracked_user_ids == {self.alice} - assert self.device_list.device_keys[self.alice]['JLAFKJWSCS'] + device_list.get_room_device_keys(room) + assert device_list.tracked_user_ids == {self.alice} + assert device_list.device_keys[self.alice]['JLAFKJWSCS'] # Same, but we already track the user - self.device_list.get_room_device_keys(room) + device_list.get_room_device_keys(room) # Non-blocking - self.device_list.tracked_user_ids.clear() + device_list.tracked_user_ids.clear() # We have to block for testing purposes, though - self.device_list.update_thread.event.clear() - self.device_list.get_room_device_keys(room, blocking=False) - self.device_list.update_thread.event.wait() + device_list.update_thread.event.clear() + device_list.get_room_device_keys(room, blocking=False) + device_list.update_thread.event.wait() # Same, but we already track the user - self.device_list.get_room_device_keys(room, blocking=False) + device_list.get_room_device_keys(room, blocking=False) @responses.activate - def test_track_users(self): - self.device_list.tracked_user_ids.clear() + def test_track_users(self, device_list): responses.add(responses.POST, self.query_url, json=example_key_query_response) - self.device_list.update_thread.event.clear() - self.device_list.track_users({self.alice}) - self.device_list.update_thread.event.wait() - assert self.device_list.tracked_user_ids == {self.alice} + device_list.update_thread.event.clear() + device_list.track_users({self.alice}) + device_list.update_thread.event.wait() + assert device_list.tracked_user_ids == {self.alice} assert len(responses.calls) == 1 # Same, but we are already tracking Alice - self.device_list.track_users({self.alice}) + device_list.track_users({self.alice}) assert len(responses.calls) == 1 - def test_stop_tracking_users(self): - self.device_list.tracked_user_ids.clear() - self.device_list.tracked_user_ids.add(self.alice) - self.device_list.outdated_user_ids.clear() - self.device_list.outdated_user_ids.add(self.alice) + def test_stop_tracking_users(self, device_list): + device_list.tracked_user_ids.add(self.alice) + device_list.outdated_user_ids.add(self.alice) - self.device_list.stop_tracking_users({self.alice}) + device_list.stop_tracking_users({self.alice}) - assert not self.device_list.tracked_user_ids - assert not self.device_list.outdated_user_ids + assert not device_list.tracked_user_ids + assert not device_list.outdated_user_ids - def test_pending_users(self): + def test_pending_users(self, device_list): # Say Alice is already tracked to avoid triggering dowload process - self.device_list.tracked_user_ids.add(self.alice) + device_list.tracked_user_ids.add(self.alice) - self.device_list.track_user_no_download(self.alice) - assert self.alice in self.device_list.pending_outdated_user_ids + device_list.track_user_no_download(self.alice) + assert self.alice in device_list.pending_outdated_user_ids - self.device_list.track_pending_users() - assert self.alice not in self.device_list.pending_outdated_user_ids + device_list.track_pending_users() + assert self.alice not in device_list.pending_outdated_user_ids @responses.activate - def test_update_user_device_keys(self): - self.device_list.tracked_user_ids.clear() + def test_update_user_device_keys(self, device_list): responses.add(responses.POST, self.query_url, json=example_key_query_response) - self.device_list.update_user_device_keys({self.alice}) + device_list.update_user_device_keys({self.alice}) assert len(responses.calls) == 0 - self.device_list.tracked_user_ids.add(self.alice) + device_list.tracked_user_ids.add(self.alice) - self.device_list.update_thread.event.clear() - self.device_list.update_user_device_keys({self.alice}, since_token='dummy') - self.device_list.update_thread.event.wait() + device_list.update_thread.event.clear() + device_list.update_user_device_keys({self.alice}, since_token='dummy') + device_list.update_thread.event.wait() assert len(responses.calls) == 1 @responses.activate - def test_update_after_restart(self): + def test_update_after_restart(self, device_list): keys_changes_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/changes' class DB(DummyStore): @@ -258,20 +259,18 @@ def __getattribute__(self, name): if name == 'get_sync_token': return lambda: 'test' return super(DB, self).__getattribute__(name) - db = self.device_list.db # First launch, no sync token - self.device_list.update_after_restart('test') + device_list.update_after_restart('test') - self.device_list.db = DB() + device_list.db = DB() responses.add(responses.GET, keys_changes_url, json={}) - self.device_list.update_after_restart('test') + device_list.update_after_restart('test') resp = {'left': 'test', 'changed': self.user_id} responses.replace(responses.GET, keys_changes_url, json=resp) - self.device_list.tracked_user_ids.clear() - self.device_list.update_after_restart('test') - self.device_list.db = db + device_list.tracked_user_ids.clear() + device_list.update_after_restart('test') def test_outdated_users_set(): diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index a315549d..d21dd8be 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -33,22 +33,53 @@ class TestOlmDevice: user_id = '@user:matrix.org' room_id = '!test:example.com' device_id = 'QBUAZIFURK' - device = OlmDevice(cli.api, user_id, device_id) - signing_key = device.olm_account.identity_keys['ed25519'] alice = '@alice:example.com' alice_device_id = 'JLAFKJWSCS' alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' alice_ed_key = '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' alice_device = Device(cli.api, alice, alice_device_id, database=DummyStore(), curve25519_key=alice_curve_key, ed25519_key=alice_ed_key) - alice_olm_session = olm.OutboundSession( - device.olm_account, alice_curve_key, alice_curve_key) room = cli._mkroom(room_id) room._members[alice] = User(cli.api, alice) - # allow to_device api call to work well with responses - device.api._make_txn_id = lambda: 1 - def test_sign_json(self): + upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + + @pytest.fixture() + def device(self): + device = OlmDevice(self.cli.api, self.user_id, self.device_id) + # allow to_device api call to work well with responses + device.api._make_txn_id = lambda: 1 + return device + + @pytest.fixture() + def signing_key(self, device): + return device.olm_account.identity_keys['ed25519'] + + @pytest.fixture() + def olm_session_with_alice(self, device): + session = olm.OutboundSession(device.olm_account, self.alice_curve_key, + self.alice_curve_key) + device.device_keys[self.alice][self.alice_device_id] = self.alice_device + device.olm_sessions[self.alice_curve_key] = [session] + + @pytest.fixture() + def alice_olm_device(self, device): + """Establish an Olm session from Alice to us, and return Alice's Olm device.""" + alice_device = OlmDevice(device.api, self.alice, self.alice_device_id) + alice_device.device_keys[self.user_id][self.device_id] = device + device.device_keys[self.alice][self.alice_device_id] = alice_device + + device.olm_account.generate_one_time_keys(1) + otk = next(iter(device.olm_account.one_time_keys['curve25519'].values())) + device.olm_account.mark_keys_as_published() + sender_key = device.curve25519 + session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) + alice_device.olm_sessions[sender_key] = [session] + return alice_device + + def test_sign_json(self, device): example_payload = { "name": "example.org", "unsigned": { @@ -57,14 +88,14 @@ def test_sign_json(self): } saved_payload = deepcopy(example_payload) - signed_payload = self.device.sign_json(example_payload) + signed_payload = device.sign_json(example_payload) signature = signed_payload.pop('signatures') # We should not have modified the payload besides the signatures key assert example_payload == saved_payload - key_id = 'ed25519:' + self.device_id + key_id = 'ed25519:' + device.device_id assert signature[self.user_id][key_id] - def test_verify_json(self): + def test_verify_json(self, device): example_payload = { "test": "test", "unsigned": { @@ -80,54 +111,53 @@ def test_verify_json(self): saved_payload = deepcopy(example_payload) signing_key = "WQF5z9b4DV1DANI5HUMJfhTIDvJs1jkoGTLY6AQdjF0" - assert self.device.verify_json(example_payload, signing_key, self.user_id, - self.device_id) + assert device.verify_json(example_payload, signing_key, self.user_id, + device.device_id) # We should not have modified the payload assert example_payload == saved_payload # Try to verify an object that has been tampered with example_payload['test'] = 'test1' - assert not self.device.verify_json(example_payload, signing_key, self.user_id, - self.device_id) + assert not device.verify_json(example_payload, signing_key, self.user_id, + device.device_id) # Try to verify invalid payloads example_payload['signatures'].pop(self.user_id) - assert not self.device.verify_json(example_payload, signing_key, self.user_id, - self.device_id) + assert not device.verify_json(example_payload, signing_key, self.user_id, + device.device_id) example_payload.pop('signatures') - assert not self.device.verify_json(example_payload, signing_key, self.user_id, - self.device_id) + assert not device.verify_json(example_payload, signing_key, self.user_id, + device.device_id) - def test_sign_verify(self): + def test_sign_verify(self, device, signing_key): example_payload = { "name": "example.org", } - signed_payload = self.device.sign_json(example_payload) - assert self.device.verify_json(signed_payload, self.signing_key, self.user_id, - self.device_id) + signed_payload = device.sign_json(example_payload) + assert device.verify_json(signed_payload, signing_key, self.user_id, + device.device_id) @responses.activate - def test_upload_identity_keys(self): - upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' - self.device.one_time_keys_manager.server_counts = {} + def test_upload_identity_keys(self, device, signing_key): + device.one_time_keys_manager.server_counts = {} resp = deepcopy(example_key_upload_response) - responses.add(responses.POST, upload_url, json=resp) + responses.add(responses.POST, self.upload_url, json=resp) - assert self.device.upload_identity_keys() is None - assert self.device.one_time_keys_manager.server_counts == \ + assert device.upload_identity_keys() is None + assert device.one_time_keys_manager.server_counts == \ resp['one_time_key_counts'] req_device_keys = json.loads(responses.calls[0].request.body)['device_keys'] assert req_device_keys['user_id'] == self.user_id assert req_device_keys['device_id'] == self.device_id - assert req_device_keys['algorithms'] == self.device._algorithms + assert req_device_keys['algorithms'] == device._algorithms assert 'keys' in req_device_keys assert 'signatures' in req_device_keys - assert self.device.verify_json(req_device_keys, self.signing_key, self.user_id, - self.device_id) + assert device.verify_json(req_device_keys, signing_key, self.user_id, + self.device_id) @pytest.mark.parametrize('proportion', [-1, 2]) def test_upload_identity_keys_invalid(self, proportion): @@ -140,11 +170,10 @@ def test_upload_identity_keys_invalid(self, proportion): @responses.activate @pytest.mark.parametrize('proportion', [0, 1, 0.5, 0.33]) def test_upload_one_time_keys(self, proportion): - upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' resp = deepcopy(example_key_upload_response) counts = resp['one_time_key_counts'] counts['curve25519'] = counts['signed_curve25519'] = 10 - responses.add(responses.POST, upload_url, json=resp) + responses.add(responses.POST, self.upload_url, json=resp) device = OlmDevice( self.cli.api, self.user_id, self.device_id, signed_keys_proportion=proportion) @@ -177,41 +206,37 @@ def test_upload_one_time_keys(self, proportion): device.device_id) @responses.activate - def test_upload_one_time_keys_enough(self): - upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' - self.device.one_time_keys_manager.server_counts = {} - limit = self.device.olm_account.max_one_time_keys // 2 + def test_upload_one_time_keys_enough(self, device): + device.one_time_keys_manager.server_counts = {} + limit = device.olm_account.max_one_time_keys // 2 resp = {'one_time_key_counts': {'signed_curve25519': limit}} - responses.add(responses.POST, upload_url, json=resp) + responses.add(responses.POST, self.upload_url, json=resp) - assert not self.device.upload_one_time_keys() + assert not device.upload_one_time_keys() @responses.activate - def test_upload_one_time_keys_force_update(self): - upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' - self.device.one_time_keys_manager.server_counts = {'curve25519': 10} + def test_upload_one_time_keys_force_update(self, device): + device.one_time_keys_manager.server_counts = {'curve25519': 10} resp = deepcopy(example_key_upload_response) - responses.add(responses.POST, upload_url, json=resp) + responses.add(responses.POST, self.upload_url, json=resp) - self.device.upload_one_time_keys() + device.upload_one_time_keys() assert len(responses.calls) == 1 - self.device.upload_one_time_keys(force_update=True) + device.upload_one_time_keys(force_update=True) assert len(responses.calls) == 3 @responses.activate @pytest.mark.parametrize('count,should_upload', [(0, True), (25, False), (4, True)]) - def test_update_one_time_key_counts(self, count, should_upload): - upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload' - responses.add(responses.POST, upload_url, json={'one_time_key_counts': {}}) - self.device.one_time_keys_manager.target_counts['signed_curve25519'] = 50 - self.device.one_time_keys_manager.server_counts.clear() + def test_update_one_time_key_counts(self, device, count, should_upload): + responses.add(responses.POST, self.upload_url, json={'one_time_key_counts': {}}) + device.one_time_keys_manager.target_counts['signed_curve25519'] = 50 count_dict = {} if count: count_dict['signed_curve25519'] = count - self.device.update_one_time_key_counts(count_dict) + device.update_one_time_key_counts(count_dict) if should_upload: if count: @@ -233,107 +258,100 @@ def test_invalid_keys_threshold(self, threshold): keys_threshold=threshold) @responses.activate - def test_olm_start_sessions(self): - claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' - responses.add(responses.POST, claim_url, json=example_claim_keys_response) - self.device.olm_sessions.clear() - self.device.device_keys.clear() + def test_olm_start_sessions(self, device): + responses.add(responses.POST, self.claim_url, json=example_claim_keys_response) user_devices = {self.alice: {self.alice_device_id}} # We don't have alice's keys - self.device.olm_start_sessions(user_devices) - assert not self.device.olm_sessions[self.alice_curve_key] + device.olm_start_sessions(user_devices) + assert not device.olm_sessions[self.alice_curve_key] # Cover logging part olm_device.logger.setLevel(logging.WARNING) # Now should be good - self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device - self.device.olm_start_sessions(user_devices) - assert self.device.olm_sessions[self.alice_curve_key] + device.device_keys[self.alice][self.alice_device_id] = self.alice_device + device.olm_start_sessions(user_devices) + assert device.olm_sessions[self.alice_curve_key] # With failures and wrong signature - self.device.olm_sessions.clear() + device.olm_sessions.clear() payload = deepcopy(example_claim_keys_response) payload['failures'] = {'dummy': 1} key = payload['one_time_keys'][self.alice][self.alice_device_id] key['signed_curve25519:AAAAAQ']['test'] = 1 - responses.replace(responses.POST, claim_url, json=payload) + responses.replace(responses.POST, self.claim_url, json=payload) - self.device.olm_start_sessions(user_devices) - assert not self.device.olm_sessions[self.alice_curve_key] + device.olm_start_sessions(user_devices) + assert not device.olm_sessions[self.alice_curve_key] # Missing requested user and devices user_devices[self.alice].add('test') user_devices['test'] = 'test' - self.device.olm_start_sessions(user_devices) + device.olm_start_sessions(user_devices) @responses.activate - def test_olm_build_encrypted_event(self): - self.device.device_keys.clear() - self.device.olm_sessions.clear() + def test_olm_build_encrypted_event(self, device): event_content = {'dummy': 'example'} # We don't have Alice's keys with pytest.raises(RuntimeError): - self.device.olm_build_encrypted_event( + device.olm_build_encrypted_event( 'm.text', event_content, self.alice, self.alice_device_id) # We don't have a session with Alice - self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + device.device_keys[self.alice][self.alice_device_id] = self.alice_device with pytest.raises(RuntimeError): - self.device.olm_build_encrypted_event( + device.olm_build_encrypted_event( 'm.text', event_content, self.alice, self.alice_device_id) - claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' - responses.add(responses.POST, claim_url, json=example_claim_keys_response) + responses.add(responses.POST, self.claim_url, json=example_claim_keys_response) user_devices = {self.alice: {self.alice_device_id}} - self.device.olm_start_sessions(user_devices) - assert self.device.olm_build_encrypted_event( + device.olm_start_sessions(user_devices) + assert device.olm_build_encrypted_event( 'm.text', event_content, self.alice, self.alice_device_id) - def test_olm_decrypt(self): - self.device.olm_sessions.clear() + def test_olm_decrypt(self, device): # Since this method doesn't care about high-level event formatting, we will # generate things at low level - our_account = self.device.olm_account + our_account = device.olm_account # Alice needs to start a session with us alice = olm.Account() sender_key = alice.identity_keys['curve25519'] our_account.generate_one_time_keys(1) otk = next(iter(our_account.one_time_keys['curve25519'].values())) - self.device.olm_account.mark_keys_as_published() + device.olm_account.mark_keys_as_published() session = olm.OutboundSession(alice, our_account.identity_keys['curve25519'], otk) plaintext = {"test": "test"} message = session.encrypt(json.dumps(plaintext)) - assert self.device._olm_decrypt(message, sender_key) == plaintext + assert device._olm_decrypt(message, sender_key) == plaintext # New pre-key message, but the session exists this time message = session.encrypt(json.dumps(plaintext)) - assert self.device._olm_decrypt(message, sender_key) == plaintext + assert device._olm_decrypt(message, sender_key) == plaintext # Try to decrypt the same message twice with pytest.raises(RuntimeError): - self.device._olm_decrypt(message, sender_key) + device._olm_decrypt(message, sender_key) # Answer Alice in order to have a type 1 message - message = self.device.olm_sessions[sender_key][0].encrypt(json.dumps(plaintext)) + message = device.olm_sessions[sender_key][0].encrypt(json.dumps(plaintext)) session.decrypt(message) message = session.encrypt(json.dumps(plaintext)) - assert self.device._olm_decrypt(message, sender_key) == plaintext + assert device._olm_decrypt(message, sender_key) == plaintext # Try to decrypt the same message type 1 twice with pytest.raises(RuntimeError): - self.device._olm_decrypt(message, sender_key) + device._olm_decrypt(message, sender_key) # Try to decrypt a message from a session that reused a one-time key otk_reused_session = olm.OutboundSession( alice, our_account.identity_keys['curve25519'], otk) message = otk_reused_session.encrypt(json.dumps(plaintext)) with pytest.raises(RuntimeError): - self.device._olm_decrypt(message, sender_key) + device._olm_decrypt(message, sender_key) # Try to decrypt an invalid type 0 message our_account.generate_one_time_keys(1) @@ -341,126 +359,105 @@ def test_olm_decrypt(self): wrong_session = olm.OutboundSession(alice, sender_key, otk) message = wrong_session.encrypt(json.dumps(plaintext)) with pytest.raises(RuntimeError): - self.device._olm_decrypt(message, sender_key) + device._olm_decrypt(message, sender_key) # Try to decrypt a type 1 message for which we have no sessions message = session.encrypt(json.dumps(plaintext)) - self.device.olm_sessions.clear() + device.olm_sessions.clear() with pytest.raises(RuntimeError): - self.device._olm_decrypt(message, sender_key) - - def test_olm_decrypt_event(self): - self.device.device_keys.clear() - self.device.olm_sessions.clear() - alice_device = OlmDevice(self.device.api, self.alice, self.alice_device_id) - alice_device.device_keys[self.user_id][self.device_id] = self.device - self.device.device_keys[self.alice][self.alice_device_id] = alice_device - - # Artificially start an Olm session from Alice - self.device.olm_account.generate_one_time_keys(1) - otk = next(iter(self.device.olm_account.one_time_keys['curve25519'].values())) - self.device.olm_account.mark_keys_as_published() - sender_key = self.device.curve25519 - session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) - alice_device.olm_sessions[sender_key] = [session] + device._olm_decrypt(message, sender_key) - encrypted_event = alice_device.olm_build_encrypted_event( + def test_olm_decrypt_event(self, device, alice_olm_device): + encrypted_event = alice_olm_device.olm_build_encrypted_event( 'example_type', {'content': 'test'}, self.user_id, self.device_id) # Now we can test - self.device.olm_decrypt_event(encrypted_event, self.alice) + device.olm_decrypt_event(encrypted_event, self.alice) # Device verification - alice_device.verified = True - encrypted_event = alice_device.olm_build_encrypted_event( + alice_olm_device.verified = True + encrypted_event = alice_olm_device.olm_build_encrypted_event( 'example_type', {'content': 'test'}, self.user_id, self.device_id) - self.device.olm_decrypt_event(encrypted_event, self.alice) + device.olm_decrypt_event(encrypted_event, self.alice) # The signing_key is wrong - encrypted_event = alice_device.olm_build_encrypted_event( + encrypted_event = alice_olm_device.olm_build_encrypted_event( 'example_type', {'content': 'test'}, self.user_id, self.device_id) - self.device.device_keys[self.alice][self.alice_device_id]._ed25519 = 'wrong' + device.device_keys[self.alice][self.alice_device_id]._ed25519 = 'wrong' with pytest.raises(RuntimeError): - self.device.olm_decrypt_event(encrypted_event, self.alice) + device.olm_decrypt_event(encrypted_event, self.alice) # We do not have the keys - encrypted_event = alice_device.olm_build_encrypted_event( + encrypted_event = alice_olm_device.olm_build_encrypted_event( 'example_type', {'content': 'test'}, self.user_id, self.device_id) - self.device.device_keys[self.alice].clear() - self.device.olm_decrypt_event(encrypted_event, self.alice) - self.device.device_keys[self.alice][self.alice_device_id] = alice_device - alice_device.verified = False + device.device_keys[self.alice].clear() + device.olm_decrypt_event(encrypted_event, self.alice) + device.device_keys[self.alice][self.alice_device_id] = alice_olm_device + alice_olm_device.verified = False # Type 1 Olm payload - alice_device.olm_decrypt_event( - self.device.olm_build_encrypted_event( + alice_olm_device.olm_decrypt_event( + device.olm_build_encrypted_event( 'example_type', {'content': 'test'}, self.alice, self.alice_device_id ), self.user_id) - encrypted_event = alice_device.olm_build_encrypted_event( + encrypted_event = alice_olm_device.olm_build_encrypted_event( 'example_type', {'content': 'test'}, self.user_id, self.device_id) - self.device.olm_decrypt_event(encrypted_event, self.alice) + device.olm_decrypt_event(encrypted_event, self.alice) - encrypted_event = alice_device.olm_build_encrypted_event( + encrypted_event = alice_olm_device.olm_build_encrypted_event( 'example_type', {'content': 'test'}, self.user_id, self.device_id) with pytest.raises(RuntimeError): - self.device.olm_decrypt_event(encrypted_event, 'wrong') + device.olm_decrypt_event(encrypted_event, 'wrong') wrong_event = deepcopy(encrypted_event) wrong_event['algorithm'] = 'wrong' with pytest.raises(RuntimeError): - self.device.olm_decrypt_event(wrong_event, self.alice) + device.olm_decrypt_event(wrong_event, self.alice) wrong_event = deepcopy(encrypted_event) wrong_event['ciphertext'] = {} with pytest.raises(RuntimeError): - self.device.olm_decrypt_event(wrong_event, self.alice) + device.olm_decrypt_event(wrong_event, self.alice) - encrypted_event = alice_device.olm_build_encrypted_event( + encrypted_event = alice_olm_device.olm_build_encrypted_event( 'example_type', {'content': 'test'}, self.user_id, self.device_id) - self.device.user_id = 'wrong' + device.user_id = 'wrong' with pytest.raises(RuntimeError): - self.device.olm_decrypt_event(encrypted_event, self.alice) - self.device.user_id = self.user_id + device.olm_decrypt_event(encrypted_event, self.alice) + device.user_id = self.user_id - encrypted_event = alice_device.olm_build_encrypted_event( + encrypted_event = alice_olm_device.olm_build_encrypted_event( 'example_type', {'content': 'test'}, self.user_id, self.device_id) - backup = self.device.ed25519 - self.device._ed25519 = 'wrong' + device._ed25519 = 'wrong' with pytest.raises(RuntimeError): - self.device.olm_decrypt_event(encrypted_event, self.alice) - self.device._ed25519 = backup + device.olm_decrypt_event(encrypted_event, self.alice) @responses.activate - def test_olm_ensure_sessions(self): - claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' - responses.add(responses.POST, claim_url, json=example_claim_keys_response) - self.device.olm_sessions.clear() - self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + def test_olm_ensure_sessions(self, device): + responses.add(responses.POST, self.claim_url, json=example_claim_keys_response) + device.device_keys[self.alice][self.alice_device_id] = self.alice_device user_devices = {self.alice: [self.alice_device_id]} - self.device.olm_ensure_sessions(user_devices) - assert self.device.olm_sessions[self.alice_curve_key] + device.olm_ensure_sessions(user_devices) + assert device.olm_sessions[self.alice_curve_key] assert len(responses.calls) == 1 - self.device.olm_ensure_sessions(user_devices) + device.olm_ensure_sessions(user_devices) assert len(responses.calls) == 1 @responses.activate - def test_megolm_share_session(self): - claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' - responses.add(responses.POST, claim_url, json=example_claim_keys_response) - to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' - responses.add(responses.PUT, to_device_url, json={}) - self.device.olm_sessions.clear() - self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device - self.device.device_keys['dummy']['dummy'] = \ + def test_megolm_share_session(self, device): + responses.add(responses.POST, self.claim_url, json=example_claim_keys_response) + responses.add(responses.PUT, self.to_device_url, json={}) + device.device_keys[self.alice][self.alice_device_id] = self.alice_device + device.device_keys['dummy']['dummy'] = \ Device(self.cli.api, 'dummy', 'dummy', curve25519_key='a', ed25519_key='a') user_devices = {self.alice: [self.alice_device_id], 'dummy': ['dummy']} session = MegolmOutboundSession() # Sharing with Alice should succeed, but dummy will fail - self.device.megolm_share_session(self.room_id, user_devices, session) + device.megolm_share_session(self.room_id, user_devices, session) assert session.devices == {self.alice_device_id, 'dummy'} req = json.loads(responses.calls[1].request.body)['messages'] @@ -468,16 +465,14 @@ def test_megolm_share_session(self): assert 'dummy' not in req @responses.activate - def test_megolm_start_session(self): - to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' - responses.add(responses.PUT, to_device_url, json={}) - self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device - self.device.device_list.tracked_user_ids.add(self.alice) - self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + @pytest.mark.usefixtures('olm_session_with_alice') + def test_megolm_start_session(self, device): + responses.add(responses.PUT, self.to_device_url, json={}) + device.device_list.tracked_user_ids.add(self.alice) user_devices = {self.alice: [self.alice_device_id]} - self.device.megolm_start_session(self.room, user_devices) - session = self.device.megolm_outbound_sessions[self.room_id] + device.megolm_start_session(self.room, user_devices) + session = device.megolm_outbound_sessions[self.room_id] assert self.alice_device_id in session.devices # Check that we can decrypt our own messages @@ -485,7 +480,7 @@ def test_megolm_start_session(self): 'type': 'test', 'content': {'test': 'test'}, } - encrypted_event = self.device.megolm_build_encrypted_event(self.room, plaintext) + encrypted_event = device.megolm_build_encrypted_event(self.room, plaintext) event = { 'sender': self.alice, 'room_id': self.room_id, @@ -494,147 +489,122 @@ def test_megolm_start_session(self): 'origin_server_ts': 1, 'event_id': 1 } - self.device.megolm_decrypt_event(event) + device.megolm_decrypt_event(event) assert event['content'] == plaintext['content'] @responses.activate - def test_megolm_share_session_with_new_devices(self): - to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' - responses.add(responses.PUT, to_device_url, json={}) - self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device - self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + @pytest.mark.usefixtures('olm_session_with_alice') + def test_megolm_share_session_with_new_devices(self, device): + responses.add(responses.PUT, self.to_device_url, json={}) session = MegolmOutboundSession() - self.device.megolm_outbound_sessions[self.room_id] = session + device.megolm_outbound_sessions[self.room_id] = session user_devices = {self.alice: [self.alice_device_id]} - self.device.megolm_share_session_with_new_devices( - self.room, user_devices, session) + device.megolm_share_session_with_new_devices(self.room, user_devices, session) assert self.alice_device_id in session.devices assert len(responses.calls) == 1 - self.device.megolm_share_session_with_new_devices( - self.room, user_devices, session) + device.megolm_share_session_with_new_devices(self.room, user_devices, session) assert len(responses.calls) == 1 - def test_megolm_get_recipients(self): - self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + def test_megolm_get_recipients(self, device): + device.device_keys[self.alice][self.alice_device_id] = self.alice_device - user_devices, _ = self.device.megolm_get_recipients(self.room) + user_devices, _ = device.megolm_get_recipients(self.room) assert user_devices == {self.alice: [self.alice_device_id]} - self.device.megolm_outbound_sessions.clear() session = MegolmOutboundSession() - self.device.megolm_outbound_sessions[self.room_id] = session + device.megolm_outbound_sessions[self.room_id] = session - user_devices, removed = self.device.megolm_get_recipients(self.room, session) + user_devices, removed = device.megolm_get_recipients(self.room, session) assert user_devices == {self.alice: [self.alice_device_id]} and not removed self.alice_device.blacklisted = True - _, removed = self.device.megolm_get_recipients(self.room, session) + _, removed = device.megolm_get_recipients(self.room, session) assert not removed session.add_device(self.alice_device_id) - _, removed = self.device.megolm_get_recipients(self.room, session) - assert removed and self.room_id not in self.device.megolm_outbound_sessions + _, removed = device.megolm_get_recipients(self.room, session) + assert removed and self.room_id not in device.megolm_outbound_sessions self.alice_device.blacklisted = False self.room.verify_devices = True with pytest.raises(E2EUnknownDevices) as e: - self.device.megolm_get_recipients(self.room) + device.megolm_get_recipients(self.room) assert e.value.user_devices == {self.alice: [self.alice_device]} self.room.verify_devices = False @responses.activate - def test_megolm_build_encrypted_event(self): - to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' - responses.add(responses.PUT, to_device_url, json={}) - self.device.megolm_outbound_sessions.clear() - self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device - self.device.device_list.tracked_user_ids.add(self.alice) - self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + @pytest.mark.usefixtures('olm_session_with_alice') + def test_megolm_build_encrypted_event(self, device): + responses.add(responses.PUT, self.to_device_url, json={}) + device.device_list.tracked_user_ids.add(self.alice) event = {'type': 'm.room.message', 'content': {'body': 'test'}} self.room.rotation_period_msgs = 1 - self.device.megolm_build_encrypted_event(self.room, event) + device.megolm_build_encrypted_event(self.room, event) - self.device.megolm_build_encrypted_event(self.room, event) + device.megolm_build_encrypted_event(self.room, event) - session = self.device.megolm_outbound_sessions[self.room_id] + session = device.megolm_outbound_sessions[self.room_id] session.encrypt('test') - self.device.megolm_build_encrypted_event(self.room, event) - assert self.device.megolm_outbound_sessions[self.room_id].id != session.id + device.megolm_build_encrypted_event(self.room, event) + assert device.megolm_outbound_sessions[self.room_id].id != session.id - def test_megolm_remove_outbound_session(self): + def test_megolm_remove_outbound_session(self, device): session = MegolmOutboundSession() - self.device.megolm_outbound_sessions[self.room_id] = session - self.device.megolm_remove_outbound_session(self.room_id) - self.device.megolm_remove_outbound_session(self.room_id) + device.megolm_outbound_sessions[self.room_id] = session + device.megolm_remove_outbound_session(self.room_id) + device.megolm_remove_outbound_session(self.room_id) @responses.activate - def test_send_encrypted_message(self): + @pytest.mark.usefixtures('olm_session_with_alice') + def test_send_encrypted_message(self, device): message_url = HOSTNAME + MATRIX_V2_API_PATH + \ '/rooms/{}/send/m.room.encrypted/1'.format(quote(self.room.room_id)) responses.add(responses.PUT, message_url, json={}) - self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device - self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] session = MegolmOutboundSession() session.add_device(self.alice_device_id) - self.device.megolm_outbound_sessions[self.room_id] = session + device.megolm_outbound_sessions[self.room_id] = session - self.device.send_encrypted_message(self.room, {'test': 'test'}) + device.send_encrypted_message(self.room, {'test': 'test'}) - def test_megolm_add_inbound_session(self): + def test_megolm_add_inbound_session(self, device): session = MegolmOutboundSession() - self.device.megolm_inbound_sessions.clear() - assert not self.device.megolm_add_inbound_session( + assert not device.megolm_add_inbound_session( self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, 'wrong') - assert self.device.megolm_add_inbound_session( + assert device.megolm_add_inbound_session( self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, session.session_key ) assert session.id in \ - self.device.megolm_inbound_sessions[self.room_id][self.alice_curve_key] - assert not self.device.megolm_add_inbound_session( + device.megolm_inbound_sessions[self.room_id][self.alice_curve_key] + assert not device.megolm_add_inbound_session( self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, session.session_key ) - assert not self.device.megolm_add_inbound_session( + assert not device.megolm_add_inbound_session( self.room_id, self.alice_curve_key, self.alice_ed_key, 'wrong', session.session_key ) - def test_handle_room_key_event(self): - self.device.megolm_inbound_sessions.clear() - - self.device.handle_room_key_event(example_room_key_event, self.alice_curve_key) - assert self.room_id in self.device.megolm_inbound_sessions + def test_handle_room_key_event(self, device): + device.handle_room_key_event(example_room_key_event, self.alice_curve_key) + assert self.room_id in device.megolm_inbound_sessions - self.device.handle_room_key_event(example_room_key_event, self.alice_curve_key) + device.handle_room_key_event(example_room_key_event, self.alice_curve_key) event = deepcopy(example_room_key_event) event['content']['algorithm'] = 'wrong' - self.device.handle_room_key_event(event, self.alice_curve_key) + device.handle_room_key_event(event, self.alice_curve_key) event = deepcopy(example_room_key_event) event['content']['session_id'] = 'wrong' - self.device.handle_room_key_event(event, self.alice_curve_key) - - def test_olm_handle_encrypted_event(self): - self.device.olm_sessions.clear() - alice_device = OlmDevice(self.device.api, self.alice, self.alice_device_id) - alice_device.device_keys[self.user_id][self.device_id] = self.device - self.device.device_keys[self.alice][self.alice_device_id] = alice_device - - # Artificially start an Olm session from Alice - self.device.olm_account.generate_one_time_keys(1) - otk = next(iter(self.device.olm_account.one_time_keys['curve25519'].values())) - self.device.olm_account.mark_keys_as_published() - sender_key = self.device.curve25519 - session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) - alice_device.olm_sessions[sender_key] = [session] + device.handle_room_key_event(event, self.alice_curve_key) + def test_olm_handle_encrypted_event(self, device, alice_olm_device): content = example_room_key_event['content'] - encrypted_event = alice_device.olm_build_encrypted_event( + encrypted_event = alice_olm_device.olm_build_encrypted_event( 'm.room_key', content, self.user_id, self.device_id) event = { 'type': 'm.room.encrypted', @@ -642,25 +612,25 @@ def test_olm_handle_encrypted_event(self): 'sender': self.alice } - self.device.olm_handle_encrypted_event(event) + device.olm_handle_encrypted_event(event) # Decrypting the same event twice will trigger an error - self.device.olm_handle_encrypted_event(event) + device.olm_handle_encrypted_event(event) - encrypted_event = alice_device.olm_build_encrypted_event( + encrypted_event = alice_olm_device.olm_build_encrypted_event( 'm.other', content, self.user_id, self.device_id) event = { 'type': 'm.room.encrypted', 'content': encrypted_event, 'sender': self.alice } - self.device.olm_handle_encrypted_event(event) + device.olm_handle_encrypted_event(event) # Simulate redacted event event['content'].pop('algorithm') - self.device.olm_handle_encrypted_event(event) + device.olm_handle_encrypted_event(event) - def test_megolm_decrypt_event(self): + def test_megolm_decrypt_event(self, device): out_session = MegolmOutboundSession() plaintext = { @@ -687,64 +657,64 @@ def test_megolm_decrypt_event(self): } with pytest.raises(RuntimeError): - self.device.megolm_decrypt_event(event) + device.megolm_decrypt_event(event) session_key = out_session.session_key in_session = MegolmInboundSession(session_key, self.alice_ed_key) - sessions = self.device.megolm_inbound_sessions[self.room_id] + sessions = device.megolm_inbound_sessions[self.room_id] sessions[self.alice_curve_key][in_session.id] = in_session # Unknown message index with pytest.raises(RuntimeError): - self.device.megolm_decrypt_event(event) + device.megolm_decrypt_event(event) ciphertext = out_session.encrypt(json.dumps(plaintext)) event['content']['ciphertext'] = ciphertext - self.device.megolm_decrypt_event(event) + device.megolm_decrypt_event(event) assert event['content'] == plaintext['content'] # No replay attack event['content'] = content - self.device.megolm_decrypt_event(event) + device.megolm_decrypt_event(event) assert event['content'] == plaintext['content'] # Replay attack event['content'] = content event['event_id'] = 2 with pytest.raises(RuntimeError): - self.device.megolm_decrypt_event(event) + device.megolm_decrypt_event(event) event['event_id'] = 1 # Device verification - self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + device.device_keys[self.alice][self.alice_device_id] = self.alice_device event['content'] = content # Unverified - self.device.megolm_decrypt_event(event) + device.megolm_decrypt_event(event) assert event['content'] == plaintext['content'] assert isinstance(event, dict) event['content'] = content # Verified self.alice_device.verified = True - decrypted_event = self.device.megolm_decrypt_event(event) + decrypted_event = device.megolm_decrypt_event(event) assert decrypted_event['content'] == plaintext['content'] assert isinstance(decrypted_event, VerifiedEvent) in_session = MegolmInboundSession(session_key, self.alice_curve_key) - sessions = self.device.megolm_inbound_sessions[self.room_id] + sessions = device.megolm_inbound_sessions[self.room_id] sessions[self.alice_curve_key][in_session.id] = in_session # Wrong signing key with pytest.raises(RuntimeError): - self.device.megolm_decrypt_event(event) + device.megolm_decrypt_event(event) self.alice_device.verified = False event['content']['algorithm'] = 'wrong' with pytest.raises(RuntimeError): - self.device.megolm_decrypt_event(event) + device.megolm_decrypt_event(event) event['content'].pop('algorithm') event['type'] = 'encrypted' - self.device.megolm_decrypt_event(event) + device.megolm_decrypt_event(event) assert event['type'] == 'encrypted' From 5606b924c202a8a01d9a6b7f4f2332b4a31d4ce8 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 2 Aug 2018 17:50:57 +0200 Subject: [PATCH 56/66] fail to enable encryption on limited cache_level Encryption shouldn't be supported on limited cache_level. Even if it may work a bit, it causes a lot of data to be cached. --- matrix_client/client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/matrix_client/client.py b/matrix_client/client.py index 4edf4ca4..6b6abcbe 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -133,6 +133,9 @@ def __init__(self, base_url, token=None, user_id=None, if restore_device_id and not encryption: raise ValueError("restore_device_id only makes sense when encryption is " "enabled.") + if encryption and cache_level != CACHE.ALL: + raise ValueError("Encryption is unvailable on cache_level other than " + "CACHE.ALL.") self.api = MatrixHttpApi(base_url, token) self.api.validate_certificate(valid_cert_check) From 0564d21081303d7326da7293f4a10c560fa8a438 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 25 Jul 2018 16:43:22 +0200 Subject: [PATCH 57/66] add key sharing functionality --- docs/source/matrix_client.rst | 5 + matrix_client/client.py | 54 +++++++- matrix_client/crypto/key_sharing.py | 195 ++++++++++++++++++++++++++++ matrix_client/crypto/olm_device.py | 60 ++++++--- matrix_client/crypto/sessions.py | 11 +- matrix_client/errors.py | 6 + matrix_client/room.py | 5 +- test/client_test.py | 8 +- test/crypto/olm_device_test.py | 20 +-- 9 files changed, 328 insertions(+), 36 deletions(-) create mode 100644 matrix_client/crypto/key_sharing.py diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index 16a6aedf..44909b15 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -81,3 +81,8 @@ matrix_client.crypto :members: :undoc-members: :show-inheritance: + +.. automodule:: matrix_client.crypto.key_sharing + :members: + :undoc-members: + :show-inheritance: diff --git a/matrix_client/client.py b/matrix_client/client.py index 6b6abcbe..03d5faba 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -491,6 +491,51 @@ def add_leave_listener(self, callback): """ self.left_listeners.append(callback) + def add_key_request_listener(self, callback): + """Add a listener that will send a callback when a device requests keys. + + NOTE: + This can only be used after logging in. + + NOTE: + Only one listener can exist, and calling this method a second time will + discard the first one. + + Args: + callback (func(dict, func(list))): Callback called when key requests arrive. + It is given a map from device ID to :class:`.Device` object, which + corresponds to the devices requesting keys. This map should be used to + verify devices if relevant. The callback then needs to call the function + it was given as second argument with a list of the device IDs whose key + requests should be answered. Key requests from other devices will be + discarded. + """ + self.olm_device.key_sharing_manager.key_request_callback = callback + + def add_key_forward_listener(self, callback): + """Add a listener that will send a callback when we receive a key. + + When a listener exists, keys are requested automatically each time we are unable + to decrypt a Megolm event due to missing keys. + A client could maintain a map from the ``session_id`` property of a + ``m.room.encrypted`` event to a list of corresponding events, and use this + method to be notified when it can try to decrypt them again. + + NOTE: + This can only be used after logging in. Since keys are not requested when a + listener doesn't exist, a client wanting to requests keys on start-up should + login with ``sync=False``, then add a listener, and then sync. + + NOTE: + Only one listener can exist, and calling this method a second time will + discard the first one. + + Args: + callback (func(string)): Callback called when a forwarded key arrive. + It is given a Megolm session ID. + """ + self.olm_device.key_sharing_manager.key_forward_callback = callback + def listen_for_events(self, timeout_ms=30000): """ This function just calls _sync() @@ -648,8 +693,13 @@ def _sync(self, timeout_ms=30000): if 'to_device' in response: for event in response['to_device']['events']: - if event['type'] == 'm.room.encrypted' and self._encryption: - self.olm_device.olm_handle_encrypted_event(event) + if self._encryption: + if event['type'] == 'm.room.encrypted': + self.olm_device.olm_handle_encrypted_event(event) + elif event['type'] == 'm.room_key_request': + self.olm_device.key_sharing_manager.handle_key_request(event) + if self._encryption: + self.olm_device.key_sharing_manager.trigger_key_requests_callback() if self._encryption and 'device_one_time_keys_count' in response: self.olm_device.update_one_time_key_counts( diff --git a/matrix_client/crypto/key_sharing.py b/matrix_client/crypto/key_sharing.py new file mode 100644 index 00000000..252ebbe7 --- /dev/null +++ b/matrix_client/crypto/key_sharing.py @@ -0,0 +1,195 @@ +import logging +from collections import defaultdict + +logger = logging.getLogger(__name__) + + +class KeySharingManager(object): + + def __init__(self, api, user_id, device_id, olm_device): + self.api = api + self.user_id = user_id + self.device_id = device_id + self.olm_device = olm_device + self.queued_key_requests = defaultdict(dict) + self.outgoing_key_requests = set() + self.key_request_callback = None + self.key_forward_callback = None + + def handle_forwarded_room_key_event(self, event, sender, sender_key): + """Handle a ``m.forwarded_room_key`` event. + + The key it contains will be used only if it was requested previously, and comes + from a device owned by the current user. A cancelation will be sent. Otherwise, it + will be discarded, and no cancelation will be sent. + + Args: + event (dict): A ``m.forwarded_room_key`` event. + sender_key (str): The Curve25519 key of the event's sender. + """ + if sender != self.user_id: + logger.info('Ignoring m.forwarded_room_key event sent by %s.', sender) + return + content = event['content'] + if content['algorithm'] != self.olm_device._megolm_algorithm: + logger.info('Ignoring unsupported algorithm %s in m.forwarded_room_key ' + 'event from device %s.', content['algorithm'], sender_key) + return + + session_id = content['session_id'] + if session_id not in self.outgoing_key_requests: + logger.info('Ignoring session key we have not requested from device %s.', + sender_key) + return + + room_id = content['room_id'] + session_sender_key = content['sender_key'] + signing_key = content['sender_claimed_ed25519_key'] + chain = content['forwarding_curve25519_key_chain'] + chain.append(session_sender_key) + try: + self.olm_device.megolm_add_inbound_session( + room_id, session_sender_key, signing_key, session_id, + content['session_key'], forwarding_chain=chain, export_format=True + ) + except ValueError as e: + logger.warning('Error in forwarded room key payload for session %s: %s', + session_id, e) + return + payload = { + 'action': 'cancel_request', + 'request_id': session_id, + 'requesting_device_id': self.device_id + } + self.api.send_to_device('m.room_key_request', {self.user_id: {'*': payload}}) + self.outgoing_key_requests.discard(session_id) + if self.key_forward_callback: + self.key_forward_callback(session_id) + + def handle_key_request(self, event): + """Handle a ``m.room_key_request`` event. + + Args: + event (dict): m.room_key_request event. + """ + if event['sender'] != self.user_id: + logger.info("Ignoring m.room_key_request event from %s.", event['sender']) + return + + content = event['content'] + device_id = content['requesting_device_id'] + if device_id == self.device_id: + return + try: + self.olm_device.device_keys[self.user_id][device_id] + except KeyError: + logger.info("Ignoring m.room_key_request event from device %s, which " + "we don't own.", device_id) + return + + # Build a queue of key requests as we don't want to tell client of each requests, + # knowing that the canceling event might be coming right up next. + request_id = content['request_id'] + if content['action'] == 'request': + body = content['body'] + if body['algorithm'] != self.olm_device._megolm_algorithm: + return + if request_id not in self.queued_key_requests[device_id]: + self.queued_key_requests[device_id][request_id] = body + elif content['action'] == 'cancel_request': + # This doesn't remove request_id from the dict, so we will never + # add an event with this request ID again. + self.queued_key_requests[device_id][request_id].clear() + + def trigger_key_requests_callback(self): + if not self.key_request_callback: + return + devices = {} + for device_id in self.queued_key_requests: + device = self.olm_device.device_keys[self.user_id][device_id] + devices[device_id] = device + if devices: + self.key_request_callback(devices, self.process_key_requests) + + def process_key_requests(self, device_ids): + """Share the key requested by the given device_ids. + + This empties the key request queue we keep upon completion, meaning that any + request from a device not present in ``device_ids`` will be discarded. + + Args: + device_ids (iterable): The device IDs who should see their request answered, + if possible. + """ + logger.info('Sharing requested sessions with devices %s.', device_ids) + + # TODO: improve this as in the case of a new device which request keys + # on start-up, we may not have the time to fetch its keys. + self.olm_device.olm_ensure_sessions({self.user_id: device_ids}) + for device_id in device_ids: + if not self.queued_key_requests[device_id]: + continue + for event in self.queued_key_requests[device_id].values(): + session_id = event['session_id'] + room_id = event['room_id'] + sender_key = event['sender_key'] + sessions = self.olm_device.megolm_inbound_sessions[room_id][sender_key] + try: + session = sessions[session_id] + except KeyError: + session = self.olm_device.db.get_inbound_session(room_id, sender_key, + session_id) + if not session: + continue + payload = self.build_forwarded_room_key_event(room_id, sender_key, + session) + event = self.olm_device.olm_build_encrypted_event( + 'm.forwarded_room_key', payload, self.user_id, device_id) + self.api.send_to_device( + 'm.room.encrypted', {self.user_id: {device_id: event}}) + self.queued_key_requests.clear() + + def build_forwarded_room_key_event(self, room_id, sender_key, session): + payload = { + 'algorithm': self.olm_device._megolm_algorithm, + 'room_id': room_id, + 'sender_key': sender_key, + 'sender_claimed_ed25519_key': session.ed25519, + 'session_id': session.id, + 'session_key': session.export_session(session.first_known_index), + 'forwarding_curve25519_key_chain': session.forwarding_chain, + } + return payload + + def request_missing_key(self, encrypted_event, force=False): + """Request the key used to encrypt the event from our devices. + + Args: + encrypted_event (dict): A ``m.room.encrypted`` Megolm event. + force (bool): Optional. If ``True``, send a request even if one has already + been sent. + """ + # If no callback is registered in ordered to handle forwarded keys, it is + # useless to request them. + if not self.key_forward_callback: + return + content = encrypted_event['content'] + session_id = content['session_id'] + if session_id in self.outgoing_key_requests and not force: + logger.info('Already have an outgoing key request for session %s.', + session_id) + return + logger.info('Requesting keys for session %s.', session_id) + payload = { + 'action': 'request', + 'body': { + 'algorithm': content['algorithm'], + 'session_id': session_id, + 'room_id': encrypted_event['room_id'], + 'sender_key': content['sender_key'] + }, + 'request_id': session_id, + 'requesting_device_id': self.device_id + } + self.api.send_to_device('m.room_key_request', {self.user_id: {'*': payload}}) + self.outgoing_key_requests.add(session_id) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 9bf9863d..c7940b2c 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -7,12 +7,13 @@ from matrix_client.checks import check_user_id from matrix_client.device import Device -from matrix_client.errors import E2EUnknownDevices +from matrix_client.errors import E2EUnknownDevices, UnableToDecryptError from matrix_client.crypto.one_time_keys import OneTimeKeysManager from matrix_client.crypto.device_list import DeviceList from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession from matrix_client.crypto.crypto_store import CryptoStore from matrix_client.crypto.verified_event import VerifiedEvent +from matrix_client.crypto.key_sharing import KeySharingManager logger = logging.getLogger(__name__) @@ -101,6 +102,7 @@ def __init__(self, keys_threshold) self.device_list = DeviceList(self, api, self.device_keys, self.db) self.megolm_index_record = defaultdict(dict) + self.key_sharing_manager = KeySharingManager(api, user_id, device_id, self) keys = self.olm_account.identity_keys super(OlmDevice, self).__init__(self.api, self.user_id, device_id, database=self.db, ed25519_key=keys['ed25519'], @@ -680,15 +682,20 @@ def olm_handle_encrypted_event(self, encrypted_event): if 'algorithm' not in content or content['algorithm'] != self._olm_algorithm: return + sender = encrypted_event['sender'] try: - event = self.olm_decrypt_event(content, encrypted_event['sender']) + event = self.olm_decrypt_event(content, sender) except RuntimeError as e: - logger.warning('Failed to decrypt m.room_key event sent by user %s: %s', + logger.warning('Failed to decrypt toDevice Olm event sent by user %s: %s', encrypted_event['sender'], e) return + sender_key = encrypted_event['content']['sender_key'] if event['type'] == 'm.room_key': - self.handle_room_key_event(event, encrypted_event['content']['sender_key']) + self.handle_room_key_event(event, sender_key) + elif event['type'] == 'm.forwarded_room_key': + self.key_sharing_manager.handle_forwarded_room_key_event(event, sender, + sender_key) def handle_room_key_event(self, event, sender_key): """Handle a m.room_key event. @@ -705,9 +712,13 @@ def handle_room_key_event(self, event, sender_key): user_id = event['sender'] device_id = event['sender_device'] - new = self.megolm_add_inbound_session(content['room_id'], sender_key, - signing_key, content['session_id'], - content['session_key']) + try: + new = self.megolm_add_inbound_session(content['room_id'], sender_key, + signing_key, content['session_id'], + content['session_key']) + except ValueError as e: + logger.warning(e) + return if new: logger.info('Created a new Megolm inbound session with device %s of ' 'user %s.', device_id, user_id) @@ -716,7 +727,8 @@ def handle_room_key_event(self, event, sender_key): 'already exists or is invalid.', device_id, user_id) def megolm_add_inbound_session(self, room_id, sender_key, signing_key, session_id, - session_key): + session_key, forwarding_chain=None, + export_format=False): """Create a new Megolm inbound session if necessary. Args: @@ -727,8 +739,10 @@ def megolm_add_inbound_session(self, room_id, sender_key, signing_key, session_i signing_key (str): The ed25519 key of the event which established the session. Returns: - ``True`` if a new session was created, ``False`` if it already existed or if - the parameters were invalid. + ``True`` if a new session was created, ``False`` if it already existed. + + Raises: + ValueError if one of the parameters were invalid. """ sessions = self.megolm_inbound_sessions[room_id][sender_key] if session_id in sessions: @@ -737,13 +751,16 @@ def megolm_add_inbound_session(self, room_id, sender_key, signing_key, session_i if self.db.get_inbound_session(room_id, sender_key, session_id, sessions): return False try: - session = MegolmInboundSession(session_key, signing_key) - except olm.OlmGroupSessionError: - return False + if export_format: + session = MegolmInboundSession.import_session(session_key, signing_key, + forwarding_chain) + else: + session = MegolmInboundSession(session_key, signing_key) + except olm.OlmGroupSessionError as e: + raise ValueError('olmlib error when trying to add the session: {}.'.format(e)) if session.id != session_id: - logger.warning('Session ID mismatch in m.room_key event. Expected %s from ' - 'event property, got %s.', session_id, session.id) - return False + raise ValueError('Session ID mismatch in m.room_key event. Expected {} from ' + 'event property, got {}.'.format(session_id, session.id)) self.db.save_inbound_session(room_id, sender_key, session) sessions[session_id] = session return True @@ -779,9 +796,11 @@ def megolm_decrypt_event(self, event): session = self.db.get_inbound_session( room_id, sender_key, session_id, sessions) if not session: - raise RuntimeError("Unable to decrypt event sent by device {} of user " - "{}: The sender's device has not sent us the keys for " - "this message.".format(device_id, user_id)) + raise UnableToDecryptError( + "Unable to decrypt event sent by device {} of user {}: The sender's " + "device has not sent us the keys for this message." + .format(device_id, user_id) + ) try: decrypted_event, message_index = session.decrypt(content['ciphertext']) @@ -795,7 +814,8 @@ def megolm_decrypt_event(self, event): except KeyError: pass else: - if device.verified: + # Do not mark events decrypted using a forwarded key as verified + if device.verified and not session.forwarding_chain: if device.ed25519 != session.ed25519 or device.curve25519 != sender_key: raise RuntimeError('Device keys mismatch in event sent by device {}.' .format(device.device_id)) diff --git a/matrix_client/crypto/sessions.py b/matrix_client/crypto/sessions.py index fe9d024d..fe90259d 100644 --- a/matrix_client/crypto/sessions.py +++ b/matrix_client/crypto/sessions.py @@ -81,13 +81,22 @@ class MegolmInboundSession(InboundGroupSession): def __init__(self, session_key, signing_key): self.ed25519 = signing_key + self.forwarding_chain = None super(MegolmInboundSession, self).__init__(session_key) def __new__(cls, *args): return super(MegolmInboundSession, cls).__new__(cls) @classmethod - def from_pickle(cls, pickle, signing_key, passphrase=''): + def from_pickle(cls, pickle, signing_key, passphrase='', forwarding_chain=None): session = super(MegolmInboundSession, cls).from_pickle(pickle, passphrase) session.ed25519 = signing_key + session.forwarding_chain = forwarding_chain + return session + + @classmethod + def import_session(cls, session_key, signing_key, forwarding_chain=None): + session = super(MegolmInboundSession, cls).import_session(session_key) + session.ed25519 = signing_key + session.forwarding_chain = forwarding_chain return session diff --git a/matrix_client/errors.py b/matrix_client/errors.py index 91154bb8..e7cd2c68 100644 --- a/matrix_client/errors.py +++ b/matrix_client/errors.py @@ -68,3 +68,9 @@ def __init__(self, user_devices): "The room contains unknown devices which have not been verified. They can " "be inspected via the 'user_devices' attribute of this exception.") self.user_devices = user_devices + + +class UnableToDecryptError(Exception): + """An encrypted message couldn't be decrypted due to missing keys.""" + + pass diff --git a/matrix_client/room.py b/matrix_client/room.py index 5ec57d7a..e9f2b200 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -19,7 +19,7 @@ from .checks import check_room_id from .user import User -from .errors import MatrixRequestError, MatrixNoEncryptionError +from .errors import MatrixRequestError, MatrixNoEncryptionError, UnableToDecryptError logger = logging.getLogger(__name__) @@ -348,6 +348,9 @@ def _put_event(self, event): event = self.client.olm_device.megolm_decrypt_event(event) except RuntimeError as e: logger.warning(e) + except UnableToDecryptError as e: + logger.warning(e) + self.client.olm_device.key_sharing_manager.request_missing_key(event) self.events.append(event) if len(self.events) > self.event_history_limit: self.events.pop(0) diff --git a/test/client_test.py b/test/client_test.py index 472ad195..e3adc408 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -578,15 +578,15 @@ def test_one_time_keys_sync(): sync_response["device_one_time_keys_count"] = payload sync_response['rooms']['join'] = {} - class DummyDevice: + class DummyDevice(OlmDevice): def update_one_time_key_counts(self, payload): - self.payload = payload + self.test_payload = payload - device = DummyDevice() + device = DummyDevice(None, '@test:localhost', 'test') client.olm_device = device responses.add(responses.GET, sync_url, json=sync_response) client._sync() - assert device.payload == payload + assert device.test_payload == payload diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index d21dd8be..ffe918fc 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -18,7 +18,7 @@ from matrix_client.client import MatrixClient from matrix_client.user import User from matrix_client.device import Device -from matrix_client.errors import E2EUnknownDevices +from matrix_client.errors import E2EUnknownDevices, UnableToDecryptError from test.crypto.dummy_olm_device import OlmDevice, DummyStore from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession from test.response_examples import (example_key_upload_response, @@ -571,8 +571,11 @@ def test_send_encrypted_message(self, device): def test_megolm_add_inbound_session(self, device): session = MegolmOutboundSession() - assert not device.megolm_add_inbound_session( - self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, 'wrong') + with pytest.raises(ValueError): + device.megolm_add_inbound_session( + self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, + 'wrong' + ) assert device.megolm_add_inbound_session( self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, session.session_key @@ -583,10 +586,11 @@ def test_megolm_add_inbound_session(self, device): self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, session.session_key ) - assert not device.megolm_add_inbound_session( - self.room_id, self.alice_curve_key, self.alice_ed_key, 'wrong', - session.session_key - ) + with pytest.raises(ValueError): + device.megolm_add_inbound_session( + self.room_id, self.alice_curve_key, self.alice_ed_key, 'wrong', + session.session_key + ) def test_handle_room_key_event(self, device): device.handle_room_key_event(example_room_key_event, self.alice_curve_key) @@ -656,7 +660,7 @@ def test_megolm_decrypt_event(self, device): 'event_id': 1 } - with pytest.raises(RuntimeError): + with pytest.raises(UnableToDecryptError): device.megolm_decrypt_event(event) session_key = out_session.session_key From e3e579f89a8d6a32c8032f292e7d9be58c76aad2 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 9 Aug 2018 10:03:44 +0200 Subject: [PATCH 58/66] persist forwarded chain --- matrix_client/crypto/crypto_store.py | 17 +++++++++++++++++ matrix_client/crypto/sessions.py | 6 +++--- test/crypto/crypto_store_test.py | 2 ++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index bd9c7c85..90bec001 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -74,6 +74,11 @@ def create_tables_if_needed(self): ed_key TEXT, session BLOB, FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE ); +CREATE TABLE IF NOT EXISTS forwarded_chains( + device_id TEXT, session_id TEXT, curve_key TEXT, + PRIMARY KEY(device_id, session_id, curve_key), + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); CREATE TABLE IF NOT EXISTS megolm_outbound_sessions( device_id TEXT, room_id TEXT, session BLOB, max_age_s FLOAT, max_messages INTEGER, creation_time TIMESTAMP, message_count INTEGER, @@ -253,6 +258,9 @@ def save_inbound_session(self, room_id, curve_key, session): c.execute('REPLACE INTO megolm_inbound_sessions VALUES (?,?,?,?,?,?)', (self.device_id, session.id, room_id, curve_key, session.ed25519, session.pickle(self.pickle_key))) + rows = [(self.device_id, session.id, curve_key) + for curve_key in session.forwarding_chain] + c.executemany('INSERT OR IGNORE INTO forwarded_chains VALUES(?,?,?)', rows) c.close() self.conn.commit() @@ -273,6 +281,7 @@ def load_inbound_sessions(self, sessions): session = MegolmInboundSession.from_pickle( bytes(row['session']), row['ed_key'], self.pickle_key) sessions[row['room_id']][row['curve_key']][session.id] = session + self._load_forwarding_chain(session) c.close() def get_inbound_session(self, room_id, curve_key, session_id, sessions=None): @@ -303,10 +312,18 @@ def get_inbound_session(self, room_id, curve_key, session_id, sessions=None): c.close() session = MegolmInboundSession.from_pickle(session_data, row['ed_key'], self.pickle_key) + self._load_forwarding_chain(session) if sessions is not None: sessions[session.id] = session return session + def _load_forwarding_chain(self, session): + c = self.conn.cursor() + c.execute('SELECT curve_key FROM forwarded_chains WHERE device_id=? ' + 'AND session_id=?', (self.device_id, session.id)) + session.forwarding_chain = [row['curve_key'] for row in c] + c.close() + def save_outbound_session(self, room_id, session): """Saves a Megolm outbound session. diff --git a/matrix_client/crypto/sessions.py b/matrix_client/crypto/sessions.py index fe90259d..2e7da2a9 100644 --- a/matrix_client/crypto/sessions.py +++ b/matrix_client/crypto/sessions.py @@ -81,7 +81,7 @@ class MegolmInboundSession(InboundGroupSession): def __init__(self, session_key, signing_key): self.ed25519 = signing_key - self.forwarding_chain = None + self.forwarding_chain = [] super(MegolmInboundSession, self).__init__(session_key) def __new__(cls, *args): @@ -91,12 +91,12 @@ def __new__(cls, *args): def from_pickle(cls, pickle, signing_key, passphrase='', forwarding_chain=None): session = super(MegolmInboundSession, cls).from_pickle(pickle, passphrase) session.ed25519 = signing_key - session.forwarding_chain = forwarding_chain + session.forwarding_chain = forwarding_chain or [] return session @classmethod def import_session(cls, session_key, signing_key, forwarding_chain=None): session = super(MegolmInboundSession, cls).import_session(session_key) session.ed25519 = signing_key - session.forwarding_chain = forwarding_chain + session.forwarding_chain = forwarding_chain or [] return session diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index 42bc859d..d224f4d0 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -144,6 +144,7 @@ def test_olm_sessions_persistence(self, account, curve_key, device): def test_megolm_inbound_persistence(self, curve_key, ed_key, device): out_session = olm.OutboundGroupSession() session = MegolmInboundSession(out_session.session_key, ed_key) + session.forwarding_chain.append(curve_key) sessions = defaultdict(lambda: defaultdict(dict)) self.store.load_inbound_sessions(sessions) @@ -157,6 +158,7 @@ def test_megolm_inbound_persistence(self, curve_key, ed_key, device): saved_session = self.store.get_inbound_session(self.room_id, curve_key, session.id) assert saved_session.id == session.id + assert saved_session.forwarding_chain == [curve_key] sessions = {} saved_session = self.store.get_inbound_session(self.room_id, curve_key, From e76c5ec9db48eacf07034052624a1c9dd9413b0b Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 7 Aug 2018 10:16:00 +0200 Subject: [PATCH 59/66] persist outgoing key requests --- matrix_client/crypto/crypto_store.py | 40 ++++++++++++++++++++++++++++ matrix_client/crypto/key_sharing.py | 6 ++++- matrix_client/crypto/olm_device.py | 3 ++- test/crypto/crypto_store_test.py | 17 ++++++++++++ 4 files changed, 64 insertions(+), 2 deletions(-) diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py index 90bec001..ae9ac5ed 100644 --- a/matrix_client/crypto/crypto_store.py +++ b/matrix_client/crypto/crypto_store.py @@ -105,6 +105,10 @@ def create_tables_if_needed(self): CREATE TABLE IF NOT EXISTS sync_tokens( device_id TEXT PRIMARY KEY, token TEXT, FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS outgoing_key_requests( + device_id TEXT PRIMARY KEY, session_id TEXT, + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE ); """) c.close() @@ -580,6 +584,42 @@ def get_sync_token(self): finally: c.close() + def add_outgoing_key_request(self, session_id): + """Saves a key request. + + Args: + session_id (str): The requested session. + """ + c = self.conn.cursor() + c.execute('INSERT OR IGNORE INTO outgoing_key_requests VALUES (?,?)', + (self.device_id, session_id)) + c.close() + self.conn.commit() + + def remove_outgoing_key_request(self, session_id): + """Removes a key request. + + Args: + session_id (str): The requested session. + """ + c = self.conn.cursor() + c.execute('DELETE FROM outgoing_key_requests WHERE device_id=? and session_id=?', + (self.device_id, session_id)) + c.close() + + def load_outgoing_key_requests(self, session_ids): + """Load key requests. + + Args: + session_ids (set): Will be populated with session IDs. + """ + c = self.conn.cursor() + c.execute('SELECT session_id FROM outgoing_key_requests WHERE device_id=?', + (self.device_id,)) + for row in c: + session_ids.add(row['session_id']) + c.close() + def close(self): self.conn.close() diff --git a/matrix_client/crypto/key_sharing.py b/matrix_client/crypto/key_sharing.py index 252ebbe7..8f6c26cf 100644 --- a/matrix_client/crypto/key_sharing.py +++ b/matrix_client/crypto/key_sharing.py @@ -6,13 +6,15 @@ class KeySharingManager(object): - def __init__(self, api, user_id, device_id, olm_device): + def __init__(self, api, db, user_id, device_id, olm_device): self.api = api + self.db = db self.user_id = user_id self.device_id = device_id self.olm_device = olm_device self.queued_key_requests = defaultdict(dict) self.outgoing_key_requests = set() + self.db.load_outgoing_key_requests(self.outgoing_key_requests) self.key_request_callback = None self.key_forward_callback = None @@ -63,6 +65,7 @@ def handle_forwarded_room_key_event(self, event, sender, sender_key): } self.api.send_to_device('m.room_key_request', {self.user_id: {'*': payload}}) self.outgoing_key_requests.discard(session_id) + self.db.remove_outgoing_key_request(session_id) if self.key_forward_callback: self.key_forward_callback(session_id) @@ -193,3 +196,4 @@ def request_missing_key(self, encrypted_event, force=False): } self.api.send_to_device('m.room_key_request', {self.user_id: {'*': payload}}) self.outgoing_key_requests.add(session_id) + self.db.add_outgoing_key_request(session_id) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index c7940b2c..2f007912 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -102,7 +102,8 @@ def __init__(self, keys_threshold) self.device_list = DeviceList(self, api, self.device_keys, self.db) self.megolm_index_record = defaultdict(dict) - self.key_sharing_manager = KeySharingManager(api, user_id, device_id, self) + self.key_sharing_manager = KeySharingManager(api, self.db, user_id, device_id, + self) keys = self.olm_account.identity_keys super(OlmDevice, self).__init__(self.api, self.user_id, device_id, database=self.db, ed25519_key=keys['ed25519'], diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py index d224f4d0..1fc2b089 100644 --- a/test/crypto/crypto_store_test.py +++ b/test/crypto/crypto_store_test.py @@ -333,6 +333,23 @@ def test_sync_token_persistence(self): self.store.save_sync_token(sync_token) assert self.store.get_sync_token() == sync_token + @pytest.mark.usefixtures('account') + def test_key_requests(self): + session_id = 'test' + session_ids = set() + + self.store.load_outgoing_key_requests(session_ids) + assert not session_ids + + self.store.add_outgoing_key_request(session_id) + self.store.load_outgoing_key_requests(session_ids) + assert session_id in session_ids + + session_ids.clear() + self.store.remove_outgoing_key_request(session_id) + self.store.load_outgoing_key_requests(session_ids) + assert not session_ids + def test_load_all(self, account, curve_key, ed_key, device): curve_key = account.identity_keys['curve25519'] session = olm.OutboundSession(account, curve_key, curve_key) From ae8d11d24e9064bdf51d915328ce6b2566896c63 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 9 Aug 2018 15:41:58 +0200 Subject: [PATCH 60/66] add key sharing tests --- test/crypto/key_sharing_test.py | 243 ++++++++++++++++++++++++++++++++ test/crypto/olm_device_test.py | 11 ++ test/response_examples.py | 42 ++++++ 3 files changed, 296 insertions(+) create mode 100644 test/crypto/key_sharing_test.py diff --git a/test/crypto/key_sharing_test.py b/test/crypto/key_sharing_test.py new file mode 100644 index 00000000..5667a4df --- /dev/null +++ b/test/crypto/key_sharing_test.py @@ -0,0 +1,243 @@ +import pytest +olm = pytest.importorskip("olm") # noqa + +from copy import deepcopy + +import responses + +from matrix_client.api import MATRIX_V2_API_PATH +from matrix_client.client import MatrixClient +from matrix_client.crypto.key_sharing import KeySharingManager +from matrix_client.device import Device +from test.crypto.dummy_olm_device import OlmDevice, DummyStore +from test.response_examples import (example_forwarded_room_key_event, + example_room_key_request_event, + example_room_key_cancel_event) + +HOSTNAME = 'http://example.com' + + +class TestKeySharing: + cli = MatrixClient(HOSTNAME) + user_id = '@user:matrix.org' + room_id = '!test:example.com' + device_id = 'QBUAZIFURK' + other_device_id = 'JLAFKJWSCS' + other_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' + other_device = Device(None, user_id, other_device_id, curve25519_key=other_curve_key) + request_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room_key_request/1' + forward_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + + @pytest.fixture() + def olm_device(self): + device = OlmDevice(self.cli.api, self.user_id, self.device_id) + device.api._make_txn_id = lambda: 1 + return device + + @pytest.fixture() + def manager(self, olm_device): + return KeySharingManager(self.cli.api, DummyStore(), self.user_id, self.device_id, + olm_device) + + @pytest.fixture() + def olm_session_with_other_device(self, olm_device): + session = olm.OutboundSession(olm_device.olm_account, self.other_curve_key, + self.other_curve_key) + olm_device.device_keys[self.user_id][self.other_device_id] = self.other_device + olm_device.olm_sessions[self.other_curve_key] = [session] + + @responses.activate + def test_handle_forwarded_room_key(self, olm_device, manager): + responses.add(responses.PUT, self.request_url, json={}) + content = example_forwarded_room_key_event['content'] + sender_key = 'test' + room_id = content['room_id'] + session_sender_key = content['sender_key'] + session_id = content['session_id'] + + # Not requested + manager.handle_forwarded_room_key_event(example_forwarded_room_key_event, + self.user_id, sender_key) + assert not olm_device.megolm_inbound_sessions + + manager.outgoing_key_requests.add(session_id) + manager.handle_forwarded_room_key_event(example_forwarded_room_key_event, + self.user_id, sender_key) + sessions = olm_device.megolm_inbound_sessions[room_id][session_sender_key] + assert sessions[session_id].id == session_id + assert not manager.outgoing_key_requests + + manager.outgoing_key_requests.add(session_id) + # With callback + + def callback(arg_session_id): + assert arg_session_id == session_id + + manager.key_forward_callback = callback + manager.handle_forwarded_room_key_event(example_forwarded_room_key_event, + self.user_id, sender_key) + assert sessions[session_id].id == session_id + assert not manager.outgoing_key_requests + + manager.outgoing_key_requests.add(session_id) + olm_device.megolm_inbound_sessions.clear() + # Wrong payload + event = deepcopy(example_forwarded_room_key_event) + event['content']['session_key'] = 'wrong' + manager.handle_forwarded_room_key_event(event, self.user_id, sender_key) + sessions = olm_device.megolm_inbound_sessions[room_id][session_sender_key] + assert not sessions + assert manager.outgoing_key_requests + + # Wrong algorithm + event = deepcopy(example_forwarded_room_key_event) + event['content']['algorithm'] = 'wrong' + manager.handle_forwarded_room_key_event(event, self.user_id, sender_key) + assert not sessions + assert manager.outgoing_key_requests + + # Wrong sender + manager.handle_forwarded_room_key_event(example_forwarded_room_key_event, + 'wrong', sender_key) + assert not sessions + assert manager.outgoing_key_requests + + def test_handle_key_request(self, manager, olm_device): + event = deepcopy(example_room_key_request_event) + content = event['content'] + device_id = content['requesting_device_id'] + request_id = content['request_id'] + + # Request from another user + event['sender'] = 'wrong' + manager.handle_key_request(event) + assert not manager.queued_key_requests + + # Useless request from us + event['sender'] = self.user_id + content['requesting_device_id'] = self.device_id + manager.handle_key_request(event) + assert not manager.queued_key_requests + + # Request from unknown device + content['requesting_device_id'] = 'unknown' + manager.handle_key_request(event) + assert not manager.queued_key_requests + + # Valid request + olm_device.device_keys[self.user_id][device_id] = None + content['requesting_device_id'] = device_id + valid_event = deepcopy(event) + manager.handle_key_request(valid_event) + assert request_id in manager.queued_key_requests[device_id] + + # Duplicate request + manager.handle_key_request(event) + + # Cancel request + cancel_event = deepcopy(example_room_key_cancel_event) + cancel_event['sender'] = self.user_id + manager.handle_key_request(cancel_event) + assert not manager.queued_key_requests[device_id][request_id] + + # Request after cancelation + manager.handle_key_request(event) + assert not manager.queued_key_requests[device_id][request_id] + + # Unknown algorithm + content['body']['algorithm'] = 'unknown' + manager.handle_key_request(event) + assert not manager.queued_key_requests[device_id][request_id] + + # Unknown action + content['action'] = 'unknown' + manager.handle_key_request(event) + assert not manager.queued_key_requests[device_id][request_id] + + def test_trigger_key_requests_callback(self, manager, olm_device): + # No callback + manager.trigger_key_requests_callback() + + def callback(devices, method): + assert devices[device_id] == device + assert method == manager.process_key_requests + + manager.key_request_callback = callback + + # No requests + manager.trigger_key_requests_callback() + + # Request + device_id = 'test' + device = Device(None, self.user_id, self.device_id) + olm_device.device_keys[self.user_id][device_id] = device + + manager.queued_key_requests[device_id] = None + manager.trigger_key_requests_callback() + + @responses.activate + @pytest.mark.usefixtures('olm_session_with_other_device') + def test_process_key_requests(self, manager, olm_device): + device_ids = [self.other_device_id] + + # No requests + manager.process_key_requests(device_ids) + + # No session + event = deepcopy(example_room_key_request_event) + content = event['content'] + body = content['body'] + request_id = content['request_id'] + manager.queued_key_requests[self.other_device_id][request_id] = body + manager.process_key_requests(device_ids) + + responses.add(responses.PUT, self.forward_url, json={}) + room_id = body['room_id'] + sender_key = body['sender_key'] + session_id = body['session_id'] + olm_device.megolm_add_inbound_session( + room_id, sender_key, 'ed25519', session_id, + example_forwarded_room_key_event['content']['session_key'], + export_format=True + ) + manager.queued_key_requests[self.other_device_id][request_id] = body + manager.process_key_requests(device_ids) + + manager.queued_key_requests[self.other_device_id][request_id] = body + # Retrieved from db + session = olm_device.megolm_inbound_sessions[room_id][sender_key][session_id] + olm_device.megolm_inbound_sessions.clear() + + class DB(DummyStore): + + def __getattribute__(self, name): + if name == 'get_inbound_session': + return lambda *x: session + return super(DB, self).__getattribute__(name) + + olm_device.db = DB() + manager.process_key_requests(device_ids) + + @responses.activate + def test_request_missing_key(self, manager): + responses.add(responses.PUT, self.request_url, json={}) + encrypted_event = { + 'room_id': 'test', + 'content': { + 'session_id': 'test', + 'algorithm': 'test', + 'sender_key': 'test' + } + } + # No callback + manager.request_missing_key(encrypted_event) + assert not responses.calls + + manager.key_forward_callback = lambda: None + # Good + manager.request_missing_key(encrypted_event) + assert len(responses.calls) == 1 + + # Already requested + manager.request_missing_key(encrypted_event) + assert len(responses.calls) == 1 diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index ffe918fc..e91a27b4 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -621,6 +621,17 @@ def test_olm_handle_encrypted_event(self, device, alice_olm_device): # Decrypting the same event twice will trigger an error device.olm_handle_encrypted_event(event) + # Forwarded key event + encrypted_event = alice_olm_device.olm_build_encrypted_event( + 'm.forwarded_room_key', content, self.user_id, self.device_id) + event = { + 'type': 'm.room.encrypted', + 'content': encrypted_event, + 'sender': self.alice + } + device.olm_handle_encrypted_event(event) + + # Unhandled event encrypted_event = alice_olm_device.olm_build_encrypted_event( 'm.other', content, self.user_id, self.device_id) event = { diff --git a/test/response_examples.py b/test/response_examples.py index 9b5e9456..eff4c45d 100644 --- a/test/response_examples.py +++ b/test/response_examples.py @@ -256,3 +256,45 @@ "ed25519": "4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc", } } + +example_forwarded_room_key_event = { + "content": { + "algorithm": "m.megolm.v1.aes-sha2", + "forwarding_curve25519_key_chain": [ + "hPQNcabIABgGnx3/ACv/jmMmiQHoeFfuLB17tzWp6Hw" + ], + "room_id": "!Cuyf34gef24t:localhost", + "sender_claimed_ed25519_key": "aj40p+aw64yPIdsxoog8jhPu9i7l7NcFRecuOQblE3Y", + "sender_key": "RF3s+E7RkTQTGF2d8Deol0FkQvgII2aJDf3/Jp5mxVU", + "session_id": "iR4Q8LUXrtjwse7U80iALTZjcezHm0fI1UvXloTV0xs", + "session_key": + ("AQAAAADk1Ouk7LX6RuCsuHvtkD3/yvEDx4q4oXaK3sfPh03lUxNM3mXx6OHOH8kGFANHEVXQYr0OdYh" + "UeFM6xNSididZ5jiFpSQ0rIftSl+z4RlmFZPbt3XkvS2/8Q0mDr70g4rSYMkqxdWQy9Vi2lj0sWfQNl" + "QR92G0RwGsPNZdzYsBJokeEPC1F67Y8LHu1PNIgC02Y3Hsx5tHyNVL15aE1dMb") + }, + "type": "m.room_key" +} + +example_room_key_request_event = { + "content": { + "action": "request", + "body": { + "algorithm": "m.megolm.v1.aes-sha2", + "room_id": "!Cuyf34gef24t:localhost", + "sender_key": "RF3s+E7RkTQTGF2d8Deol0FkQvgII2aJDf3/Jp5mxVU", + "session_id": "iR4Q8LUXrtjwse7U80iALTZjcezHm0fI1UvXloTV0xs" + }, + "request_id": "1495474790150.19", + "requesting_device_id": "RJYKSTBOIE" + }, + "type": "m.room_key_request" +} + +example_room_key_cancel_event = { + "content": { + "action": "cancel_request", + "request_id": "1495474790150.19", + "requesting_device_id": "RJYKSTBOIE" + }, + "type": "m.room_key_request" +} From 4cdb52609f04d13d107da8d0daea22ff8f73e060 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 8 Aug 2018 21:24:47 +0200 Subject: [PATCH 61/66] add key export functions --- matrix_client/crypto/key_export.py | 111 +++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 matrix_client/crypto/key_export.py diff --git a/matrix_client/crypto/key_export.py b/matrix_client/crypto/key_export.py new file mode 100644 index 00000000..a7160cd1 --- /dev/null +++ b/matrix_client/crypto/key_export.py @@ -0,0 +1,111 @@ +from Crypto import Random +from Crypto.Cipher import AES +from Crypto.Hash import HMAC, SHA512, SHA256 +from Crypto.Protocol.KDF import PBKDF2 +from Crypto.Util import Counter +from unpaddedbase64 import decode_base64, encode_base64 + +HEADER = '-----BEGIN MEGOLM SESSION DATA-----' +FOOTER = '-----END MEGOLM SESSION DATA-----' + + +def encrypt_and_save(data, outfile, passphrase, count=100000): + """Encrypt keys data and write it to file. + + Args: + data (bytes): The data to encrypt. + outfile (str): The file the encrypted data will be written to. + passphrase (str): The encryption passphrase. + count (int): The round count used when deriving a key from the passphrase. + + Raises: + FileNotFoundError if the path to the file did not exist. + """ + encrypted_data = encrypt(data, passphrase, count=count) + with open(outfile, 'w') as f: + f.write(HEADER) + f.write(encrypted_data) + f.write(FOOTER) + + +def decrypt_and_read(infile, passphrase): + """Decrypt keys data from file. + + Args: + infile (str): The file the encrypted data will be written to. + passphrase (str): The encryption passphrase. + + Returns: + The decrypted data, as bytes. + + Raises: + ValueError if something went wrong during decryption. + FileNotFoundError if the file was not found. + """ + with open(infile, 'r') as f: + encrypted_data = f.read() + encrypted_data = encrypted_data.replace('\n', '') + + if not encrypted_data.startswith(HEADER) or not encrypted_data.endswith(FOOTER): + raise ValueError('Wrong file format.') + + encrypted_data = encrypted_data[len(HEADER):-len(FOOTER)] + return decrypt(encrypted_data, passphrase) + + +def prf(passphrase, salt): + """HMAC-SHA-512 pseudorandom function.""" + return HMAC.new(passphrase, salt, SHA512).digest() + + +def encrypt(data, passphrase, count=100000): + # 128 bits salt + salt = Random.new().read(16) + # 512 bits derived key + derived_key = PBKDF2(passphrase, salt, 64, count, prf) + aes_key = derived_key[:32] + hmac_key = derived_key[32:64] + + # 128 bits IV, which will be the initial value initial + iv = int.from_bytes(Random.new().read(16), byteorder='big') + # Set bit 63 to 0, as specified + iv &= ~(1 << 63) + ctr = Counter.new(128, initial_value=iv) + cipher = AES.new(aes_key, AES.MODE_CTR, counter=ctr) + encrypted_data = cipher.encrypt(data) + + payload = b''.join(( + bytes([1]), # Version + salt, + int.to_bytes(iv, length=16, byteorder='big'), + int.to_bytes(count, length=4, byteorder='big'), # 32 bits big-endian round count + encrypted_data + )) + + hmac = HMAC.new(hmac_key, payload, SHA256).digest() + return encode_base64(payload + hmac) + + +def decrypt(encrypted_payload, passphrase): + encrypted_payload = decode_base64(encrypted_payload) + + version = encrypted_payload[0] + if version != 1: + raise ValueError('Unsupported export format version.') + salt = encrypted_payload[1:17] + iv = int.from_bytes(encrypted_payload[17:33], byteorder='big') + count = int.from_bytes(encrypted_payload[33:37], byteorder='big') + encrypted_data = encrypted_payload[37:-32] + expected_hmac = encrypted_payload[-32:] + + derived_key = PBKDF2(passphrase, salt, 64, count, prf) + aes_key = derived_key[:32] + hmac_key = derived_key[32:64] + + hmac = HMAC.new(hmac_key, encrypted_payload[:-32], SHA256).digest() + if hmac != expected_hmac: + raise ValueError('HMAC check failed for encrypted payload.') + + ctr = Counter.new(128, initial_value=iv) + cipher = AES.new(aes_key, AES.MODE_CTR, counter=ctr) + return cipher.decrypt(encrypted_data) From b1b888f31ca5d2e896558686a9499b59d9ab5eb8 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 8 Aug 2018 21:25:19 +0200 Subject: [PATCH 62/66] export keys from OlmDevice --- matrix_client/crypto/olm_device.py | 62 ++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 2f007912..8e23b632 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -14,6 +14,7 @@ from matrix_client.crypto.crypto_store import CryptoStore from matrix_client.crypto.verified_event import VerifiedEvent from matrix_client.crypto.key_sharing import KeySharingManager +from matrix_client.crypto.key_export import encrypt_and_save, decrypt_and_read logger = logging.getLogger(__name__) @@ -912,3 +913,64 @@ def verify_json(self, json, user_key, user_id, device_id): json['unsigned'] = unsigned return success + + def export_keys(self, outfile, passphrase, count=10000): + """Export all the Megolm decryption keys of this device. + + The keys will be encrypted using the passphrase. + + NOTE: + This does not save other information such as the private identity keys + of the device. + + Args: + outfile (str): The file to write the keys to. + passphrase (str): The encryption passphrase. + count (int): Optional. Round count for the underlying key derivation. + It is not recommended to specify it unless absolutely sure of the + consequences. + """ + session_list = [] + self.db.load_inbound_sessions(self.megolm_inbound_sessions) + for room_id in self.megolm_inbound_sessions: + for sender_key, sessions in self.megolm_inbound_sessions[room_id].items(): + for session in sessions.values(): + payload = { + 'algorithm': self._megolm_algorithm, + 'sender_key': sender_key, + 'sender_claimed_keys': { + 'ed25519': session.ed25519 + }, + 'forwarding_curve25519_key_chain': session.forwarding_chain, + 'room_id': room_id, + 'session_id': session.id, + 'session_key': session.export_session(session.first_known_index) + } + session_list.append(payload) + data = json.dumps({'sessions': session_list}).encode() + encrypt_and_save(data, outfile, passphrase, count=count) + logger.info('Success exporting keys to %s.', outfile) + + def import_keys(self, infile, passphrase): + """Import Megolm decryption keys. + + The keys will be added to the current instance as well as written to database. + + Args: + infile (str): The file containing the keys. + passphrase (str): The decryption passphrase. + """ + data = decrypt_and_read(infile, passphrase) + session_list = json.loads(data)['sessions'] + for session in session_list: + if session['algorithm'] != self._megolm_algorithm: + logger.warning('Ignoring session with unsupported algorithm.') + continue + # This could be improved by writing everything to db at once at the end + self.megolm_add_inbound_session( + session['room_id'], session['sender_key'], + session['sender_claimed_keys']['ed25519'], session['session_id'], + session['session_key'], session['forwarding_curve25519_key_chain'], + export_format=True + ) + logger.info('Success importing keys from %s.', infile) From 027ae79bd86f288e1ab0dfb2c92026bfd0106ddb Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Wed, 8 Aug 2018 21:25:45 +0200 Subject: [PATCH 63/66] add key export tests --- test/crypto/key_export_test.py | 74 ++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 test/crypto/key_export_test.py diff --git a/test/crypto/key_export_test.py b/test/crypto/key_export_test.py new file mode 100644 index 00000000..0c035b19 --- /dev/null +++ b/test/crypto/key_export_test.py @@ -0,0 +1,74 @@ +import pytest +olm = pytest.importorskip("olm") # noqa + +import os +from tempfile import mkstemp + +from unpaddedbase64 import decode_base64, encode_base64 + +from matrix_client.crypto.key_export import (encrypt, encrypt_and_save, decrypt, + decrypt_and_read) +from matrix_client.crypto.sessions import MegolmInboundSession +from test.crypto.dummy_olm_device import OlmDevice + + +def test_encrypt_decrypt(): + plaintext = b'test' + passphrase = 'pass' + # Set a ridiculously low round count for this test to be fast + ciphertext = encrypt(plaintext, passphrase, count=1) + + assert decrypt(ciphertext, passphrase) == plaintext + + ciphertext_bytes = decode_base64(ciphertext) + + # Wrong hmac + ciphertext = encode_base64(ciphertext_bytes[:-32] + b'A' * 32) + with pytest.raises(ValueError): + decrypt(ciphertext, passphrase) + + # Wrong version + ciphertext = encode_base64(bytes([42]) + ciphertext_bytes[1:]) + with pytest.raises(ValueError): + decrypt(ciphertext, passphrase) + + +def test_encrypt_decrypt_and_save(): + plaintext = b'test' + passphrase = 'pass' + try: + filename = mkstemp()[1] + encrypt_and_save(plaintext, filename, passphrase, count=1) + assert decrypt_and_read(filename, passphrase) == plaintext + + # Bad header + with open(filename, 'w') as f: + f.write('wrong') + with pytest.raises(ValueError): + decrypt_and_read(filename, passphrase) + finally: + os.remove(filename) + + +def test_import_export(): + passphrase = 'pass' + device = OlmDevice(None, '@test:localhost', 'AUIETSRN') + out = olm.OutboundGroupSession() + session = MegolmInboundSession(out.session_key, 'signing_key') + device.megolm_inbound_sessions['room']['sender_key'][session.id] = session + + try: + filename = mkstemp()[1] + device.export_keys(filename, passphrase, count=1) + other_device = OlmDevice(None, '@test:localhost', 'AUIETSRN') + other_device.import_keys(filename, passphrase) + sessions = other_device.megolm_inbound_sessions['room']['sender_key'] + assert sessions[session.id].id == session.id + + # Unknown algorithn + other_device = OlmDevice(None, '@test:localhost', 'AUIETSRN') + other_device._megolm_algorithm = 'wrong' + other_device.import_keys(filename, passphrase) + assert not other_device.megolm_inbound_sessions + finally: + os.remove(filename) From c9f4938c18911756a67f9b31530833fc15669791 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 9 Aug 2018 18:42:42 +0200 Subject: [PATCH 64/66] add import/export methods to client --- matrix_client/client.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/matrix_client/client.py b/matrix_client/client.py index 03d5faba..48ec24ee 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -784,3 +784,33 @@ def get_fingerprint(self): if not self._encryption: raise ValueError("Encryption is not enabled, this device has no fingerprint.") return self.olm_device.ed25519 + + def export_keys(self, outfile, passphrase): + """Export all the Megolm decryption keys of this device. + + The keys will be encrypted using the passphrase. + + NOTE: + This does not save other information such as the private identity keys + of the device. + + Args: + outfile (str): The file to write the keys to. + passphrase (str): The encryption passphrase. + """ + if not self._encryption: + raise ValueError("Encryption is not enabled, there are no keys to export.") + self.olm_device.export_keys(outfile, passphrase) + + def import_keys(self, infile, passphrase): + """Import Megolm decryption keys. + + The keys will be added to the current instance as well as written to database. + + Args: + infile (str): The file containing the keys. + passphrase (str): The decryption passphrase. + """ + if not self._encryption: + raise ValueError("Encryption is not enabled, cannot import keys.") + self.olm_device.import_keys(infile, passphrase) From d30d5cb3ac46be1f99947e70d8885307c046122d Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 10 Aug 2018 12:52:56 +0200 Subject: [PATCH 65/66] add E2E friendly documentation --- E2E_overview.rst | 183 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 E2E_overview.rst diff --git a/E2E_overview.rst b/E2E_overview.rst new file mode 100644 index 00000000..f9f52dde --- /dev/null +++ b/E2E_overview.rst @@ -0,0 +1,183 @@ +Overview of end-to-end encryption in matrix-python-sdk +------------------------------------------------------ + +This SDK supports end-to-end encryption as specified in Matrix. The following is an +overview of the main available features. + +Encryption is mostly automatic, and users are not expected to read past the `basic +usage`_ section. + +.. contents:: + +Installation +============ + +Encryption requires `libolm`__, the official Matrix library that provides the necessary +cryptographic primitives. It is available in many Linux distribution repositories, and can +also be easily compiled from source. + +__ https://matrix.org/git/olm + +Encryption also comes with several optional dependencies, listed under the ``e2e`` group +in ``setup.py``. + +Using pip, these can be installed by running ``pip install .[e2e]`` at the root of the +repository. + +Encryption heavily rely on an underlying database, in order to work seamlessly across +restarts. This is implemented using SQLite and the sqlite3 module of the standard Python +library. Users do not not have to worry about this, and the database location is platform +dependent (and is displayed on start-up via an info log line). For advanced usage, see +`overriding the crypto store`_. + + +Basic usage +=========== + +Encryption support is disabled by default. Enabling it is done when instantiating +``MatrixClient``, as follow: + +.. code:: python + + client = MatrixClient(HOSTNAME, encryption=True) + +.. note:: + + When enabling encryption in an already existing project, you will notice that a lot of + logging messages appear. Most of those can be safely ignored. For instance, warning + messages on first sync simply mean that the client is unable to decrypt old messages + it didn't receive the keys for, as there are anterior to the encryption enabling. + +Device IDs +~~~~~~~~~~ + +When using encryption, a user **should** reuse device IDs, as they are associated with +a fingerprint key that should not change across restart, in most cases. The complete +rationale is explained `here`__. + +__ https://matrix.org/docs/guides/e2e_implementation.html#devices + +A user can keep track of device IDs by specifying them at login, or can delegate it to the +SDK, as follow: + +.. code:: python + + client = MatrixClient(HOSTNAME, encryption=True, restore_device_id=True) + +On first launch, the client will store the device ID returned by the homeserver in the +same database used to store encryption keys. On subsequent launches, the device ID will be +retrieved from the user ID at login. + +.. note:: + + When logging in with ``restore_device_id`` turned on, you must supply a full user ID (eg ``@test:matrix.org``), not just a username (eg ``test``). + +When using this, the need to reset the device ID automatically associated with a user ID +may arise. This can be done by explicitly specifying a device ID at login, or simply by +removing the database (consider using ``shred`` over ``rm``). Both of these methods will +delete all the encryption data associated with the previous device, as none can be safely +reused as-is with a new one. Hence, before doing this, a user might want to `export +encryption keys`_. + +.. note:: + + Refer to ``samples/e2e_overview.py`` for more example code. + + +Advanced usage +============== + +Several options are available in order to customize some behaviors, or to enable +additional features. These are abundantly documented via docstrings, and the following +subsections aim at showing some examples. + +Device verification +~~~~~~~~~~~~~~~~~~~ + +A major feature of end-to-end encryption is to make sure that the sender of a message is +the actual sender, and not an usurper. + +In order to allow other users to verify the current device, its fingerprint should be +displayed. This is done by calling ``client.get_fingerprint()``. + +Device verification is disabled by default. It can be enabled globally by passing +``verify_devices=True`` when instantiating ``MatrixClient``, or on a per-room basis by +doing ``room.verify_devices = True``. + +Once device verification is enabled in a room, sending messages to it will raise +``E2EUnknownDevices`` if there are some never seen before devices. A user should inspect +the ``user_devices`` attribute of this exception, and for each devices it contains, do +either: + + - ``device.verified = True`` if the device can be verified. New checks will be enabled + to ensure that every subsequent messages received from this device actually come from + it. + - ``device.blacklisted = True`` if decryption keys should never be shared with this + device. + - ``device.ignored = True`` if the device cannot be verified, and keys should be + sent to it anyway. + +Those verifications are persisted in database. + +.. note:: + + This section is incomplete (doesn't explain how to verify an event). + +Key sharing +~~~~~~~~~~~ + +A feature of the protocol is to be able to request and receive encryption keys from other +users. The SDK implements only the sharing of keys with devices of the current Matrix +user. + +Key sharing is disabled by default. A user has to implement non-trivial logic in order to +use it. + +The automatic request of keys can be enabled by adding a listener using +``MatrixClient.add_key_forward_listener(callback)``. The callback should be used to be +notified when a new key arrives, and it is advised to carefully read the docstring of this +method. A client only wanting to silently request and receive keys can add a callback +which does nothing. + +In order to reply to key requests, ``MatrixClient.add_key_request_listener(callback)`` +should be used. Refer to the docstring for more info. + +Encrypted attachments +~~~~~~~~~~~~~~~~~~~~~ + +.. TODO waiting for more convenient upload/download process + +Export encryption keys +~~~~~~~~~~~~~~~~~~~~~~ + +A user may want to import or export the encryption keys used in rooms, in order to be able +to decrypt messages on a new device. This can be done by using the ``export_keys`` and +``import_keys`` methods of ``MatrixClient``. + +Overriding the crypto store +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In order to use another storage method, the SQLite storage can be replaced by subclassing +the class ``CryptoStore`` and carefully reimplementing all the methods, which are +thoroughly documented for this purpose. The new class can then be used as follow: + +.. code:: python + + client = MatrixClient(HOSTNAME, encryption=True, encryption_conf={'Store': NewClass}) + +Changing the database file location +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This feature is especially useful when wanting to run several instances of +``MatrixClient`` in multiple *processes* (threads should work fine). The SQLite database +cannot be shared between processes (at least not without proper locking, which would have +to be implemented). Then the easiest way is to have one database per process. + +The ``CryptoStore`` class can be passed attributes ``db_path`` and ``db_name``. +Then, configuring the database to be stored as ``/foo/bar.db`` is done as follow: + +.. code:: python + + store_conf = {'db_path': '/foo/', 'db_name': 'bar.db'} + client = MatrixClient(HOSTNAME, encryption=True, + encryption_conf={'store_conf': store_conf}) From 150b0079e198592bca20a3d716f428ab626679c8 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Fri, 10 Aug 2018 12:53:17 +0200 Subject: [PATCH 66/66] add E2E sample code --- samples/e2e_overview.py | 69 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100755 samples/e2e_overview.py diff --git a/samples/e2e_overview.py b/samples/e2e_overview.py new file mode 100755 index 00000000..d0827a7c --- /dev/null +++ b/samples/e2e_overview.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +""" +This sample contains example code of how encryption can be used. +It requires a room to be already created. +""" +import sys + +from samples_common import get_user_details, get_input +from matrix_client.client import MatrixClient +from matrix_client.errors import E2EUnknownDevices + + +host, user_id, password = get_user_details(sys.argv) +try: + room_id = sys.argv[4] +except: + room_id = get_input("Room ID: ") + +## Basic usage + +client = MatrixClient(host, encryption=True, restore_device_id=True) +client.login(username=user_id, password=password) +device_id = client.device_id +room = client.join_room(room_id) + +if not room.encrypted: + room.send_text("Unencrypted!") + encrypted = room.enable_encryption() + if encrypted: + room.send_text("Encrypted!") + else: + room.send_text("Still unencrypted, insufficient power levels?") + +room.send_text("My autogenerated device ID is {}.".format(device_id)) + +client.logout() + +client = MatrixClient(host, encryption=True, restore_device_id=True) +client.login(username=user_id, password=password) +assert client.device_id == device_id + +room = client.join_room(room_id) +room.send_text("My device ID is still {}!".format(device_id)) + +## Advanced usage + +# Device verification + +room.verify_devices = True + +try: + room.send_text("Do I know everyone?") +except E2EUnknownDevices as e: + # We don't know anyone, but send anyway + for user_id, devices in e.user_devices.items(): + for device in devices: + device.ignored = True + # Out-of-band verification should allow to do device.verified = True instead + +room.send_text("Now I know everyone, kind of.") + + +# Key sharing + +# Print every keys which arrive to us +client.add_key_forward_listener(lambda x: print(x)) + +# Share keys everytime we receive a request from another of our devices (do not do this) +client.add_key_request_listener(lambda x, f: f(x))