diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index af1ebc41..353eb5a5 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -56,3 +56,13 @@ matrix_client.crypto :members: :undoc-members: :show-inheritance: + +.. automodule:: matrix_client.crypto.device_list + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: matrix_client.crypto.megolm_outbound_session + :members: + :undoc-members: + :show-inheritance: diff --git a/matrix_client/api.py b/matrix_client/api.py index 8bd44ecc..102d3e5c 100644 --- a/matrix_client/api.py +++ b/matrix_client/api.py @@ -343,6 +343,18 @@ def send_content(self, room_id, item_url, item_name, msg_type, return self.send_message_event(room_id, "m.room.message", content_pack, timestamp=timestamp) + def get_location_body(self, geo_uri, name, thumb_url=None, thumb_info=None): + content_pack = { + "geo_uri": geo_uri, + "msgtype": "m.location", + "body": name, + } + if thumb_url: + content_pack["thumbnail_url"] = thumb_url + if thumb_info: + content_pack["thumbnail_info"] = thumb_info + return content_pack + # http://matrix.org/docs/spec/client_server/r0.2.0.html#m-location def send_location(self, room_id, geo_uri, name, thumb_url=None, thumb_info=None, timestamp=None): @@ -356,15 +368,8 @@ def send_location(self, room_id, geo_uri, name, thumb_url=None, thumb_info=None, thumb_info (dict): Metadata about the thumbnail, type ImageInfo. timestamp (int): Set origin_server_ts (For application services only) """ - content_pack = { - "geo_uri": geo_uri, - "msgtype": "m.location", - "body": name, - } - if thumb_url: - content_pack["thumbnail_url"] = thumb_url - if thumb_info: - content_pack["thumbnail_info"] = thumb_info + content_pack = self.get_location_body( + geo_uri, name, thumb_url, thumb_info) return self.send_message_event(room_id, "m.room.message", content_pack, timestamp=timestamp) @@ -405,12 +410,11 @@ def send_notice(self, room_id, text_content, timestamp=None): text_content (str): The m.notice body to send. timestamp (int): Set origin_server_ts (For application services only) """ - body = { - "msgtype": "m.notice", - "body": text_content - } - return self.send_message_event(room_id, "m.room.message", body, - timestamp=timestamp) + return self.send_message_event( + room_id, "m.room.message", + self.get_notice_body(text_content), + timestamp=timestamp + ) def get_room_messages(self, room_id, token, direction, limit=10, to=None): """Perform GET /rooms/{roomId}/messages. @@ -675,6 +679,12 @@ def get_emote_body(self, text): "body": text } + def get_notice_body(self, text): + return { + "msgtype": "m.notice", + "body": text + } + def get_filter(self, user_id, filter_id): return self._send("GET", "/user/{userId}/filter/{filterId}" .format(userId=user_id, filterId=filter_id)) diff --git a/matrix_client/client.py b/matrix_client/client.py index 703b825c..0c2a8af6 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -581,6 +581,15 @@ def _mkroom(self, room_id): # TODO better handling of the blocking I/O caused by update_one_time_key_counts def _sync(self, timeout_ms=30000): response = self.api.sync(self.sync_token, timeout_ms, filter=self.sync_filter) + + if self._encryption and 'device_lists' in response: + if response['device_lists'].get('changed'): + self.olm_device.device_list.update_user_device_keys( + response['device_lists']['changed'], self.sync_token) + if response['device_lists'].get('left'): + self.olm_device.device_list.stop_tracking_users( + response['device_lists']['left']) + self.sync_token = response["next_batch"] for presence_update in response['presence']['events']: @@ -627,6 +636,10 @@ def _sync(self, timeout_ms=30000): ): listener['callback'](event) + if self._encryption and room.encrypted: + # Track the new users in the room + self.olm_device.device_list.track_pending_users() + for event in sync_room['ephemeral']['events']: event['room_id'] = room_id room._put_ephemeral_event(event) diff --git a/matrix_client/crypto/device_list.py b/matrix_client/crypto/device_list.py new file mode 100644 index 00000000..c0bf406a --- /dev/null +++ b/matrix_client/crypto/device_list.py @@ -0,0 +1,292 @@ +import logging +from collections import defaultdict +from threading import Thread, Condition, Event, Lock + +from matrix_client.errors import MatrixHttpLibError, MatrixRequestError + +logger = logging.getLogger(__name__) + + +class DeviceList: + """Allows to maintain a list of devices up-to-date for an OlmDevice. + + Offers blocking and non-blocking methods to fetch device keys when appropriate. + NOTE: Spawns a thread that will last until program termination. + + Args: + olm_device (OlmDevice): Will be used to get additional info, such as device id. + api (MatrixHttpApi): The api object used to make requests. + device_keys (defaultdict(dict)): A map from user to device to keys. + """ + + def __init__(self, olm_device, api, device_keys): + self.olm_device = olm_device + self.api = api + self.device_keys = device_keys + # Stores the ids of users who need updating + self.outdated_user_ids = _OutdatedUsersSet() + # Stores the ids of users to fetch the device keys of eventually + self.pending_outdated_user_ids = set() + self.pending_users_lock = Lock() + # Stores the ids of users we are currently tracking. We can assume the device + # keys of these users are up-to-date as long as no downloading is in progress. + # We should track every user we share an encrypted room with. + self.tracked_user_ids = set() + # Allows to wake up the thread when there are new users to update, and to + # synchronise shared data. + self.thread_condition = Condition() + self.update_thread = _UpdateDeviceList( + self.thread_condition, self.outdated_user_ids, self._download_device_keys, + self.tracked_user_ids + ) + self.update_thread.start() + + def get_room_device_keys(self, room, blocking=True): + """Gets the keys of all devices present in the room. + + Makes sure not to download keys of users we are already tracking. + The users we were not yet tracking will get tracked automatically. + + Args: + room (Room): The room to use. + blocking (bool): Optional. Whether to wait for the keys to have been + downloaded before returning. + """ + logger.info('Fetching all missing keys in room %s.', room.room_id) + user_ids = {u.user_id for u in room.get_joined_members()} - self.tracked_user_ids + if not user_ids: + logger.info('Already had all the keys in room %s.', room.room_id) + if blocking: + # Wait on an eventual download to finish + self.update_thread.event.wait() + return + with self.thread_condition: + self.outdated_user_ids.update(user_ids) + if blocking: + # Will ensure the user_ids we just added are processed + event = Event() + self.outdated_user_ids.events.add(event) + self.thread_condition.notify() + if blocking: + event.wait() + + def track_user_no_download(self, user_id): + """Add user to be tracked, but do not track it instantly. + + This should be used to avoid making calls to :func:`track_users` with only one + user repeatedly. Instead, using this should allow to passively queue the users, + and tracking can be triggered when there are sufficiently, using + :func:`track_pending_users`. + + Args: + user_id (str): A user id. + """ + with self.pending_users_lock: + self.pending_outdated_user_ids.add(user_id) + + def track_pending_users(self): + """Triggers the tracking of the user added with :func:`track_user_no_download`.""" + with self.pending_users_lock: + self.track_users(self.pending_outdated_user_ids) + self.pending_outdated_user_ids.clear() + + def track_users(self, user_ids): + """Add users to be tracked, and download their device keys. + + NOTE: this is non-blocking and will return before the keys are downloaded. + + Args: + user_ids (iterable): Any iterable containing user ids. + """ + user_ids = user_ids.difference(self.tracked_user_ids) + if user_ids: + self._add_outdated_users(user_ids) + + def stop_tracking_users(self, user_ids): + """Stop tracking users. + + NOTE: Keys will not be deleted. + + Args: + user_ids (iterable): Any iterable containing user ids. + """ + with self.thread_condition: + self.tracked_user_ids.difference_update(user_ids) + self.outdated_user_ids.difference_update(user_ids) + logger.info('Stopped tracking users: %s.', user_ids) + + def update_user_device_keys(self, user_ids, since_token=None): + """Triggers an update for users we already track. + + Args: + user_ids (iterable): Any iterable containing user ids. + since_token (str): Optional. Since token of a sync request, if triggering + the update as a result of that sync request. + """ + user_ids = self.tracked_user_ids.intersection(user_ids) + if not user_ids: + return + logger.info('Updating the device lists of users: %s, using token %s', + user_ids, since_token) + self._add_outdated_users(user_ids, since_token=since_token) + + def _add_outdated_users(self, user_ids, since_token=None): + """Stop tracking users. Keys will not be deleted. + + Args: + user_ids (iterable): Any iterable containing user ids. + since_token (str): Optional. Since token of a sync request. + """ + with self.thread_condition: + self.outdated_user_ids.update(user_ids) + if since_token: + self.outdated_user_ids.sync_token = since_token + self.thread_condition.notify() + + def _download_device_keys(self, user_devices, since_token=None): + """Download and store device keys, if they pass security checks. + + Args: + user_devices (dict): Format is ``user_id: [device_ids]``. + since_token (str): Optional. Since token of a sync request. + """ + changed = defaultdict(dict) + resp = self.api.query_keys(user_devices, token=since_token) + if resp.get('failures'): + logger.warning('Failed to download keys from the following unreachable ' + 'homeservers %s.', resp['failures']) + device_keys = resp['device_keys'] + for user_id in user_devices: + # The response might not contain every user_ids we requested + for device_id, payload in device_keys.get(user_id, {}).items(): + if device_id == self.olm_device.device_id: + continue + if payload['user_id'] != user_id or payload['device_id'] != device_id: + logger.warning('Mismatch in keys payload of device %s (%s) of user ' + '%s (%s).', payload['device_id'], device_id, + payload['user_id'], user_id) + continue + try: + signing_key = payload['keys']['ed25519:{}'.format(device_id)] + curve_key = payload['keys']['curve25519:{}'.format(device_id)] + except KeyError as e: + logger.warning('Invalid identity keys payload from device %s of' + 'user %s: %s.', device_id, user_id, e) + continue + verified = self.olm_device.verify_json( + payload, signing_key, user_id, device_id) + if not verified: + logger.warning('Signature verification failed for device %s of ' + 'user %s.', device_id, user_id) + continue + keys = self.device_keys[user_id].setdefault(device_id, {}) + if keys: + if keys['ed25519'] != signing_key: + logger.warning('Ed25519 key has changed for device %s of ' + 'user %s.', device_id, user_id) + continue + if keys['curve25519'] == curve_key: + continue + else: + keys['ed25519'] = signing_key + keys['curve25519'] = curve_key + changed[user_id][device_id] = keys + + logger.info('Successfully downloaded keys for devices: %s.', + {user_id: list(changed[user_id]) for user_id in changed}) + return changed + + +class _OutdatedUsersSet(set): + """Allows to know if elements in a set have been processed. + + This is done by adding elements along with an Event object. Then, functions + processing the set should set the events when they are done. + """ + + def __init__(self, iterable=()): + self.events = set() + self._sync_token = None + super(_OutdatedUsersSet, self).__init__(iterable) + + def mark_as_processed(self): + for event in self.events: + event.set() + + def copy(self): + new_set = _OutdatedUsersSet(self) + new_set.events = self.events.copy() + return new_set + + def clear(self): + self.events.clear() + super(_OutdatedUsersSet, self).clear() + + def update(self, iterable): + super(_OutdatedUsersSet, self).update(iterable) + if isinstance(iterable, _OutdatedUsersSet): + self.events.update(iterable.events) + + @property + def sync_token(self): + return self._sync_token + + @sync_token.setter + def sync_token(self, token): + if not self._sync_token or token > self._sync_token: + self._sync_token = token + + +class _UpdateDeviceList(Thread): + + def __init__(self, cond, user_ids, download_method, tracked_user_ids): + # We wait on this condition when there is nothing to do. Outside code should use + # it to notify us when they add data to be processed in outdated_user_ids so that + # we can wake up and process it. + self.cond = cond + self.outdated_user_ids = user_ids + self.download = download_method + self.tracked_user_ids = tracked_user_ids + # Cleared when we start a download, and set when we have finished it. This can be + # used by outside code in order to know if we are in the middle of a download, and + # allows to wait for it to complete by waiting on this event. + self.event = Event() + # Used internally to terminate gracefully on program exit. + self._should_terminate = Event() + super(_UpdateDeviceList, self).__init__() + + def run(self): + while True and not self._should_terminate.is_set(): + with self.cond: + while not self.outdated_user_ids: + # Avoid any deadlocks + self.outdated_user_ids.mark_as_processed() + self.event.set() + logger.debug('Update thread is going to sleep...') + self.cond.wait() + logger.debug('Update thread woke up!') + if self._should_terminate.is_set(): + return + to_download = self.outdated_user_ids.copy() + self.outdated_user_ids.clear() + self.event.clear() + self.tracked_user_ids.update(to_download) + payload = {user_id: [] for user_id in to_download} + logger.info('Downloading device keys for users: %s.', to_download) + try: + self.download(payload, self.outdated_user_ids.sync_token) + self.event.set() + to_download.mark_as_processed() + except (MatrixHttpLibError, MatrixRequestError) as e: + logger.warning('Network error when fetching device keys (will retry): %s', + e) + with self.cond: + self.outdated_user_ids.update(to_download) + + def join(self, timeout=None): + # If we are joined, this means that the main program is terminating. + # We should terminate too. + self._should_terminate.set() + with self.cond: + self.cond.notify() + super(_UpdateDeviceList, self).join(timeout=timeout) diff --git a/matrix_client/crypto/megolm_outbound_session.py b/matrix_client/crypto/megolm_outbound_session.py new file mode 100644 index 00000000..cff87f6a --- /dev/null +++ b/matrix_client/crypto/megolm_outbound_session.py @@ -0,0 +1,60 @@ +from datetime import datetime, timedelta + +from olm import OutboundGroupSession + + +class MegolmOutboundSession(OutboundGroupSession): + + """Outbound group session aware of the users it is shared with. + + Also remembers the time it was created and the number of messages it has encrypted, + in order to know if it needs to be rotated. + + Args: + max_age (datetime.timedelta): Optional. The maximum time the session should + exist. Default to one week if not present. + max_messages (int): Optional. The maximum number of messages that should be sent. + A new message in considered sent each time there is a call to ``encrypt``. + Default to 100 if not present. + """ + + def __init__(self, max_age=None, max_messages=None): + self.devices = set() + if max_age: + self.max_age = timedelta(milliseconds=max_age) + else: + self.max_age = timedelta(days=7) + self.max_messages = max_messages or 100 + self.creation_time = datetime.now() + self.message_count = 0 + super(MegolmOutboundSession, self).__init__() + + def __new__(cls, **kwargs): + return super(MegolmOutboundSession, cls).__new__(cls) + + def add_device(self, device_id): + """Adds a device the session is shared with.""" + self.devices.add(device_id) + + def add_devices(self, device_ids): + """Adds devices the session is shared with. + + Args: + device_ids (iterable): An iterable of device ids, preferably a set. + """ + self.devices.update(device_ids) + + def should_rotate(self): + """Wether the session should be rotated. + + Returns: + True if it should, False if not. + """ + if self.message_count >= self.max_messages or \ + datetime.now() - self.creation_time >= self.max_age: + return True + return False + + def encrypt(self, plaintext): + self.message_count += 1 + return super(MegolmOutboundSession, self).encrypt(plaintext) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 514965db..8b2a344d 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -1,10 +1,14 @@ +import json import logging +from collections import defaultdict import olm from canonicaljson import encode_canonical_json from matrix_client.checks import check_user_id from matrix_client.crypto.one_time_keys import OneTimeKeysManager +from matrix_client.crypto.device_list import DeviceList +from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession logger = logging.getLogger(__name__) @@ -59,6 +63,10 @@ def __init__(self, self.one_time_keys_manager = OneTimeKeysManager(target_keys_number, signed_keys_proportion, keys_threshold) + self.device_keys = defaultdict(dict) + self.device_list = DeviceList(self, api, self.device_keys) + self.olm_sessions = defaultdict(list) + self.megolm_outbound_sessions = {} def upload_identity_keys(self): """Uploads this device's identity keys to HS. @@ -136,6 +144,400 @@ def update_one_time_key_counts(self, counts): logger.info('Uploading new one-time keys.') self.upload_one_time_keys() + def olm_start_sessions(self, user_devices): + """Start olm sessions with the given devices. + + NOTE: those device keys must already be known. + + Args: + user_devices (dict): A map from user_id to an iterable of device_ids. + The format is ``: []``. + """ + logger.info('Trying to establish Olm sessions with devices: %s.', + dict(user_devices)) + payload = defaultdict(dict) + for user_id in user_devices: + for device_id in user_devices[user_id]: + payload[user_id][device_id] = 'signed_curve25519' + + resp = self.api.claim_keys(payload) + if resp.get('failures'): + logger.warning('Failed to claim one-time keys from the following unreachable ' + 'homeservers: %s.', resp['failures']) + keys = resp['one_time_keys'] + if logger.level >= logging.WARNING: + missing = {} + for user_id, device_ids in user_devices.items(): + if user_id not in keys: + missing[user_id] = device_ids + else: + missing_devices = set(device_ids) - set(keys[user_id]) + if missing_devices: + missing[user_id] = missing_devices + logger.warning('Failed to claim the keys of %s.', missing) + + for user_id in user_devices: + for device_id, one_time_key in keys.get(user_id, {}).items(): + try: + device_keys = self.device_keys[user_id][device_id] + except KeyError: + logger.warning('Key for device %s of user %s not found, could not ' + 'start Olm session.', device_id, user_id) + continue + key_object = next(iter(one_time_key.values())) + verified = self.verify_json(key_object, + device_keys['ed25519'], + user_id, + device_id) + if verified: + session = olm.OutboundSession(self.olm_account, + device_keys['curve25519'], + key_object['key']) + sessions = self.olm_sessions[device_keys['curve25519']] + sessions.append(session) + logger.info('Established Olm session %s with device %s of user ' + '%s.', device_id, session.id, user_id) + else: + logger.warning('Signature verification for one-time key of device %s ' + 'of user %s failed, could not start olm session.', + device_id, user_id) + + def olm_build_encrypted_event(self, event_type, content, user_id, device_id): + """Encrypt an event using Olm. + + NOTE: a session with this device must already be established. + + Args: + event_type (str): The event type, will be encrypted. + content (dict): The event content, will be encrypted. + user_id (str): The intended recipient of the event. + device_id (str): The device to encrypt to. + + Returns: + The Olm encrypted event, as JSON. + """ + try: + keys = self.device_keys[user_id][device_id] + except KeyError: + raise RuntimeError('Device is unknown, could not encrypt.') + + signing_key = keys['ed25519'] + identity_key = keys['curve25519'] + + payload = { + 'type': event_type, + 'content': content, + 'sender': self.user_id, + 'sender_device': self.device_id, + 'keys': { + 'ed25519': self.identity_keys['ed25519'] + }, + 'recipient': user_id, + 'recipient_keys': { + 'ed25519': signing_key + } + } + + sessions = self.olm_sessions[identity_key] + if sessions: + session = sorted(sessions, key=lambda s: s.id)[0] + else: + raise RuntimeError('No session for this device, could not encrypt.') + + encrypted_message = session.encrypt(json.dumps(payload)) + ciphertext_payload = { + identity_key: { + 'type': encrypted_message.message_type, + 'body': encrypted_message.ciphertext + } + } + + event = { + 'algorithm': self._olm_algorithm, + 'sender_key': self.identity_keys['curve25519'], + 'ciphertext': ciphertext_payload + } + return event + + def olm_decrypt_event(self, content, user_id): + """Decrypt an Olm encrypted event, and check its properties. + + Args: + event (dict): The content property of a m.room.encrypted event. + user_id (str): The sender of the event. + + Retuns: + The decrypted event held by the initial event. + + Raises: + RuntimeError: Error in the decryption process. Nothing can be done. The text + of the exception indicates what went wrong, and should be logged or + displayed to the user. + KeyError: The event is missing a required field. + """ + if content['algorithm'] != self._olm_algorithm: + raise RuntimeError('Event was not encrypted with {}.' + .format(self._olm_algorithm)) + + ciphertext = content['ciphertext'] + try: + payload = ciphertext[self.identity_keys['curve25519']] + except KeyError: + raise RuntimeError('This message was not encrypted for us.') + + msg_type = payload['type'] + if msg_type == 0: + encrypted_message = olm.OlmPreKeyMessage(payload['body']) + else: + encrypted_message = olm.OlmMessage(payload['body']) + + decrypted_event = self._olm_decrypt(encrypted_message, content['sender_key']) + + if decrypted_event['sender'] != user_id: + raise RuntimeError( + 'Found user {} instead of sender {} in Olm plaintext {}.' + .format(decrypted_event['sender'], user_id, decrypted_event) + ) + if decrypted_event['recipient'] != self.user_id: + raise RuntimeError( + 'Found user {} instead of us ({}) in Olm plaintext {}.' + .format(decrypted_event['recipient'], self.user_id, decrypted_event) + ) + our_key = decrypted_event['recipient_keys']['ed25519'] + if our_key != self.identity_keys['ed25519']: + raise RuntimeError( + 'Found key {} instead of ours own ed25519 key {} in Olm plaintext {}.' + .format(our_key, self.identity_keys['ed25519'], decrypted_event) + ) + + return decrypted_event + + def _olm_decrypt(self, olm_message, sender_key): + """Decrypt an Olm encrypted event. + + NOTE: This does no implement any security check. + + Try to decrypt using existing sessions. If it fails, start an new one when + possible. + + Args: + olm_message (OlmMessage): Olm encrypted payload. + sender_key (str): The sender's curve25519 identity key. + + Returns: + The decrypted event held by the initial payload, as JSON. + """ + + sessions = self.olm_sessions[sender_key] + + # Try to decrypt message body using one of the known sessions for that device + for session in sessions: + try: + event = session.decrypt(olm_message) + logger.info('Success decrypting Olm event using existing session %s.', + session.id) + break + except olm.session.OlmSessionError as e: + if olm_message.message_type == 0: + if session.matches(olm_message, sender_key): + # We had a matching session for a pre-key message, but it didn't + # work. This means something is wrong, so we fail now. + raise RuntimeError('Error decrypting pre-key message with ' + 'existing Olm session {}, reason: {}.' + .format(session.id, e)) + # Simply keep trying otherwise + else: + if olm_message.message_type > 0: + # Not a pre-key message, we should have had a matching session + if sessions: + raise RuntimeError('Error decrypting with existing sessions.') + raise RuntimeError('No existing sessions.') + + # We have a pre-key message without any matching session, in this case + # we should try to create one. + try: + session = olm.session.InboundSession( + self.olm_account, olm_message, sender_key) + except olm.session.OlmSessionError as e: + raise RuntimeError('Error decrypting pre-key message when trying to ' + 'establish a new session: {}.'.format(e)) + + logger.info('Created new Olm session %s.', session.id) + try: + event = session.decrypt(olm_message) + except olm.session.OlmSessionError as e: + raise RuntimeError('Error decrypting pre-key message with new session: ' + '{}.'.format(e)) + self.olm_account.remove_one_time_keys(session) + sessions.append(session) + + return json.loads(event) + + def olm_ensure_sessions(self, user_devices): + """Start Olm sessions with the given devices if one doesn't exist already. + + Args: + user_devices (dict): A map from user ids to a list of device ids. + """ + user_devices_no_session = defaultdict(list) + for user_id in user_devices: + for device_id in user_devices[user_id]: + curve_key = self.device_keys[user_id][device_id]['curve25519'] + # Check if we have a list of sessions for this device, which can be + # empty. Implicitely, an empty list will indicate that we already tried + # to establish a session with a device, but this attempt was + # unsuccessful. We do not retry to establish a session. + if curve_key not in self.olm_sessions: + user_devices_no_session[user_id].append(device_id) + if user_devices_no_session: + self.olm_start_sessions(user_devices_no_session) + + def megolm_start_session(self, room): + """Start a megolm session in a room, and share it with its members. + + Args: + room (Room): The room to use. + + Returns: + The newly created session. + """ + session = MegolmOutboundSession(max_age=room.rotation_period_ms, + max_messages=room.rotation_period_msgs) + self.megolm_outbound_sessions[room.room_id] = session + logger.info('Starting a new Meglom outbound session %s in %s.', + session.id, room.room_id) + + users = room.get_joined_members() + user_devices = {user.user_id: list(self.device_keys[user.user_id]) + for user in users} + self.device_list.get_room_device_keys(room) + self.megolm_share_session(room.room_id, user_devices, session) + return session + + def megolm_share_session(self, room_id, user_devices, session): + """Share an already existing outbound megolm session with the specified devices. + + Args: + room_id (str): The room corresponding to the session. + user_devices (dict): A map from user ids to a list of device ids. + session (MegolmOutboundSession): The session object. + """ + + logger.info('Attempting to share Megolm session %s in %s with %s.', + session.id, room_id, user_devices) + self.olm_ensure_sessions(user_devices) + + event = { + 'algorithm': self._megolm_algorithm, + 'room_id': room_id, + 'session_id': session.id, + 'session_key': session.session_key + } + + messages = defaultdict(dict) + new_devices = set() + for user_id in user_devices: + for device_id in user_devices[user_id]: + try: + messages[user_id][device_id] = self.olm_build_encrypted_event( + 'm.room_key', event, user_id, device_id + ) + except RuntimeError as e: + logger.warning('Could not share megolm session %s with device %s of ' + 'user %s: %s', session.id, + device_id, user_id, e) + # We will not retry to share session with failed devices + new_devices.add(device_id) + self.api.send_to_device('m.room.encrypted', messages) + session.add_devices(new_devices) + + def megolm_share_session_with_new_devices(self, room, session): + """Share a megolm session with new devices in a room. + + Args: + room (Room): The room corresponding to the session. + session (MegolmOutboundSession): The session to share. + """ + user_devices = {} + users = room.get_joined_members() + for user in users: + user_id = user.user_id + missing_devices = list(set(self.device_keys[user_id].keys()) - + self.megolm_outbound_sessions[room.room_id].devices) + if missing_devices: + user_devices[user_id] = missing_devices + if user_devices: + logger.info('Sharing existing Megolm outbound session %s with new devices: ' + '%s', session.id, user_devices) + self.megolm_share_session(room.room_id, user_devices, session) + + def megolm_build_encrypted_event(self, room, event): + """Build an encrypted Megolm payload from a plaintext event. + + If no session exists in the room, a new one will be initiated. Also takes care + of rotating the session periodically. + + Args: + room (Room): The room the event will be sent in. + event (dict): Matrix event. + + Returns: + The encrypted event, as a dict. + """ + room_id = room.room_id + + session = self.megolm_outbound_sessions.get(room_id) + if not session or session.should_rotate(): + session = self.megolm_start_session(room) + else: + self.megolm_share_session_with_new_devices(room, session) + + payload = { + 'type': event['type'], + 'content': event['content'], + 'room_id': room_id + } + + encrypted_payload = session.encrypt(json.dumps(payload)) + + encrypted_event = { + 'algorithm': self._megolm_algorithm, + 'sender_key': self.identity_keys['curve25519'], + 'ciphertext': encrypted_payload, + 'session_id': session.id, + 'device_id': self.device_id + } + return encrypted_event + + def megolm_remove_outbound_session(self, room_id): + """Remove an existing Megolm outbound session in a room. + + If there is no such session, nothing will happen. + + Args: + room_id (str): The room to use. + """ + try: + self.megolm_outbound_sessions.pop(room_id) + logger.info('Removed Meglom outbound session in %s.', room_id) + except KeyError: + pass + + def send_encrypted_message(self, room, content): + """Send a m.room.encrypted event in a room. + + Args: + room (Room): The room to use. + content (dict): The content of the event, will be encrypted. + + Raises: + MatrixRequestError if there was an error sending the event. + """ + event = {'content': content, 'room_id': room.room_id, 'type': 'm.room.message'} + encrypted_event = self.megolm_build_encrypted_event(room, event) + return self.api.send_message_event( + room.room_id, 'm.room.encrypted', encrypted_event) + def sign_json(self, json): """Signs a JSON object. diff --git a/matrix_client/errors.py b/matrix_client/errors.py index e9dc8fe3..98bd5eb0 100644 --- a/matrix_client/errors.py +++ b/matrix_client/errors.py @@ -46,3 +46,10 @@ def __init__(self, original_exception, method, endpoint): original_exception) ) self.original_exception = original_exception + + +class MatrixNoEncryptionError(MatrixError): + """Encryption was not available.""" + + def __init__(self, content=""): + super(MatrixNoEncryptionError, self).__init__(content) diff --git a/matrix_client/room.py b/matrix_client/room.py index 488f335b..120f65d2 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -18,7 +18,7 @@ from .checks import check_room_id from .user import User -from .errors import MatrixRequestError +from .errors import MatrixRequestError, MatrixNoEncryptionError class Room(object): @@ -50,6 +50,8 @@ def __init__(self, client, room_id): # user_id: displayname } self.encrypted = False + self.rotation_period_msgs = None + self.rotation_period_ms = None def set_user_profile(self, displayname=None, @@ -102,7 +104,10 @@ def display_name(self): def send_text(self, text): """Send a plain text message to the room.""" - return self.client.api.send_message(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_text_body(text)) + else: + return self.client.api.send_message(self.room_id, text) def get_html_content(self, html, body=None, msgtype="m.text"): return { @@ -119,8 +124,12 @@ def send_html(self, html, body=None, msgtype="m.text"): html (str): The html formatted message to be sent. body (str): The unformatted body of the message to be sent. """ - return self.client.api.send_message_event( - self.room_id, "m.room.message", self.get_html_content(html, body, msgtype)) + content = self.get_html_content(html, body, msgtype) + if self.encrypted and self.client._encryption: + return self.send_encrypted(content) + else: + return self.client.api.send_message_event( + self.room_id, "m.room.message", content) def set_account_data(self, type, account_data): return self.client.api.set_room_account_data( @@ -142,7 +151,10 @@ def add_tag(self, tag, order=None, content=None): def send_emote(self, text): """Send an emote (/me style) message to the room.""" - return self.client.api.send_emote(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_emote_body(text)) + else: + return self.client.api.send_emote(self.room_id, text) def send_file(self, url, name, **fileinfo): """Send a pre-uploaded file to the room. @@ -163,7 +175,10 @@ def send_file(self, url, name, **fileinfo): def send_notice(self, text): """Send a notice (from bot) message to the room.""" - return self.client.api.send_notice(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_notice_body(text)) + else: + return self.client.api.send_notice(self.room_id, text) # See http://matrix.org/docs/spec/r0.0.1/client_server.html#m-image for the # imageinfo args. @@ -195,8 +210,13 @@ def send_location(self, geo_uri, name, thumb_url=None, **thumb_info): thumb_url (str): URL to the thumbnail of the location. thumb_info (): Metadata about the thumbnail, type ImageInfo. """ - return self.client.api.send_location(self.room_id, geo_uri, name, - thumb_url, thumb_info) + if self.encrypted and self.client._encryption: + content = self.client.api.get_location_body( + geo_uri, name, thumb_url, thumb_info) + return self.send_encrypted(content) + else: + return self.client.api.send_location(self.room_id, geo_uri, name, + thumb_url, thumb_info) def send_video(self, url, name, **videoinfo): """Send a pre-uploaded video to the room. @@ -226,6 +246,22 @@ def send_audio(self, url, name, **audioinfo): return self.client.api.send_content(self.room_id, url, name, "m.audio", extra_information=audioinfo) + def send_encrypted(self, content): + """Send an arbitrary encrypted message to the room. + + Args: + content (dict): The content of a m.room.message event. + + Raises: + ``MatrixNoEncryptionError`` if encryption is not enabled in client, or if + the room is unencrypted. + """ + if not self.client._encryption: + raise MatrixNoEncryptionError('Encryption is not enabled in client.') + if not self.encrypted: + raise MatrixNoEncryptionError('Encryption is not enabled in the room.') + return self.client.olm_device.send_encrypted_message(self, content) + def redact_message(self, event_id, reason=None): """Redacts the message with specified event_id for the given reason. @@ -618,18 +654,35 @@ def set_guest_access(self, allow_guests): except MatrixRequestError: return False - def enable_encryption(self): + def enable_encryption(self, rotation_period_ms=None, rotation_period_msgs=None): """Enables encryption in the room. NOTE: Once enabled, encryption cannot be disabled. + Args: + rotation_period_ms (int): Optional. Lifetime of Megolm sessions in the room. + rotation_period_msgs (int): Optional. Number of messages that can be encrypted + by a Megolm session before it is rotated. + Returns: - True if successful, False if not + ``True`` if successful, ``False`` if not. + If the room was already encrypted, True is returned, but the + ``rotation_period_ms`` and ``rotation_period_ms`` parameters have no effect. """ + if self.encrypted: + return True + + content = {"algorithm": "m.megolm.v1.aes-sha2"} + if rotation_period_ms: + content["rotation_period_ms"] = rotation_period_ms + if rotation_period_msgs: + content["rotation_period_msgs"] = rotation_period_msgs + try: - self.send_state_event("m.room.encryption", - {"algorithm": "m.megolm.v1.aes-sha2"}) + self.send_state_event("m.room.encryption", content) self.encrypted = True + self.rotation_period_ms = rotation_period_ms + self.rotation_period_msgs = rotation_period_msgs return True except MatrixRequestError: return False @@ -658,13 +711,26 @@ def _process_state_event(self, state_event): elif etype == "m.room.encryption": if econtent.get("algorithm") == "m.megolm.v1.aes-sha2": self.encrypted = True + if not self.rotation_period_ms: + self.rotation_period_ms = econtent.get("rotation_period_ms") + if not self.rotation_period_msgs: + self.rotation_period_msgs = econtent.get("rotation_period_msgs") elif etype == "m.room.member" and clevel == clevel.ALL: # tracking room members can be large e.g. #matrix:matrix.org + user_id = state_event["state_key"] if econtent["membership"] == "join": - user_id = state_event["state_key"] self._add_member(user_id, econtent.get("displayname")) + if self.client._encryption and self.encrypted: + # Track the device list of this user + self.client.olm_device.device_list.track_user_no_download(user_id) elif econtent["membership"] in ("leave", "kick", "invite"): - self._members.pop(state_event["state_key"], None) + self._members.pop(user_id, None) + if econtent["membership"] != "invite": + if self.client._encryption and self.encrypted: + # Invalidate any outbound session we have in the room when + # someone leaves + self.client.olm_device.megolm_remove_outbound_session( + self.room_id) for listener in self.state_listeners: if ( diff --git a/test/client_test.py b/test/client_test.py index c5884924..5fa84762 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -155,13 +155,30 @@ def test_state_event(): # test encryption room.encrypted = False ev["type"] = "m.room.encryption" - ev["content"] = {"algorithm": "m.megolm.v1.aes-sha2"} + ev["content"] = { + "algorithm": "m.megolm.v1.aes-sha2", + "rotation_period_msgs": 50, + "rotation_period_ms": 100000, + } room._process_state_event(ev) assert room.encrypted + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 # encrypted flag must not be cleared on configuration change ev["content"] = {"algorithm": None} room._process_state_event(ev) assert room.encrypted + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 + # nor should the session parameters be changed + ev["content"] = { + "algorithm": "m.megolm.v1.aes-sha2", + "rotation_period_msgs": 5, + "rotation_period_ms": 10000, + } + room._process_state_event(ev) + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 def test_get_user(): @@ -519,6 +536,15 @@ def test_enable_encryption_in_room(): assert room.enable_encryption() assert room.encrypted + room = client._mkroom(room_id) + assert room.enable_encryption(rotation_period_msgs=1, rotation_period_ms=1) + assert room.rotation_period_msgs == 1 + assert room.rotation_period_ms == 1 + + assert room.enable_encryption(rotation_period_msgs=2) + # The room was already encrypted, we should not have changed its attribute + assert room.rotation_period_msgs == 1 + @responses.activate def test_detect_encryption_state(): diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py new file mode 100644 index 00000000..69df1ed3 --- /dev/null +++ b/test/crypto/device_list_test.py @@ -0,0 +1,282 @@ +import pytest +pytest.importorskip("olm") # noqa + +import json +from copy import deepcopy +from threading import Event, Condition + +import responses + +from matrix_client.api import MATRIX_V2_API_PATH +from matrix_client.client import MatrixClient +from matrix_client.room import User +from matrix_client.errors import MatrixRequestError +from matrix_client.crypto.olm_device import OlmDevice +from matrix_client.crypto.device_list import (_OutdatedUsersSet as OutdatedUsersSet, + _UpdateDeviceList as UpdateDeviceList) +from test.response_examples import example_key_query_response + +HOSTNAME = 'http://example.com' + + +class TestDeviceList: + cli = MatrixClient(HOSTNAME) + user_id = '@test:example.com' + alice = '@alice:example.com' + room_id = '!test:example.com' + device_id = 'AUIETSRN' + device = OlmDevice(cli.api, user_id, device_id) + device_list = device.device_list + signing_key = device.olm_account.identity_keys['ed25519'] + query_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/query' + + @responses.activate + def test_download_device_keys(self): + # The method we want to test + download_device_keys = self.device_list._download_device_keys + bob = '@bob:example.com' + eve = '@eve:example.com' + user_devices = {self.alice: [], bob: [], self.user_id: []} + + # This response is correct for Alice's keys, but lacks Bob's + # There are no failures + resp = example_key_query_response + responses.add(responses.POST, self.query_url, json=resp) + + # Still correct, but Alice's identity key has changed + resp = deepcopy(example_key_query_response) + new_id_key = 'ijxGZqwB/UvMtKABdaCdrI0OtQI6NhHBYiknoCkdWng' + payload = resp['device_keys'][self.alice]['JLAFKJWSCS'] + payload['keys']['curve25519:JLAFKJWSCS'] = new_id_key + payload['signatures'][self.alice]['ed25519:JLAFKJWSCS'] = \ + ('D9oLtYefMIr4StiHTIzn3+bhtPCfrZNDU9jsUbMu3MicfZLl4d8WlYn3TPmbwDi8XMGcT' + 'nNnqfdi/tYUPvKfCA') + responses.add(responses.POST, self.query_url, json=resp) + + # Still correct, but Alice's signing key has changed + alice_device = OlmDevice(self.cli.api, self.alice, 'JLAFKJWSCS') + resp = deepcopy(example_key_query_response) + resp['device_keys'][self.alice]['JLAFKJWSCS']['keys']['ed25519:JLAFKJWSCS'] = \ + alice_device.identity_keys['ed25519'] + resp['device_keys'][self.alice]['JLAFKJWSCS'] = \ + alice_device.sign_json(resp['device_keys'][self.alice]['JLAFKJWSCS']) + responses.add(responses.POST, self.query_url, json=resp) + + # Response containing an unknown user + resp = deepcopy(example_key_query_response) + user_device = resp['device_keys'].pop(self.alice) + resp['device_keys'][eve] = user_device + responses.add(responses.POST, self.query_url, json=resp) + + # Response with an invalid signature + resp = deepcopy(example_key_query_response) + resp['device_keys'][self.alice]['JLAFKJWSCS']['test'] = 1 + responses.add(responses.POST, self.query_url, json=resp) + + # Response with a requested user and valid signature, but with a mismatch + resp = deepcopy(example_key_query_response) + user_device = resp['device_keys'].pop(self.alice) + resp['device_keys'][bob] = user_device + responses.add(responses.POST, self.query_url, json=resp) + + # Response with an invalid keys field + resp = deepcopy(example_key_query_response) + keys_field = resp['device_keys'][self.alice]['JLAFKJWSCS']['keys'] + key = keys_field.pop("ed25519:JLAFKJWSCS") + keys_field["ed25519:wrong"] = key + # Cover a missing branch by adding failures + resp["failures"]["other.com"] = {} + # And one more by adding ourself + resp['device_keys'][self.user_id] = {self.device_id: 'dummy'} + responses.add(responses.POST, self.query_url, json=resp) + + self.device.device_keys.clear() + assert download_device_keys(user_devices) + req = json.loads(responses.calls[0].request.body) + assert req['device_keys'] == {self.alice: [], bob: [], self.user_id: []} + expected_device_keys = { + self.alice: { + 'JLAFKJWSCS': { + 'curve25519': '3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI', + 'ed25519': 'VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA' + } + } + } + assert self.device.device_keys == expected_device_keys + + # Different curve25519, key should get updated + assert download_device_keys(user_devices) + expected_device_keys[self.alice]['JLAFKJWSCS']['curve25519'] = new_id_key + assert self.device.device_keys == expected_device_keys + + # Different ed25519, key should not get updated + assert not download_device_keys(user_devices) + assert self.device.device_keys == expected_device_keys + + self.device.device_keys.clear() + # All the remaining responses are wrong and we should not add the key + for _ in range(4): + assert not download_device_keys(user_devices) + assert self.device.device_keys == {} + + assert len(responses.calls) == 7 + + @responses.activate + def test_update_thread(self): + # Normal run + event = Event() + outdated_users = OutdatedUsersSet({self.user_id}) + outdated_users.events.add(event) + + def dummy_download(user_devices, since_token=None): + assert user_devices == {self.user_id: []} + return + thread = UpdateDeviceList(Condition(), outdated_users, dummy_download, set()) + + thread.start() + event.wait() + assert not thread.outdated_user_ids + assert thread.event.is_set() + assert thread.tracked_user_ids == {self.user_id} + thread.join() + assert not thread.is_alive() + + # Error run + outdated_users = OutdatedUsersSet({self.user_id}) + + def error_on_first_download(user_devices, since_token=None): + error_on_first_download.c += 1 + if error_on_first_download.c == 1: + raise MatrixRequestError + return + error_on_first_download.c = 0 + thread = UpdateDeviceList( + Condition(), outdated_users, error_on_first_download, set()) + thread.start() + thread.event.wait() + assert error_on_first_download.c == 2 + assert not thread.outdated_user_ids + thread.join() + + # Cover a missing branch + thread = UpdateDeviceList( + Condition(), outdated_users, error_on_first_download, set()) + thread._should_terminate.set() + thread.start() + thread.join() + assert not thread.is_alive() + + @responses.activate + def test_get_room_device_keys(self): + self.device_list.tracked_user_ids.clear() + room = self.cli._mkroom(self.room_id) + room._members[self.alice] = User(self.cli.api, self.alice) + + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + # Blocking + self.device_list.get_room_device_keys(room) + assert self.device_list.tracked_user_ids == {self.alice} + assert self.device_list.device_keys[self.alice]['JLAFKJWSCS'] + + # Same, but we already track the user + self.device_list.get_room_device_keys(room) + + # Non-blocking + self.device_list.tracked_user_ids.clear() + # We have to block for testing purposes, though + self.device_list.update_thread.event.clear() + self.device_list.get_room_device_keys(room, blocking=False) + self.device_list.update_thread.event.wait() + + # Same, but we already track the user + self.device_list.get_room_device_keys(room, blocking=False) + + @responses.activate + def test_track_users(self): + self.device_list.tracked_user_ids.clear() + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + self.device_list.update_thread.event.clear() + self.device_list.track_users({self.alice}) + self.device_list.update_thread.event.wait() + assert self.device_list.tracked_user_ids == {self.alice} + assert len(responses.calls) == 1 + + # Same, but we are already tracking Alice + self.device_list.track_users({self.alice}) + assert len(responses.calls) == 1 + + def test_stop_tracking_users(self): + self.device_list.tracked_user_ids.clear() + self.device_list.tracked_user_ids.add(self.alice) + self.device_list.outdated_user_ids.clear() + self.device_list.outdated_user_ids.add(self.alice) + + self.device_list.stop_tracking_users({self.alice}) + + assert not self.device_list.tracked_user_ids + assert not self.device_list.outdated_user_ids + + def test_pending_users(self): + # Say Alice is already tracked to avoid triggering dowload process + self.device_list.tracked_user_ids.add(self.alice) + + self.device_list.track_user_no_download(self.alice) + assert self.alice in self.device_list.pending_outdated_user_ids + + self.device_list.track_pending_users() + assert self.alice not in self.device_list.pending_outdated_user_ids + + @responses.activate + def test_update_user_device_keys(self): + self.device_list.tracked_user_ids.clear() + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + self.device_list.update_user_device_keys({self.alice}) + assert len(responses.calls) == 0 + + self.device_list.tracked_user_ids.add(self.alice) + + self.device_list.update_thread.event.clear() + self.device_list.update_user_device_keys({self.alice}, since_token='dummy') + self.device_list.update_thread.event.wait() + assert len(responses.calls) == 1 + + +def test_outdated_users_set(): + s = OutdatedUsersSet() + assert not s + + s = OutdatedUsersSet({1}) + event = Event() + s.events.add(event) + assert s == {1} + + # Make a manual copy of s + t = OutdatedUsersSet() + t.add(1) + t.events.add(event) + assert t == s and t.events == s.events + + u = s.copy() + event2 = Event() + u.add(2) + u.events.add(event2) + # Check that modifying u didn't change s + assert t == s and t.events == s.events + + s.update(u) + assert s == {1, 2} and s.events == {event, event2} + + s.mark_as_processed() + assert event.is_set() + + new = 's72594_4483_1935' + s.sync_token = new + old = 's72594_4483_1934' + s.sync_token = old + assert s.sync_token == new + + s.clear() + assert not s and not s.events diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 84f56df6..c7aaa51b 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -1,15 +1,25 @@ import pytest -pytest.importorskip("olm") # noqa +olm = pytest.importorskip("olm") # noqa import json +import logging from copy import deepcopy +from datetime import timedelta, datetime +try: + from urllib import quote +except ImportError: + from urllib.parse import quote import responses +from matrix_client.crypto import olm_device from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient +from matrix_client.user import User from matrix_client.crypto.olm_device import OlmDevice -from test.response_examples import example_key_upload_response +from matrix_client.crypto.megolm_outbound_session import MegolmOutboundSession +from test.response_examples import (example_key_upload_response, + example_claim_keys_response) HOSTNAME = 'http://example.com' @@ -17,9 +27,23 @@ class TestOlmDevice: cli = MatrixClient(HOSTNAME) user_id = '@user:matrix.org' + room_id = '!test:example.com' device_id = 'QBUAZIFURK' device = OlmDevice(cli.api, user_id, device_id) signing_key = device.olm_account.identity_keys['ed25519'] + alice = '@alice:example.com' + alice_device_id = 'JLAFKJWSCS' + alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' + alice_identity_keys = { + 'curve25519': alice_curve_key, + 'ed25519': '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' + } + alice_olm_session = olm.OutboundSession( + device.olm_account, alice_curve_key, alice_curve_key) + room = cli._mkroom(room_id) + room._members[alice] = User(cli.api, alice) + # allow to_device api call to work well with responses + device.api._make_txn_id = lambda: 1 def test_sign_json(self): example_payload = { @@ -204,3 +228,320 @@ def test_invalid_keys_threshold(self, threshold): self.user_id, self.device_id, keys_threshold=threshold) + + @responses.activate + def test_olm_start_sessions(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + self.device.olm_sessions.clear() + self.device.device_keys.clear() + + user_devices = {self.alice: {self.alice_device_id}} + + # We don't have alice's keys + self.device.olm_start_sessions(user_devices) + assert not self.device.olm_sessions[self.alice_curve_key] + + # Cover logging part + olm_device.logger.setLevel(logging.WARNING) + # Now should be good + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.olm_start_sessions(user_devices) + assert self.device.olm_sessions[self.alice_curve_key] + + # With failures and wrong signature + self.device.olm_sessions.clear() + payload = deepcopy(example_claim_keys_response) + payload['failures'] = {'dummy': 1} + key = payload['one_time_keys'][self.alice][self.alice_device_id] + key['signed_curve25519:AAAAAQ']['test'] = 1 + responses.replace(responses.POST, claim_url, json=payload) + + self.device.olm_start_sessions(user_devices) + assert not self.device.olm_sessions[self.alice_curve_key] + + # Missing requested user and devices + user_devices[self.alice].add('test') + user_devices['test'] = 'test' + + self.device.olm_start_sessions(user_devices) + + @responses.activate + def test_olm_build_encrypted_event(self): + self.device.device_keys.clear() + self.device.olm_sessions.clear() + event_content = {'dummy': 'example'} + + # We don't have Alice's keys + with pytest.raises(RuntimeError): + self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + # We don't have a session with Alice + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + with pytest.raises(RuntimeError): + self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + user_devices = {self.alice: {self.alice_device_id}} + self.device.olm_start_sessions(user_devices) + assert self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + def test_olm_decrypt(self): + self.device.olm_sessions.clear() + # Since this method doesn't care about high-level event formatting, we will + # generate things at low level + our_account = self.device.olm_account + # Alice needs to start a session with us + alice = olm.Account() + sender_key = alice.identity_keys['curve25519'] + our_account.generate_one_time_keys(1) + otk = next(iter(our_account.one_time_keys['curve25519'].values())) + self.device.olm_account.mark_keys_as_published() + session = olm.OutboundSession(alice, our_account.identity_keys['curve25519'], otk) + + plaintext = {"test": "test"} + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # New pre-key message, but the session exists this time + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # Try to decrypt the same message twice + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Answer Alice in order to have a type 1 message + message = self.device.olm_sessions[sender_key][0].encrypt(json.dumps(plaintext)) + session.decrypt(message) + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # Try to decrypt the same message type 1 twice + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt a message from a session that reused a one-time key + otk_reused_session = olm.OutboundSession( + alice, our_account.identity_keys['curve25519'], otk) + message = otk_reused_session.encrypt(json.dumps(plaintext)) + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt an invalid type 0 message + our_account.generate_one_time_keys(1) + otk = next(iter(our_account.one_time_keys['curve25519'].values())) + wrong_session = olm.OutboundSession(alice, sender_key, otk) + message = wrong_session.encrypt(json.dumps(plaintext)) + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt a type 1 message for which we have no sessions + message = session.encrypt(json.dumps(plaintext)) + self.device.olm_sessions.clear() + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + def test_olm_decrypt_event(self): + self.device.device_keys.clear() + self.device.olm_sessions.clear() + alice_device = OlmDevice(self.device.api, self.alice, self.alice_device_id) + alice_device.device_keys[self.user_id][self.device_id] = self.device.identity_keys + self.device.device_keys[self.alice][self.alice_device_id] = \ + alice_device.identity_keys + + # Artificially start an Olm session from Alice + self.device.olm_account.generate_one_time_keys(1) + otk = next(iter(self.device.olm_account.one_time_keys['curve25519'].values())) + self.device.olm_account.mark_keys_as_published() + sender_key = self.device.identity_keys['curve25519'] + session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) + alice_device.olm_sessions[sender_key] = [session] + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + + # Now we can test + self.device.olm_decrypt_event(encrypted_event, self.alice) + + # Type 1 Olm payload + alice_device.olm_decrypt_event( + self.device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.alice, self.alice_device_id + ), + self.user_id) + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.olm_decrypt_event(encrypted_event, self.alice) + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, 'wrong') + + wrong_event = deepcopy(encrypted_event) + wrong_event['algorithm'] = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(wrong_event, self.alice) + + wrong_event = deepcopy(encrypted_event) + wrong_event['ciphertext'] = {} + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(wrong_event, self.alice) + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.user_id = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, self.alice) + self.device.user_id = self.user_id + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + backup = self.device.identity_keys['ed25519'] + self.device.identity_keys['ed25519'] = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, self.alice) + self.device.identity_keys['ed25519'] = backup + + @responses.activate + def test_olm_ensure_sessions(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + self.device.olm_sessions.clear() + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + user_devices = {self.alice: [self.alice_device_id]} + + self.device.olm_ensure_sessions(user_devices) + assert self.device.olm_sessions[self.alice_curve_key] + assert len(responses.calls) == 1 + + self.device.olm_ensure_sessions(user_devices) + assert len(responses.calls) == 1 + + @responses.activate + def test_megolm_share_session(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.olm_sessions.clear() + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.device_keys['dummy']['dummy'] = {'curve25519': 'a', 'ed25519': 'a'} + user_devices = {self.alice: [self.alice_device_id], 'dummy': ['dummy']} + session = MegolmOutboundSession() + + # Sharing with Alice should succeed, but dummy will fail + self.device.megolm_share_session(self.room_id, user_devices, session) + assert session.devices == {self.alice_device_id, 'dummy'} + + req = json.loads(responses.calls[1].request.body)['messages'] + assert self.alice in req + assert 'dummy' not in req + + @responses.activate + def test_megolm_start_session(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.device_list.tracked_user_ids.add(self.alice) + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + + self.device.megolm_start_session(self.room) + session = self.device.megolm_outbound_sessions[self.room_id] + assert self.alice_device_id in session.devices + + @responses.activate + def test_megolm_share_session_with_new_devices(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + session = MegolmOutboundSession() + self.device.megolm_outbound_sessions[self.room_id] = session + + self.device.megolm_share_session_with_new_devices(self.room, session) + assert self.alice_device_id in session.devices + assert len(responses.calls) == 1 + + self.device.megolm_share_session_with_new_devices(self.room, session) + assert len(responses.calls) == 1 + + @responses.activate + def test_megolm_build_encrypted_event(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.megolm_outbound_sessions.clear() + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.device_list.tracked_user_ids.add(self.alice) + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + event = {'type': 'm.room.message', 'content': {'body': 'test'}} + + self.room.rotation_period_msgs = 1 + self.device.megolm_build_encrypted_event(self.room, event) + + self.device.megolm_build_encrypted_event(self.room, event) + + session = self.device.megolm_outbound_sessions[self.room_id] + session.encrypt('test') + self.device.megolm_build_encrypted_event(self.room, event) + assert self.device.megolm_outbound_sessions[self.room_id].id != session.id + + def test_megolm_remove_outbound_session(self): + session = MegolmOutboundSession() + self.device.megolm_outbound_sessions[self.room_id] = session + self.device.megolm_remove_outbound_session(self.room_id) + self.device.megolm_remove_outbound_session(self.room_id) + + @responses.activate + def test_send_encrypted_message(self): + message_url = HOSTNAME + MATRIX_V2_API_PATH + \ + '/rooms/{}/send/m.room.encrypted/1'.format(quote(self.room.room_id)) + responses.add(responses.PUT, message_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = \ + self.alice_identity_keys + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + session = MegolmOutboundSession() + session.add_device(self.alice_device_id) + self.device.megolm_outbound_sessions[self.room_id] = session + + self.device.send_encrypted_message(self.room, {'test': 'test'}) + + +def test_megolm_outbound_session(): + session = MegolmOutboundSession() + assert session.max_messages == 100 + assert session.max_age == timedelta(days=7) + + session = MegolmOutboundSession(max_messages=1, max_age=100000) + assert session.max_messages == 1 + assert session.max_age == timedelta(milliseconds=100000) + + assert not session.devices + + session.add_device('test') + assert 'test' in session.devices + + session.add_devices({'test2', 'test3'}) + assert 'test2' in session.devices and 'test3' in session.devices + + assert not session.should_rotate() + + session.encrypt('message') + assert session.should_rotate() + + session.max_messages = 2 + assert not session.should_rotate() + session.creation_time = datetime.now() - timedelta(milliseconds=100000) + assert session.should_rotate() diff --git a/test/response_examples.py b/test/response_examples.py index 2d45aa86..d6a2ca49 100644 --- a/test/response_examples.py +++ b/test/response_examples.py @@ -184,3 +184,55 @@ "og:image:width": 48, "og:title": "Matrix Blog Post" } + +example_key_query_response = { + "failures": {}, + "device_keys": { + "@alice:example.com": { + "JLAFKJWSCS": { + "user_id": "@alice:example.com", + "device_id": "JLAFKJWSCS", + "algorithms": [ + "m.olm.curve25519-aes-sha256", + "m.megolm.v1.aes-sha" + ], + "keys": { + "curve25519:JLAFKJWSCS": ("3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIAr" + "zgyI"), + "ed25519:JLAFKJWSCS": "VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA" + }, + "signatures": { + "@alice:example.com": { + "ed25519:JLAFKJWSCS": + ("wux6Dhjtk7GYPMW54hnx0doVH0NvuUAFBleL5OW99jhbjIutufglAgrYAcu8" + "ueacgNyeSumvtzVIPZXgbB2BCg") + } + }, + "unsigned": { + "device_display_name": "Alice'smobilephone" + } + } + } + } +} + +example_claim_keys_response = { + "failures": {}, + "one_time_keys": { + "@alice:example.com": { + "JLAFKJWSCS": { + 'signed_curve25519:AAAAAQ': { + 'key': '9UOzQjF2j2Xf8mBIiMgruuCkuWtD0ea9kvx63mO92Ws', + 'signatures': { + '@alice:example.com': { + 'ed25519:JLAFKJWSCS': ( + '6O+VYxN7mVcr/j66YdHASRrpW4ydC/0FcYmEWVAGIFzU4+yjzxxinhQD' + 'l7InhhdGuXeQlk4/w/CyU76TY6wdBA' + ) + } + } + } + } + } + } +}