diff --git a/docs/source/conf.py b/docs/source/conf.py index e3a76c0d..b7bc465a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -96,4 +96,5 @@ 'Miscellaneous'), ] -autodoc_mock_imports = ["olm", "canonicaljson"] +autodoc_mock_imports = ["olm", "canonicaljson", "appdirs", "unpaddedbase64", "Crypto", + "Crypto.Cipher", "Crypto.Hash", "Crypto.Util"] diff --git a/docs/source/matrix_client.rst b/docs/source/matrix_client.rst index af1ebc41..16a6aedf 100644 --- a/docs/source/matrix_client.rst +++ b/docs/source/matrix_client.rst @@ -56,3 +56,28 @@ matrix_client.crypto :members: :undoc-members: :show-inheritance: + +.. automodule:: matrix_client.crypto.device_list + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: matrix_client.crypto.sessions + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: matrix_client.crypto.crypto_store + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: matrix_client.crypto.encrypt_attachments + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: matrix_client.crypto.verified_event + :members: + :undoc-members: + :show-inheritance: diff --git a/matrix_client/api.py b/matrix_client/api.py index 8bd44ecc..479aa6fe 100644 --- a/matrix_client/api.py +++ b/matrix_client/api.py @@ -329,19 +329,43 @@ def redact_event(self, room_id, event_id, reason=None, txn_id=None, timestamp=No # content_type can be a image,audio or video # extra information should be supplied, see # https://matrix.org/docs/spec/r0.0.1/client_server.html - def send_content(self, room_id, item_url, item_name, msg_type, - extra_information=None, timestamp=None): + def send_content(self, room_id, item_url, item_name, msg_type, filename=None, + extra_information=None, timestamp=None, encryption_info=None): + content_pack = self.get_content_body(item_url, item_name, msg_type, filename, + extra_information, encryption_info) + return self.send_message_event(room_id, "m.room.message", content_pack, + timestamp=timestamp) + + def get_content_body(self, item_url, item_name, msg_type, filename=None, + extra_information=None, encryption_info=None): if extra_information is None: extra_information = {} content_pack = { - "url": item_url, "msgtype": msg_type, "body": item_name, "info": extra_information } - return self.send_message_event(room_id, "m.room.message", content_pack, - timestamp=timestamp) + if msg_type == "m.file": + content_pack["filename"] = filename or item_name + if encryption_info: + encryption_info['url'] = item_url + content_pack['file'] = encryption_info + else: + content_pack['url'] = item_url + return content_pack + + def get_location_body(self, geo_uri, name, thumb_url=None, thumb_info=None): + content_pack = { + "geo_uri": geo_uri, + "msgtype": "m.location", + "body": name, + } + if thumb_url: + content_pack["thumbnail_url"] = thumb_url + if thumb_info: + content_pack["thumbnail_info"] = thumb_info + return content_pack # http://matrix.org/docs/spec/client_server/r0.2.0.html#m-location def send_location(self, room_id, geo_uri, name, thumb_url=None, thumb_info=None, @@ -356,15 +380,8 @@ def send_location(self, room_id, geo_uri, name, thumb_url=None, thumb_info=None, thumb_info (dict): Metadata about the thumbnail, type ImageInfo. timestamp (int): Set origin_server_ts (For application services only) """ - content_pack = { - "geo_uri": geo_uri, - "msgtype": "m.location", - "body": name, - } - if thumb_url: - content_pack["thumbnail_url"] = thumb_url - if thumb_info: - content_pack["thumbnail_info"] = thumb_info + content_pack = self.get_location_body( + geo_uri, name, thumb_url, thumb_info) return self.send_message_event(room_id, "m.room.message", content_pack, timestamp=timestamp) @@ -405,12 +422,11 @@ def send_notice(self, room_id, text_content, timestamp=None): text_content (str): The m.notice body to send. timestamp (int): Set origin_server_ts (For application services only) """ - body = { - "msgtype": "m.notice", - "body": text_content - } - return self.send_message_event(room_id, "m.room.message", body, - timestamp=timestamp) + return self.send_message_event( + room_id, "m.room.message", + self.get_notice_body(text_content), + timestamp=timestamp + ) def get_room_messages(self, room_id, token, direction, limit=10, to=None): """Perform GET /rooms/{roomId}/messages. @@ -675,6 +691,12 @@ def get_emote_body(self, text): "body": text } + def get_notice_body(self, text): + return { + "msgtype": "m.notice", + "body": text + } + def get_filter(self, user_id, filter_id): return self._send("GET", "/user/{userId}/filter/{filterId}" .format(userId=user_id, filterId=filter_id)) diff --git a/matrix_client/client.py b/matrix_client/client.py index 703b825c..4edf4ca4 100644 --- a/matrix_client/client.py +++ b/matrix_client/client.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from .api import MatrixHttpApi +from .checks import check_user_id from .errors import MatrixRequestError, MatrixUnexpectedResponse from .room import Room from .user import User @@ -63,6 +64,13 @@ class MatrixClient(object): encryption_conf (dict): Optional. Configuration parameters for encryption. Refer to :func:`~matrix_client.crypto.olm_device.OlmDevice` for supported options, since it will be passed to this class. + restore_device_id (bool): Optional. Only valid when encryption is enabled. When + turned on, the device ID corresponding to the user ID will be retrieved from + the encryption database, if it exists. + verify_devices (bool): Optional. When enabled, sending a message will fail when + there are unknown devices in an encrypted room. A client will have to + inspect those, and resend its message. Note that this can be configured later + on a per room basis. Returns: `MatrixClient` @@ -111,7 +119,8 @@ def global_callback(incoming_event): def __init__(self, base_url, token=None, user_id=None, valid_cert_check=True, sync_filter_limit=20, - cache_level=CACHE.ALL, encryption=False, encryption_conf=None): + cache_level=CACHE.ALL, encryption=False, encryption_conf=None, + restore_device_id=False, verify_devices=False): if user_id: warn( "user_id is deprecated. " @@ -121,6 +130,9 @@ def __init__(self, base_url, token=None, user_id=None, if encryption and not ENCRYPTION_SUPPORT: raise ValueError("Failed to enable encryption. Please make sure the olm " "library is available.") + if restore_device_id and not encryption: + raise ValueError("restore_device_id only makes sense when encryption is " + "enabled.") self.api = MatrixHttpApi(base_url, token) self.api.validate_certificate(valid_cert_check) @@ -133,6 +145,9 @@ def __init__(self, base_url, token=None, user_id=None, self._encryption = encryption self.encryption_conf = encryption_conf or {} self.olm_device = None + self.first_sync = True + self.restore_device_id = restore_device_id + self.verify_devices = verify_devices if isinstance(cache_level, CACHE): self._cache_level = cache_level else: @@ -265,8 +280,11 @@ def login(self, username, password, limit=10, sync=True, device_id=None): limit (int): Deprecated. How many messages to return when syncing. This will be replaced by a filter API in a later release. sync (bool): Optional. Whether to initiate a /sync request after logging in. - device_id (str): Optional. ID of the client device. The server will - auto-generate a device_id if this is not specified. + device_id (str): Optional. ID of the client device. If it is not specified, + the server will auto-generate one, or it may be retrieved + from database if ``restore_device_id`` is ``True``. If it is specified, + and ``restore_device_id`` is ``True``, the eventual encryption keys stored + along with a previous device ID of the current user are discarded. Returns: str: Access token @@ -274,6 +292,20 @@ def login(self, username, password, limit=10, sync=True, device_id=None): Raises: MatrixRequestError """ + if not device_id and self.restore_device_id: + try: + check_user_id(username) + except ValueError: + raise ValueError("When using restore_device_id, a full user ID " + "must be supplied when logging in.") + try: + self.olm_device = OlmDevice( + self.api, username, **self.encryption_conf) + device_id = self.olm_device.device_id + logger.info('Device ID was sucessfully retrieved from database.') + except ValueError: + pass + response = self.api.login( "m.login.password", user=username, password=password, device_id=device_id ) @@ -284,8 +316,9 @@ def login(self, username, password, limit=10, sync=True, device_id=None): self.device_id = response["device_id"] if self._encryption: - self.olm_device = OlmDevice( - self.api, self.user_id, self.device_id, **self.encryption_conf) + if not self.olm_device: + self.olm_device = OlmDevice( + self.api, self.user_id, self.device_id, **self.encryption_conf) self.olm_device.upload_identity_keys() self.olm_device.upload_one_time_keys() @@ -566,7 +599,7 @@ def upload(self, content, content_type, filename=None): ) def _mkroom(self, room_id): - room = Room(self, room_id) + room = Room(self, room_id, verify_devices=self.verify_devices) if self._encryption: try: event = self.api.get_state_event(room_id, "m.room.encryption") @@ -581,8 +614,21 @@ def _mkroom(self, room_id): # TODO better handling of the blocking I/O caused by update_one_time_key_counts def _sync(self, timeout_ms=30000): response = self.api.sync(self.sync_token, timeout_ms, filter=self.sync_filter) + + if self._encryption and 'device_lists' in response: + if response['device_lists'].get('changed'): + self.olm_device.device_list.update_user_device_keys( + response['device_lists']['changed'], self.sync_token) + if response['device_lists'].get('left'): + self.olm_device.device_list.stop_tracking_users( + response['device_lists']['left']) + self.sync_token = response["next_batch"] + if self._encryption and self.first_sync: + self.first_sync = False + self.olm_device.device_list.update_after_restart(self.sync_token) + for presence_update in response['presence']['events']: for callback in self.presence_listeners.values(): callback(presence_update) @@ -597,6 +643,11 @@ def _sync(self, timeout_ms=30000): if room_id in self.rooms: del self.rooms[room_id] + if 'to_device' in response: + for event in response['to_device']['events']: + if event['type'] == 'm.room.encrypted' and self._encryption: + self.olm_device.olm_handle_encrypted_event(event) + if self._encryption and 'device_one_time_keys_count' in response: self.olm_device.update_one_time_key_counts( response['device_one_time_keys_count']) @@ -627,6 +678,10 @@ def _sync(self, timeout_ms=30000): ): listener['callback'](event) + if self._encryption and room.encrypted: + # Track the new users in the room + self.olm_device.device_list.track_pending_users() + for event in sync_room['ephemeral']['events']: event['room_id'] = room_id room._put_ephemeral_event(event) @@ -650,7 +705,7 @@ def get_user(self, user_id): """ warn("get_user is deprecated. Directly instantiate a User instead.", DeprecationWarning) - return User(self.api, user_id) + return User(self, user_id) # TODO: move to Room class def remove_room_alias(self, room_alias): @@ -667,3 +722,12 @@ def remove_room_alias(self, room_alias): return True except MatrixRequestError: return False + + def get_fingerprint(self): + """Get the fingerprint of the current device. + + This is used when verifying devices. + """ + if not self._encryption: + raise ValueError("Encryption is not enabled, this device has no fingerprint.") + return self.olm_device.ed25519 diff --git a/matrix_client/crypto/crypto_store.py b/matrix_client/crypto/crypto_store.py new file mode 100644 index 00000000..bd9c7c85 --- /dev/null +++ b/matrix_client/crypto/crypto_store.py @@ -0,0 +1,571 @@ +import logging +import os +import sqlite3 +from collections import defaultdict +from datetime import timedelta +from threading import current_thread + +import olm +from appdirs import user_data_dir + +from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession +from matrix_client.device import Device + +logger = logging.getLogger(__name__) + + +class CryptoStore(object): + """Manages persistent storage for an OlmDevice. + + Args: + user_id (str): The user ID of the OlmDevice. + device_id (str): Optional. The device ID of the OlmDevice. Will be retrieved using + ``user_id`` if not present. + db_name (str): Optional. The name of the database file to use. Will be created + if necessary. + db_path (str): Optional. The path where to store the database file. Defaults to + the system default application data directory. + app_name (str): Optional. The application name, which will be used to determine + where the database is located. Ignored if db_path is supplied. + pickle_key (str): Optional. A key to encrypt the database contents. + """ + + def __init__(self, + user_id, + device_id=None, + db_name='crypto.db', + db_path=None, + app_name='matrix-python-sdk', + pickle_key='DEFAULT_KEY'): + self.user_id = user_id + self.device_id = device_id + data_dir = db_path or user_data_dir(app_name, '') + try: + os.makedirs(data_dir) + except OSError: + pass + self.db_filepath = os.path.join(data_dir, db_name) + + # Map from a thread id to a connection object + self._conn = defaultdict(self.instanciate_connection) + self.pickle_key = pickle_key + self.create_tables_if_needed() + + def instanciate_connection(self): + con = sqlite3.connect(self.db_filepath, detect_types=sqlite3.PARSE_DECLTYPES) + con.row_factory = sqlite3.Row + return con + + def create_tables_if_needed(self): + """Ensures all the tables exist.""" + c = self.conn.cursor() + c.executescript(""" +PRAGMA secure_delete = ON; +PRAGMA foreign_keys = ON; +CREATE TABLE IF NOT EXISTS accounts( + device_id TEXT NOT NULL UNIQUE, account BLOB, user_id TEXT PRIMARY KEY NOT NULL +); +CREATE TABLE IF NOT EXISTS olm_sessions( + device_id TEXT, session_id TEXT PRIMARY KEY, curve_key TEXT, session BLOB, + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS megolm_inbound_sessions( + device_id TEXT, session_id TEXT PRIMARY KEY, room_id TEXT, curve_key TEXT, + ed_key TEXT, session BLOB, + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS megolm_outbound_sessions( + device_id TEXT, room_id TEXT, session BLOB, max_age_s FLOAT, + max_messages INTEGER, creation_time TIMESTAMP, message_count INTEGER, + PRIMARY KEY(device_id, room_id), + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS megolm_outbound_devices( + device_id TEXT, room_id TEXT, user_device_id TEXT, + PRIMARY KEY(device_id, room_id, user_device_id), + FOREIGN KEY(device_id, room_id) REFERENCES + megolm_outbound_sessions(device_id, room_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS device_keys( + device_id TEXT, user_id TEXT, user_device_id TEXT, ed_key TEXT, + curve_key TEXT, verified INTEGER, blacklisted INTEGER, ignored INTEGER, + PRIMARY KEY(device_id, user_id, user_device_id), + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS tracked_users( + device_id TEXT, user_id TEXT, + PRIMARY KEY(device_id, user_id), + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); +CREATE TABLE IF NOT EXISTS sync_tokens( + device_id TEXT PRIMARY KEY, token TEXT, + FOREIGN KEY(device_id) REFERENCES accounts(device_id) ON DELETE CASCADE +); + """) + c.close() + self.conn.commit() + + def save_olm_account(self, account): + """Saves an Olm account. + + Args: + account (olm.Account): The account object to save. + """ + account_data = account.pickle(self.pickle_key) + c = self.conn.cursor() + c.execute( + 'INSERT OR IGNORE INTO accounts (device_id, account, user_id) VALUES (?,?,?)', + (self.device_id, account_data, self.user_id) + ) + c.execute('UPDATE accounts SET account=? WHERE device_id=?', + (account_data, self.device_id)) + c.close() + self.conn.commit() + + def replace_olm_account(self, account): + """Replace an Olm account. + + Instead of updating it as done with :meth:`save_olm_account`, this saves the + new account and discards all data associated with the previous one. + + Args: + account (olm.Account): The account object to save. + """ + account_data = account.pickle(self.pickle_key) + c = self.conn.cursor() + c.execute('REPLACE INTO accounts (device_id, account, user_id) VALUES (?,?,?)', + (self.device_id, account_data, self.user_id)) + c.close() + self.conn.commit() + + def get_olm_account(self): + """Gets the Olm account. + + Returns: + ``olm.Account`` object, or ``None`` if it wasn't found for the current + device_id. + + Raises: + ``ValueError`` if ``device_id`` was ``None`` and couldn't be retrieved. + """ + c = self.conn.cursor() + if self.device_id: + c.execute( + 'SELECT account, device_id FROM accounts WHERE user_id=? AND device_id=?', + (self.user_id, self.device_id) + ) + else: + c.execute('SELECT account, device_id FROM accounts WHERE user_id=?', + (self.user_id,)) + row = c.fetchone() + if not row and not self.device_id: + raise ValueError('Failed to retrieve device_id.') + try: + self.device_id = row['device_id'] + account_data = row['account'] + # sqlite gives us unicode in Python2, we want bytes + account_data = bytes(account_data) + except TypeError: + return None + finally: + c.close() + return olm.Account.from_pickle(account_data, self.pickle_key) + + def remove_olm_account(self): + """Removes the Olm account. + + NOTE: Doing so will remove any saved information associated with the account + (keys, sessions...) + """ + c = self.conn.cursor() + c.execute('DELETE FROM accounts WHERE user_id=?', (self.user_id,)) + c.close() + + def save_olm_session(self, curve_key, session): + self.save_olm_sessions({curve_key: [session]}) + + def save_olm_sessions(self, sessions): + """Saves Olm sessions. + + Args: + sessions (defaultdict(list)): A map from curve25519 keys to a list of + ``olm.Session`` objects. + """ + c = self.conn.cursor() + rows = [(self.device_id, s.id, key, s.pickle(self.pickle_key)) + for key in sessions for s in sessions[key]] + c.executemany('REPLACE INTO olm_sessions VALUES (?,?,?,?)', rows) + c.close() + self.conn.commit() + + def load_olm_sessions(self, sessions): + """Loads all saved Olm sessions. + + Args: + sessions (defaultdict(list)): A map from curve25519 keys to a list of + ``olm.Session`` objects, which will be populated. + """ + c = self.conn.cursor() + rows = c.execute('SELECT curve_key, session FROM olm_sessions WHERE device_id=?', + (self.device_id,)) + for row in rows: + session = olm.Session.from_pickle(bytes(row['session']), self.pickle_key) + sessions[row['curve_key']].append(session) + c.close() + + def get_olm_sessions(self, curve_key, sessions_dict=None): + """Get the Olm sessions corresponding to a device. + + Args: + curve_key (str): The curve25519 key of the device. + sessions_dict (defaultdict(list)): Optional. A map from curve25519 keys to a + list of ``olm.Session`` objects, to which the session list will be added. + + Returns: + A list of ``olm.Session`` objects, or ``None`` if none were found. + + NOTE: + When overriding this, be careful to append the retrieved sessions to the + list of sessions already present and not to overwrite its reference. + """ + c = self.conn.cursor() + rows = c.execute( + 'SELECT session FROM olm_sessions WHERE device_id=? AND curve_key=?', + (self.device_id, curve_key) + ) + sessions = [olm.Session.from_pickle(bytes(row['session']), self.pickle_key) + for row in rows] + if sessions_dict is not None: + sessions_dict[curve_key].extend(sessions) + c.close() + # For consistency with other get_ methods, do not return an empty list + return sessions or None + + def save_inbound_session(self, room_id, curve_key, session): + """Saves a Megolm inbound session. + + Args: + room_id (str): The room corresponding to the session. + curve_key (str): The curve25519 key of the device. + session (MegolmInboundSession): The session to save. + """ + c = self.conn.cursor() + c.execute('REPLACE INTO megolm_inbound_sessions VALUES (?,?,?,?,?,?)', + (self.device_id, session.id, room_id, curve_key, session.ed25519, + session.pickle(self.pickle_key))) + c.close() + self.conn.commit() + + def load_inbound_sessions(self, sessions): + """Loads all saved inbound Megolm sessions. + + Args: + sessions (defaultdict(defaultdict(dict))): An object which will get + populated with the sessions. The format is + ``{: {: {: + }}}``. + """ + c = self.conn.cursor() + rows = c.execute( + 'SELECT * FROM megolm_inbound_sessions WHERE device_id=?', (self.device_id,) + ) + for row in rows: + session = MegolmInboundSession.from_pickle( + bytes(row['session']), row['ed_key'], self.pickle_key) + sessions[row['room_id']][row['curve_key']][session.id] = session + c.close() + + def get_inbound_session(self, room_id, curve_key, session_id, sessions=None): + """Gets a saved inbound Megolm session. + + Args: + room_id (str): The room corresponding to the session. + curve_key (str): The curve25519 key of the device. + session_id (str): The id of the session. + sessions (dict): Optional. A map from session id to + ``MegolmInboundSession`` object, to which the session will be added. + + Returns: + ``MegolmInboundSession`` object, or ``None`` if the session was not found. + """ + c = self.conn.cursor() + c.execute( + 'SELECT session, ed_key FROM megolm_inbound_sessions WHERE device_id=? AND ' + 'room_id=? AND curve_key=? AND session_id=?', + (self.device_id, room_id, curve_key, session_id) + ) + try: + row = c.fetchone() + session_data = bytes(row['session']) + except TypeError: + return None + finally: + c.close() + session = MegolmInboundSession.from_pickle(session_data, row['ed_key'], + self.pickle_key) + if sessions is not None: + sessions[session.id] = session + return session + + def save_outbound_session(self, room_id, session): + """Saves a Megolm outbound session. + + Args: + room_id (str): The room corresponding to the session. + session (MegolmOutboundSession): The session to save. + """ + c = self.conn.cursor() + pickle = session.pickle(self.pickle_key) + c.execute( + 'INSERT OR IGNORE INTO megolm_outbound_sessions VALUES (?,?,?,?,?,?,?)', + (self.device_id, room_id, pickle, session.max_age.total_seconds(), + session.max_messages, session.creation_time, session.message_count) + ) + c.execute('UPDATE megolm_outbound_sessions SET session=? WHERE device_id=? AND ' + 'room_id=?', (pickle, self.device_id, room_id)) + c.close() + self.conn.commit() + + def load_outbound_sessions(self, sessions): + """Loads all saved outbound Megolm sessions. + + Also loads the devices each are shared with. + + Args: + sessions (dict): A map from room_id to a ``MegolmOutboundSession`` object, + which will be populated. + """ + c = self.conn.cursor() + rows = c.execute( + 'SELECT * FROM megolm_outbound_sessions WHERE device_id=?', (self.device_id,)) + for row in rows.fetchall(): + device_ids = c.execute( + 'SELECT user_device_id FROM megolm_outbound_devices WHERE device_id=? ' + 'AND room_id=?', (self.device_id, row['room_id']) + ) + devices = {device_id[0] for device_id in device_ids} + max_age_s = row['max_age_s'] + max_age = timedelta(seconds=max_age_s) + session = MegolmOutboundSession.from_pickle( + bytes(row['session']), devices, max_age, row['max_messages'], + row['creation_time'], row['message_count'], self.pickle_key + ) + sessions[row['room_id']] = session + c.close() + + def get_outbound_session(self, room_id, sessions=None): + """Gets a saved outbound Megolm session. + + Also loads the devices it is shared with. + + Args: + room_id (str): The room corresponding to the session. + sessions (dict): Optional. A map from room_id to a + :class:`.MegolmOutboundSession` object, to which the session will be + added. + + Returns: + :class:`.MegolmOutboundSession` object, or ``None`` if the session was + not found. + """ + c = self.conn.cursor() + c.execute( + 'SELECT * FROM megolm_outbound_sessions WHERE device_id=? AND room_id=?', + (self.device_id, room_id) + ) + try: + row = c.fetchone() + session_data = bytes(row['session']) + except TypeError: + c.close() + return None + device_ids = c.execute( + 'SELECT user_device_id FROM megolm_outbound_devices WHERE device_id=? ' + 'AND room_id=?', (self.device_id, room_id) + ) + devices = {device_id[0] for device_id in device_ids} + c.close() + max_age_s = row['max_age_s'] + max_age = timedelta(seconds=max_age_s) + session = MegolmOutboundSession.from_pickle( + session_data, devices, max_age, row['max_messages'], row['creation_time'], + row['message_count'], self.pickle_key + ) + if sessions is not None: + sessions[room_id] = session + return session + + def remove_outbound_session(self, room_id): + """Removes a saved outbound Megolm session. + + Args: + room_id (str): The room corresponding to the session. + """ + c = self.conn.cursor() + c.execute('DELETE FROM megolm_outbound_sessions WHERE device_id=? AND room_id=?', + (self.device_id, room_id)) + c.close() + self.conn.commit() + + def save_megolm_outbound_devices(self, room_id, device_ids): + """Saves devices an outbound Megolm session is shared with. + + Args: + room_id (str): The room corresponding to the session. + device_ids (iterable): A list of device ids. + """ + c = self.conn.cursor() + rows = [(self.device_id, room_id, device_id) for device_id in device_ids] + c.executemany( + 'INSERT OR IGNORE INTO megolm_outbound_devices VALUES (?,?,?)', rows) + c.close() + self.conn.commit() + + def save_device_keys(self, device_keys): + """Saves device keys. + + Args: + device_keys (defaultdict(dict)): The format is ``{: {: + Device``. + """ + c = self.conn.cursor() + rows = [] + for user_id, devices_dict in device_keys.items(): + for device_id, device in devices_dict.items(): + rows.append((self.device_id, user_id, device_id, device.ed25519, + device.curve25519, device.verified, device.blacklisted, + device.ignored)) + c.executemany('REPLACE INTO device_keys VALUES (?,?,?,?,?,?,?,?)', rows) + c.close() + self.conn.commit() + + def load_device_keys(self, api, device_keys): + """Loads all saved device keys. + + Args: + device_keys (defaultdict(dict)): An object which will get populated with + the keys. The format is ``{: {: Device}}``. + """ + c = self.conn.cursor() + rows = c.execute( + 'SELECT * FROM device_keys WHERE device_id=?', (self.device_id,)) + for row in rows: + device_keys[row['user_id']][row['user_device_id']] = \ + self._device_from_row(row, api) + c.close() + + def get_device_keys(self, api, user_devices, device_keys=None): + """Gets the devices keys of the specified devices. + + Args: + user_devices (dict): A map from user ids to a list of device ids. + If no device ids are given for a user, all will be retrieved. + device_keys (defaultdict(dict)): Optional. Will be updated with + the retrieved keys. The format is ``{: {: + Device}}``. + + Returns: + A ``defaultdict(dict)`` containing the keys, the format is the same as the + ``device_keys`` argument. + """ + c = self.conn.cursor() + rows = [] + for user_id in user_devices: + if not user_devices[user_id]: + c.execute( + 'SELECT * FROM device_keys WHERE device_id=? AND user_id=?', + (self.device_id, user_id) + ) + rows.extend(c.fetchall()) + else: + for device_id in user_devices[user_id]: + c.execute( + 'SELECT * FROM device_keys WHERE device_id=? AND user_id=? AND ' + 'user_device_id=?', (self.device_id, user_id, device_id) + ) + rows.extend(c.fetchall()) + c.close() + result = defaultdict(dict) + for row in rows: + result[row['user_id']][row['user_device_id']] = \ + self._device_from_row(row, api) + + if device_keys is not None and result: + device_keys.update(result) + return result + + def _device_from_row(self, row, api): + return Device(api, row['user_id'], row['user_device_id'], database=self, + ed25519_key=row['ed_key'], curve25519_key=row['curve_key'], + verified=row['verified'], blacklisted=row['blacklisted'], + ignored=row['ignored']) + + def save_tracked_users(self, user_ids): + """Saves tracked users. + + Args: + user_ids (iterable): The user ids to save. + """ + c = self.conn.cursor() + rows = [(self.device_id, user_id) for user_id in user_ids] + c.executemany('INSERT OR IGNORE INTO tracked_users VALUES (?,?)', rows) + c.close() + self.conn.commit() + + def remove_tracked_users(self, user_ids): + """Removes tracked users. + + Args: + user_ids (iterable): The user ids to remove. + """ + c = self.conn.cursor() + rows = [(user_id,) for user_id in user_ids] + c.executemany('DELETE FROM tracked_users WHERE user_id=?', rows) + c.close() + self.conn.commit() + + def load_tracked_users(self, tracked_users): + """Loads all tracked users. + + Args: + tracked_users (set): Will be populated with user ids. + """ + c = self.conn.cursor() + rows = c.execute( + 'SELECT user_id FROM tracked_users WHERE device_id=?', (self.device_id,)) + tracked_users.update(row['user_id'] for row in rows) + c.close() + return tracked_users + + def save_sync_token(self, sync_token): + """Saves a sync token. + + Args: + sync_token (str): The token to save. + """ + c = self.conn.cursor() + c.execute('REPLACE INTO sync_tokens VALUES (?,?)', (self.device_id, sync_token)) + c.close() + self.conn.commit() + + def get_sync_token(self): + """Gets the saved sync token. + + Returns: + A string corresponding to the token, or ``None`` if there wasn't any. + """ + c = self.conn.cursor() + c.execute('SELECT token FROM sync_tokens WHERE device_id=?', (self.device_id,)) + try: + return c.fetchone()['token'] + except TypeError: + return None + finally: + c.close() + + def close(self): + self.conn.close() + + @property + def conn(self): + return self._conn[current_thread().ident] diff --git a/matrix_client/crypto/device_list.py b/matrix_client/crypto/device_list.py new file mode 100644 index 00000000..8952f468 --- /dev/null +++ b/matrix_client/crypto/device_list.py @@ -0,0 +1,332 @@ +import logging +from collections import defaultdict +from threading import Thread, Condition, Event, Lock + +from matrix_client.device import Device +from matrix_client.errors import MatrixHttpLibError, MatrixRequestError + +logger = logging.getLogger(__name__) + + +class DeviceList: + """Allows to maintain a list of devices up-to-date for an OlmDevice. + + Offers blocking and non-blocking methods to fetch device keys when appropriate. + NOTE: Spawns a thread that will last until program termination. + + Args: + olm_device (OlmDevice): Will be used to get additional info, such as device id. + api (MatrixHttpApi): The api object used to make requests. + device_keys (defaultdict(dict)): A map from user to device id to Device. + """ + + def __init__(self, olm_device, api, device_keys, db): + self.olm_device = olm_device + self.api = api + self.device_keys = device_keys + # Stores the ids of users who need updating + self.outdated_user_ids = _OutdatedUsersSet() + # Stores the ids of users to fetch the device keys of eventually + self.pending_outdated_user_ids = set() + self.pending_users_lock = Lock() + # Stores the ids of users we are currently tracking. We can assume the device + # keys of these users are up-to-date as long as no downloading is in progress. + # We should track every user we share an encrypted room with. + self.tracked_user_ids = set() + # Allows to wake up the thread when there are new users to update, and to + # synchronise shared data. + self.thread_condition = Condition() + self.db = db + self.db.load_tracked_users(self.tracked_user_ids) + self.update_thread = _UpdateDeviceList( + self.thread_condition, self.outdated_user_ids, self._download_device_keys, + self.tracked_user_ids, db + ) + self.update_thread.start() + + def get_room_device_keys(self, room, blocking=True): + """Gets the keys of all devices present in the room. + + Makes sure not to download keys of users we are already tracking. + The users we were not yet tracking will get tracked automatically. + + Args: + room (Room): The room to use. + blocking (bool): Optional. Whether to wait for the keys to have been + downloaded before returning. + """ + logger.info('Fetching all missing keys in room %s.', room.room_id) + members = {m.user_id for m in room.get_joined_members()} + missing_members = {m: [] for m in members if not self.device_keys[m]} + if missing_members: + self.db.get_device_keys(self.api, missing_members, self.device_keys) + user_ids = members - self.tracked_user_ids + if not user_ids: + logger.info('Already had all the keys in room %s.', room.room_id) + if blocking: + # Wait on an eventual download to finish + self.update_thread.event.wait() + return + with self.thread_condition: + self.outdated_user_ids.update(user_ids) + if blocking: + # Will ensure the user_ids we just added are processed + event = Event() + self.outdated_user_ids.events.add(event) + self.thread_condition.notify() + if blocking: + event.wait() + + def track_user_no_download(self, user_id): + """Add user to be tracked, but do not track it instantly. + + This should be used to avoid making calls to :func:`track_users` with only one + user repeatedly. Instead, using this should allow to passively queue the users, + and tracking can be triggered when there are sufficiently, using + :func:`track_pending_users`. + + Args: + user_id (str): A user id. + """ + with self.pending_users_lock: + self.pending_outdated_user_ids.add(user_id) + + def track_pending_users(self): + """Triggers the tracking of the user added with :func:`track_user_no_download`.""" + with self.pending_users_lock: + self.track_users(self.pending_outdated_user_ids) + self.pending_outdated_user_ids.clear() + + def track_users(self, user_ids): + """Add users to be tracked, and download their device keys. + + NOTE: this is non-blocking and will return before the keys are downloaded. + + Args: + user_ids (iterable): Any iterable containing user ids. + """ + user_ids = user_ids.difference(self.tracked_user_ids) + if user_ids: + self._add_outdated_users(user_ids) + + def update_after_restart(self, to_token): + from_token = self.db.get_sync_token() + if not from_token: + # First launch. Persist this token in case we would not have the occasion to + # save one this session. + self.db.save_sync_token(to_token) + return + resp = self.api.key_changes(from_token, to_token) + if resp.get('left'): + self.stop_tracking_users(resp['left']) + if resp.get('changed'): + self.update_user_device_keys(resp['changed']) + + def stop_tracking_users(self, user_ids): + """Stop tracking users. + + NOTE: Keys will not be deleted. + + Args: + user_ids (iterable): Any iterable containing user ids. + """ + with self.thread_condition: + self.tracked_user_ids.difference_update(user_ids) + self.outdated_user_ids.difference_update(user_ids) + self.db.remove_tracked_users(user_ids) + logger.info('Stopped tracking users: %s.', user_ids) + + def update_user_device_keys(self, user_ids, since_token=None): + """Triggers an update for users we already track. + + Args: + user_ids (iterable): Any iterable containing user ids. + since_token (str): Optional. Since token of a sync request, if triggering + the update as a result of that sync request. + """ + user_ids = self.tracked_user_ids.intersection(user_ids) + if not user_ids: + return + logger.info('Updating the device lists of users: %s, using token %s', + user_ids, since_token) + self._add_outdated_users(user_ids, since_token=since_token) + + def _add_outdated_users(self, user_ids, since_token=None): + """Stop tracking users. Keys will not be deleted. + + Args: + user_ids (iterable): Any iterable containing user ids. + since_token (str): Optional. Since token of a sync request. + """ + with self.thread_condition: + self.outdated_user_ids.update(user_ids) + if since_token: + self.outdated_user_ids.sync_token = since_token + self.thread_condition.notify() + + def _download_device_keys(self, user_devices, since_token=None): + """Download and store device keys, if they pass security checks. + + Args: + user_devices (dict): Format is ``user_id: [device_ids]``. + since_token (str): Optional. Since token of a sync request. + """ + changed = defaultdict(dict) + resp = self.api.query_keys(user_devices, token=since_token) + if resp.get('failures'): + logger.warning('Failed to download keys from the following unreachable ' + 'homeservers %s.', resp['failures']) + device_keys = resp['device_keys'] + for user_id in user_devices: + # The response might not contain every user_ids we requested + for device_id, payload in device_keys.get(user_id, {}).items(): + if device_id == self.olm_device.device_id: + continue + if payload['user_id'] != user_id or payload['device_id'] != device_id: + logger.warning('Mismatch in keys payload of device %s (%s) of user ' + '%s (%s).', payload['device_id'], device_id, + payload['user_id'], user_id) + continue + try: + signing_key = payload['keys']['ed25519:{}'.format(device_id)] + curve_key = payload['keys']['curve25519:{}'.format(device_id)] + except KeyError as e: + logger.warning('Invalid identity keys payload from device %s of' + 'user %s: %s.', device_id, user_id, e) + continue + verified = self.olm_device.verify_json( + payload, signing_key, user_id, device_id) + if not verified: + logger.warning('Signature verification failed for device %s of ' + 'user %s.', device_id, user_id) + continue + devices = self.device_keys[user_id] + try: + device = devices[device_id] + except KeyError: + devices[device_id] = Device(self.api, user_id, device_id, + curve25519_key=curve_key, + ed25519_key=signing_key, + database=self.db) + else: + if device.ed25519 != signing_key: + logger.warning('Ed25519 key has changed for device %s of ' + 'user %s.', device_id, user_id) + continue + if device.curve25519 == curve_key: + continue + device._curve25519 = curve_key + changed[user_id][device_id] = devices[device_id] + + logger.info('Successfully downloaded keys for devices: %s.', + {user_id: list(changed[user_id]) for user_id in changed}) + return changed + + +class _OutdatedUsersSet(set): + """Allows to know if elements in a set have been processed. + + This is done by adding elements along with an Event object. Then, functions + processing the set should set the events when they are done. + """ + + def __init__(self, iterable=()): + self.events = set() + self._sync_token = None + super(_OutdatedUsersSet, self).__init__(iterable) + + def mark_as_processed(self): + for event in self.events: + event.set() + + def copy(self): + new_set = _OutdatedUsersSet(self) + new_set.events = self.events.copy() + return new_set + + def clear(self): + self.events.clear() + super(_OutdatedUsersSet, self).clear() + + def update(self, iterable): + super(_OutdatedUsersSet, self).update(iterable) + if isinstance(iterable, _OutdatedUsersSet): + self.events.update(iterable.events) + + @property + def sync_token(self): + return self._sync_token + + @sync_token.setter + def sync_token(self, token): + if not self._sync_token or token > self._sync_token: + self._sync_token = token + + +class _UpdateDeviceList(Thread): + + def __init__(self, cond, user_ids, download_method, tracked_user_ids, db): + # We wait on this condition when there is nothing to do. Outside code should use + # it to notify us when they add data to be processed in outdated_user_ids so that + # we can wake up and process it. + self.cond = cond + self.outdated_user_ids = user_ids + self.download = download_method + self.tracked_user_ids = tracked_user_ids + # Cleared when we start a download, and set when we have finished it. This can be + # used by outside code in order to know if we are in the middle of a download, and + # allows to wait for it to complete by waiting on this event. + self.db = db + # Cleared when we start a download, and set when we have finished it. This can be + # used by outside code in order to know if we are in the middle of a download, and + # allows to wait for it to complete by waiting on this event. + self.event = Event() + # Used internally to terminate gracefully on program exit. + self._should_terminate = Event() + super(_UpdateDeviceList, self).__init__() + + def run(self): + while True and not self._should_terminate.is_set(): + with self.cond: + while not self.outdated_user_ids: + # Avoid any deadlocks + self.outdated_user_ids.mark_as_processed() + self.event.set() + logger.debug('Update thread is going to sleep...') + self.cond.wait() + logger.debug('Update thread woke up!') + if self._should_terminate.is_set(): + return + to_download = self.outdated_user_ids.copy() + self.outdated_user_ids.clear() + self.event.clear() + new_user_ids = to_download.difference(self.tracked_user_ids) + if new_user_ids: + self.tracked_user_ids.update(new_user_ids) + payload = {user_id: [] for user_id in to_download} + logger.info('Downloading device keys for users: %s.', to_download) + try: + changed = self.download(payload, self.outdated_user_ids.sync_token) + self.event.set() + to_download.mark_as_processed() + if changed: + self.db.save_device_keys(changed) + if new_user_ids: + self.db.save_tracked_users(new_user_ids) + if self.outdated_user_ids.sync_token: + # FIXME this should be next_batch instead of since + self.db.save_sync_token(self.outdated_user_ids.sync_token) + except (MatrixHttpLibError, MatrixRequestError) as e: + logger.warning('Network error when fetching device keys (will retry): %s', + e) + with self.cond: + self.outdated_user_ids.update(to_download) + self.tracked_user_ids.difference_update(new_user_ids) + + def join(self, timeout=None): + # If we are joined, this means that the main program is terminating. + # We should terminate too. + self._should_terminate.set() + with self.cond: + self.cond.notify() + super(_UpdateDeviceList, self).join(timeout=timeout) diff --git a/matrix_client/crypto/encrypt_attachments.py b/matrix_client/crypto/encrypt_attachments.py new file mode 100644 index 00000000..aa6ed62c --- /dev/null +++ b/matrix_client/crypto/encrypt_attachments.py @@ -0,0 +1,80 @@ +import unpaddedbase64 +from Crypto.Cipher import AES +from Crypto.Util import Counter +from Crypto import Random +from Crypto.Hash import SHA256 + + +def encrypt_attachment(plaintext): + """Encrypt a plaintext in order to send it as an encrypted attachment. + + Args: + plaintext (bytes): The data to encrypt. + + Returns: + A tuple of the ciphertext bytes and a dict containing the info needed + to decrypt data. The keys are: + + | key: AES-CTR JWK key object. + | iv: Base64 encoded 16 byte AES-CTR IV. + | hashes.sha256: Base64 encoded SHA-256 hash of the ciphertext. + """ + # 8 bytes IV + iv = Random.new().read(8) + # 8 bytes counter, prefixed by the IV + ctr = Counter.new(64, prefix=iv, initial_value=0) + key = Random.new().read(32) + cipher = AES.new(key, AES.MODE_CTR, counter=ctr) + ciphertext = cipher.encrypt(plaintext) + h = SHA256.new() + h.update(ciphertext) + digest = h.digest() + json_web_key = { + 'kty': 'oct', + 'alg': 'A256CTR', + 'ext': True, + 'k': unpaddedbase64.encode_base64(key, urlsafe=True), + 'key_ops': ['encrypt', 'decrypt'] + } + keys = { + 'v': 'v2', + 'key': json_web_key, + # Send IV concatenated with counter + 'iv': unpaddedbase64.encode_base64(iv + b'\x00' * 8), + 'hashes': { + 'sha256': unpaddedbase64.encode_base64(digest), + } + } + return ciphertext, keys + + +def decrypt_attachment(ciphertext, info): + """Decrypt an encrypted attachment. + + Args: + ciphertext (bytes): The data to decrypt. + info (dict): The information needed to decrypt the attachment. + + | key: AES-CTR JWK key object. + | iv: Base64 encoded 16 byte AES-CTR IV. + | hashes.sha256: Base64 encoded SHA-256 hash of the ciphertext. + + Returns: + The plaintext bytes. + + Raises: + RuntimeError if the integrity check fails. + """ + expected_hash = unpaddedbase64.decode_base64(info['hashes']['sha256']) + h = SHA256.new() + h.update(ciphertext) + if h.digest() != expected_hash: + raise RuntimeError('Mismatched SHA-256 digest.') + + key = unpaddedbase64.decode_base64(info['key']['k']) + # Drop last 8 bytes, which are 0 + iv = unpaddedbase64.decode_base64(info['iv'])[:8] + ctr = Counter.new(64, prefix=iv, initial_value=0) + cipher = AES.new(key, AES.MODE_CTR, counter=ctr) + + return cipher.decrypt(ciphertext) diff --git a/matrix_client/crypto/olm_device.py b/matrix_client/crypto/olm_device.py index 514965db..9bf9863d 100644 --- a/matrix_client/crypto/olm_device.py +++ b/matrix_client/crypto/olm_device.py @@ -1,15 +1,23 @@ +import json import logging +from collections import defaultdict import olm from canonicaljson import encode_canonical_json from matrix_client.checks import check_user_id +from matrix_client.device import Device +from matrix_client.errors import E2EUnknownDevices from matrix_client.crypto.one_time_keys import OneTimeKeysManager +from matrix_client.crypto.device_list import DeviceList +from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession +from matrix_client.crypto.crypto_store import CryptoStore +from matrix_client.crypto.verified_event import VerifiedEvent logger = logging.getLogger(__name__) -class OlmDevice(object): +class OlmDevice(Device): """Manages the Olm cryptographic functions. Has a unique Olm account which holds identity keys. @@ -17,7 +25,8 @@ class OlmDevice(object): Args: api (MatrixHttpApi): The api object used to make requests. user_id (str): Matrix user ID. Must match the one used when logging in. - device_id (str): Must match the one used when logging in. + device_id (str): Optional. Must match the one used when logging in. If absent, + attempt to retrieve it from database using ``user_id``. signed_keys_proportion (float): Optional. The proportion of signed one-time keys we should maintain on the HS compared to unsigned keys. The maximum value of ``1`` means only signed keys will be uploaded, while the minimum value of @@ -28,6 +37,18 @@ class OlmDevice(object): replenishment is triggered. Must be between ``0`` and ``1``. For example, ``0.1`` means that new one-time keys will be uploaded when there is less than 10% of the maximum number of one-time keys on the server. + Store (class): Optional. Custom storage class. It should implement the same + methods as :class:`~matrix_client.crypto.crypto_store.CryptoStore`. + store_conf (dict): Optional. Configuration parameters for keys storage. Refer to + :func:`~matrix_client.crypto.crypto_store.CryptoStore` for supported options, + since it will be passed to this class. + load_all (bool): Optional. If True, all content of the database for the current + device will be loaded at once. This will increase runtime performance but + also launch time and memory usage. + + Raises: + ``ValueError`` if ``device_id`` was not given and couldn't be retrieved + from database. """ _olm_algorithm = 'm.olm.v1.curve25519-aes-sha2' @@ -37,9 +58,12 @@ class OlmDevice(object): def __init__(self, api, user_id, - device_id, + device_id=None, signed_keys_proportion=1, - keys_threshold=0.1): + keys_threshold=0.1, + Store=CryptoStore, + store_conf=None, + load_all=False): if not 0 <= signed_keys_proportion <= 1: raise ValueError('signed_keys_proportion must be between 0 and 1.') if not 0 <= keys_threshold <= 1: @@ -47,10 +71,26 @@ def __init__(self, self.api = api check_user_id(user_id) self.user_id = user_id - self.device_id = device_id - self.olm_account = olm.Account() - logger.info('Initialised Olm Device.') - self.identity_keys = self.olm_account.identity_keys + conf = store_conf or {} + self.db = Store(user_id, device_id=device_id, **conf) + self.olm_sessions = defaultdict(list) + self.megolm_inbound_sessions = defaultdict(lambda: defaultdict(dict)) + self.megolm_outbound_sessions = {} + self.device_keys = defaultdict(dict) + self.olm_account = self.db.get_olm_account() + if not device_id: + device_id = self.db.device_id + if self.olm_account: + if load_all: + self.db.load_olm_sessions(self.olm_sessions) + self.db.load_inbound_sessions(self.megolm_inbound_sessions) + self.db.load_outbound_sessions(self.megolm_outbound_sessions) + self.db.load_device_keys(self.api, self.device_keys) + logger.info('Loaded Olm account from database for device %s.', device_id) + else: + self.olm_account = olm.Account() + self.db.replace_olm_account(self.olm_account) + logger.info('Created new Olm account for device %s.', device_id) # Try to maintain half the number of one-time keys libolm can hold uploaded # on the HS. This is because some keys will be claimed by peers but not # used instantly, and we want them to stay in libolm, until the limit is reached @@ -59,6 +99,12 @@ def __init__(self, self.one_time_keys_manager = OneTimeKeysManager(target_keys_number, signed_keys_proportion, keys_threshold) + self.device_list = DeviceList(self, api, self.device_keys, self.db) + self.megolm_index_record = defaultdict(dict) + keys = self.olm_account.identity_keys + super(OlmDevice, self).__init__(self.api, self.user_id, device_id, + database=self.db, ed25519_key=keys['ed25519'], + curve25519_key=keys['curve25519']) def upload_identity_keys(self): """Uploads this device's identity keys to HS. @@ -69,8 +115,10 @@ def upload_identity_keys(self): 'user_id': self.user_id, 'device_id': self.device_id, 'algorithms': self._algorithms, - 'keys': {'{}:{}'.format(alg, self.device_id): key - for alg, key in self.identity_keys.items()} + 'keys': { + 'curve25519:{}'.format(self.device_id): self.curve25519, + 'ed25519:{}'.format(self.device_id): self.ed25519 + } } self.sign_json(device_keys) ret = self.api.upload_keys(device_keys=device_keys) @@ -116,12 +164,14 @@ def upload_one_time_keys(self, force_update=False): ret = self.api.upload_keys(one_time_keys=one_time_keys) self.one_time_keys_manager.server_counts = ret['one_time_key_counts'] self.olm_account.mark_keys_as_published() + self.db.save_olm_account(self.olm_account) keys_uploaded = {} if unsigned_keys_to_upload: keys_uploaded['curve25519'] = unsigned_keys_to_upload if signed_keys_to_upload: keys_uploaded['signed_curve25519'] = signed_keys_to_upload + logger.info('Uploaded new one-time keys: %s.', keys_uploaded) return keys_uploaded @@ -136,6 +186,642 @@ def update_one_time_key_counts(self, counts): logger.info('Uploading new one-time keys.') self.upload_one_time_keys() + def olm_start_sessions(self, user_devices): + """Start olm sessions with the given devices. + + NOTE: those device keys must already be known. + + Args: + user_devices (dict): A map from user_id to an iterable of device_ids. + The format is ``: []``. + """ + logger.info('Trying to establish Olm sessions with devices: %s.', + dict(user_devices)) + payload = defaultdict(dict) + for user_id in user_devices: + for device_id in user_devices[user_id]: + payload[user_id][device_id] = 'signed_curve25519' + + resp = self.api.claim_keys(payload) + if resp.get('failures'): + logger.warning('Failed to claim one-time keys from the following unreachable ' + 'homeservers: %s.', resp['failures']) + keys = resp['one_time_keys'] + if logger.level >= logging.WARNING: + missing = {} + for user_id, device_ids in user_devices.items(): + if user_id not in keys: + missing[user_id] = device_ids + else: + missing_devices = set(device_ids) - set(keys[user_id]) + if missing_devices: + missing[user_id] = missing_devices + logger.warning('Failed to claim the keys of %s.', missing) + + new_sessions = defaultdict(list) + for user_id in user_devices: + for device_id, one_time_key in keys.get(user_id, {}).items(): + try: + device = self.device_keys[user_id][device_id] + except KeyError: + logger.warning('Key for device %s of user %s not found, could not ' + 'start Olm session.', device_id, user_id) + continue + key_object = next(iter(one_time_key.values())) + verified = self.verify_json(key_object, + device.ed25519, + user_id, + device_id) + if verified: + session = olm.OutboundSession(self.olm_account, + device.curve25519, + key_object['key']) + sessions = self.olm_sessions[device.curve25519] + sessions.append(session) + new_sessions[device.curve25519].append(session) + logger.info('Established Olm session %s with device %s of user ' + '%s.', device_id, session.id, user_id) + else: + logger.warning('Signature verification for one-time key of device %s ' + 'of user %s failed, could not start olm session.', + device_id, user_id) + self.db.save_olm_sessions(new_sessions) + + def olm_build_encrypted_event(self, event_type, content, user_id, device_id): + """Encrypt an event using Olm. + + NOTE: a session with this device must already be established. + + Args: + event_type (str): The event type, will be encrypted. + content (dict): The event content, will be encrypted. + user_id (str): The intended recipient of the event. + device_id (str): The device to encrypt to. + + Returns: + The Olm encrypted event, as JSON. + """ + try: + device = self.device_keys[user_id][device_id] + except KeyError: + raise RuntimeError('Device is unknown, could not encrypt.') + + payload = { + 'type': event_type, + 'content': content, + 'sender': self.user_id, + 'sender_device': self.device_id, + 'keys': { + 'ed25519': self.ed25519 + }, + 'recipient': user_id, + 'recipient_keys': { + 'ed25519': device.ed25519 + } + } + + sessions = self.olm_sessions[device.curve25519] + if sessions: + session = sorted(sessions, key=lambda s: s.id)[0] + else: + raise RuntimeError('No session for this device, could not encrypt.') + + encrypted_message = session.encrypt(json.dumps(payload)) + self.db.save_olm_session(device.curve25519, session) + ciphertext_payload = { + device.curve25519: { + 'type': encrypted_message.message_type, + 'body': encrypted_message.ciphertext + } + } + + event = { + 'algorithm': self._olm_algorithm, + 'sender_key': self.curve25519, + 'ciphertext': ciphertext_payload + } + return event + + def olm_decrypt_event(self, content, user_id): + """Decrypt an Olm encrypted event, and check its properties. + + Args: + event (dict): The content property of a m.room.encrypted event. + user_id (str): The sender of the event. + + Retuns: + The decrypted event held by the initial event. + + Raises: + RuntimeError: Error in the decryption process. Nothing can be done. The text + of the exception indicates what went wrong, and should be logged or + displayed to the user. + KeyError: The event is missing a required field. + """ + if content['algorithm'] != self._olm_algorithm: + raise RuntimeError('Event was not encrypted with {}.' + .format(self._olm_algorithm)) + + ciphertext = content['ciphertext'] + try: + payload = ciphertext[self.curve25519] + except KeyError: + raise RuntimeError('This message was not encrypted for us.') + + msg_type = payload['type'] + if msg_type == 0: + encrypted_message = olm.OlmPreKeyMessage(payload['body']) + else: + encrypted_message = olm.OlmMessage(payload['body']) + + sender_key = content['sender_key'] + decrypted_event = self._olm_decrypt(encrypted_message, sender_key) + + signing_key = decrypted_event['keys']['ed25519'] + if decrypted_event['sender'] != user_id: + raise RuntimeError( + 'Found user {} instead of sender {} in Olm plaintext {}.' + .format(decrypted_event['sender'], user_id, decrypted_event) + ) + if decrypted_event['recipient'] != self.user_id: + raise RuntimeError( + 'Found user {} instead of us ({}) in Olm plaintext {}.' + .format(decrypted_event['recipient'], self.user_id, decrypted_event) + ) + our_key = decrypted_event['recipient_keys']['ed25519'] + if our_key != self.ed25519: + raise RuntimeError( + 'Found key {} instead of ours own ed25519 key {} in Olm plaintext {}.' + .format(our_key, self.ed25519, decrypted_event) + ) + try: + device = self.device_keys[user_id][decrypted_event['sender_device']] + except KeyError: + pass + else: + if device.verified: + if device.curve25519 != sender_key or device.ed25519 != signing_key: + raise RuntimeError( + 'Device keys mismatch between payload and /keys/query data.' + ) + + return decrypted_event + + def _olm_decrypt(self, olm_message, sender_key): + """Decrypt an Olm encrypted event. + + NOTE: This does no implement any security check. + + Try to decrypt using existing sessions. If it fails, start an new one when + possible. + + Args: + olm_message (OlmMessage): Olm encrypted payload. + sender_key (str): The sender's curve25519 identity key. + + Returns: + The decrypted event held by the initial payload, as JSON. + """ + + sessions = self.olm_sessions[sender_key] + if not sessions: + # `sessions` should get populated by this method + self.db.get_olm_sessions(sender_key, self.olm_sessions) + + # Try to decrypt message body using one of the known sessions for that device + for session in sessions: + try: + event = session.decrypt(olm_message) + self.db.save_olm_session(sender_key, session) + logger.info('Success decrypting Olm event using existing session %s.', + session.id) + break + except olm.session.OlmSessionError as e: + if olm_message.message_type == 0: + if session.matches(olm_message, sender_key): + # We had a matching session for a pre-key message, but it didn't + # work. This means something is wrong, so we fail now. + raise RuntimeError('Error decrypting pre-key message with ' + 'existing Olm session {}, reason: {}.' + .format(session.id, e)) + # Simply keep trying otherwise + else: + if olm_message.message_type > 0: + # Not a pre-key message, we should have had a matching session + if sessions: + raise RuntimeError('Error decrypting with existing sessions.') + raise RuntimeError('No existing sessions.') + + # We have a pre-key message without any matching session, in this case + # we should try to create one. + try: + session = olm.session.InboundSession( + self.olm_account, olm_message, sender_key) + except olm.session.OlmSessionError as e: + raise RuntimeError('Error decrypting pre-key message when trying to ' + 'establish a new session: {}.'.format(e)) + + logger.info('Created new Olm session %s.', session.id) + try: + event = session.decrypt(olm_message) + except olm.session.OlmSessionError as e: + raise RuntimeError('Error decrypting pre-key message with new session: ' + '{}.'.format(e)) + self.olm_account.remove_one_time_keys(session) + self.db.save_olm_account(self.olm_account) + self.db.save_olm_session(sender_key, session) + sessions.append(session) + + return json.loads(event) + + def olm_ensure_sessions(self, user_devices): + """Start Olm sessions with the given devices if one doesn't exist already. + + Args: + user_devices (dict): A map from user ids to a list of device ids. + """ + user_devices_no_session = defaultdict(list) + for user_id in user_devices: + for device_id in user_devices[user_id]: + curve_key = self.device_keys[user_id][device_id].curve25519 + # Check if we have a list of sessions for this device, which can be + # empty. Implicitely, an empty list will indicate that we already tried + # to establish a session with a device, but this attempt was + # unsuccessful. We do not retry to establish a session. + if curve_key not in self.olm_sessions: + sessions = self.db.get_olm_sessions(curve_key, self.olm_sessions) + if not sessions: + user_devices_no_session[user_id].append(device_id) + if user_devices_no_session: + self.olm_start_sessions(user_devices_no_session) + + def megolm_start_session(self, room, user_devices): + """Start a megolm session in a room, and share it with its members. + + Args: + room (Room): The room to use. + user_devices (dict): Map from user id to a list of device ids. The session + will be shared with those devices. + + Returns: + The newly created session. + """ + session = MegolmOutboundSession(max_age=room.rotation_period_ms, + max_messages=room.rotation_period_msgs) + self.megolm_outbound_sessions[room.room_id] = session + logger.info('Starting a new Meglom outbound session %s in %s.', + session.id, room.room_id) + + self.db.remove_outbound_session(room.room_id) + self.db.save_outbound_session(room.room_id, session) + self.megolm_share_session(room.room_id, user_devices, session) + # Store a corresponding inbound session, so that we can decrypt our own messages + self.megolm_add_inbound_session( + room.room_id, self.curve25519, self.ed25519, session.id, session.session_key) + return session + + def megolm_share_session(self, room_id, user_devices, session): + """Share an already existing outbound megolm session with the specified devices. + + Args: + room_id (str): The room corresponding to the session. + user_devices (dict): A map from user ids to a list of device ids. + session (MegolmOutboundSession): The session object. + """ + + logger.info('Attempting to share Megolm session %s in %s with %s.', + session.id, room_id, user_devices) + self.olm_ensure_sessions(user_devices) + + event = { + 'algorithm': self._megolm_algorithm, + 'room_id': room_id, + 'session_id': session.id, + 'session_key': session.session_key + } + + messages = defaultdict(dict) + new_devices = set() + for user_id in user_devices: + for device_id in user_devices[user_id]: + try: + messages[user_id][device_id] = self.olm_build_encrypted_event( + 'm.room_key', event, user_id, device_id + ) + except RuntimeError as e: + logger.warning('Could not share megolm session %s with device %s of ' + 'user %s: %s', session.id, + device_id, user_id, e) + # We will not retry to share session with failed devices + new_devices.add(device_id) + self.api.send_to_device('m.room.encrypted', messages) + session.add_devices(new_devices) + self.db.save_megolm_outbound_devices(room_id, new_devices) + + def megolm_share_session_with_new_devices(self, room, user_devices, session): + """Share a megolm session with new devices in a room. + + Args: + room (Room): The room corresponding to the session. + session (MegolmOutboundSession): The session to share. + user_devices (dict): Map from user id to a list of device ids. The session + will be shared with those devices if not already. + """ + new_user_devices = {} + for user_id in user_devices: + missing_devices = list(set(self.device_keys[user_id].keys()) - + self.megolm_outbound_sessions[room.room_id].devices) + if missing_devices: + new_user_devices[user_id] = missing_devices + + if new_user_devices: + logger.info('Sharing existing Megolm outbound session %s with new ' + 'devices: %s', session.id, new_user_devices) + self.megolm_share_session(room.room_id, new_user_devices, session) + + def megolm_get_recipients(self, room, session=None): + """Get the devices who should be able to decrypt a Megolm event in a room. + + This implements device verification checks. + + Args: + room (Room): The room to use. + session (MegolmOutboundSession): Optional. If a device the session had + been shared with has been blacklisted, remove the session. + + Returns: + A two element tuple containing a map from user id to a list of device ids, + and a boolean indicating whether the session has been removed. + + Raises: + E2EUnknownDevices if there are never seen before devices in the room. + """ + users = room.get_joined_members() + + user_devices = defaultdict(list) + unknown_devices = defaultdict(list) + removed_session = False + for user in users: + for device_id, device in self.device_keys[user.user_id].items(): + if device.blacklisted: + if session and device.device_id in session.devices: + self.megolm_remove_outbound_session(room.room_id) + removed_session = True + else: + if not room.verify_devices or device.ignored or device.verified: + user_devices[user.user_id].append(device_id) + else: + unknown_devices[user.user_id].append(device) + if unknown_devices and room.verify_devices: + logger.warning('Room %s contains unknown devices which have not been ' + 'verified.', room.room_id) + raise E2EUnknownDevices(unknown_devices) + + return user_devices, removed_session + + def megolm_build_encrypted_event(self, room, event): + """Build an encrypted Megolm payload from a plaintext event. + + If no session exists in the room, a new one will be initiated. Also takes care + of rotating the session periodically. + + Args: + room (Room): The room the event will be sent in. + event (dict): Matrix event. + + Returns: + The encrypted event, as a dict. + + Raises: + E2EUnknownDevices if there are never seen before devices in the room. + """ + room_id = room.room_id + + try: + session = self.megolm_outbound_sessions[room_id] + except KeyError: + session = self.db.get_outbound_session(room_id, self.megolm_outbound_sessions) + # We have to fetch device keys if there is no session. If there is one, we are + # already tracking the device list of users in the room, so it shouldn't be + # needed. + # However, there is the edge case where a device is blacklisted, and then the + # client is shutdown. When we load the session, if we do not fetch the keys + # (which triggers loading the devices from db), we would miss that a device + # had been blacklisted and we would keep using the session instead of rotating + # it as expected. Hence we also fetch device keys after a session is loaded. + self.device_list.get_room_device_keys(room) + + user_devices, removed_session = self.megolm_get_recipients(room, session) + + if not session or removed_session or session.should_rotate(): + session = self.megolm_start_session(room, user_devices) + else: + self.megolm_share_session_with_new_devices(room, user_devices, session) + + payload = { + 'type': event['type'], + 'content': event['content'], + 'room_id': room_id + } + + encrypted_payload = session.encrypt(json.dumps(payload)) + self.db.save_outbound_session(room_id, session) + + encrypted_event = { + 'algorithm': self._megolm_algorithm, + 'sender_key': self.curve25519, + 'ciphertext': encrypted_payload, + 'session_id': session.id, + 'device_id': self.device_id + } + return encrypted_event + + def megolm_remove_outbound_session(self, room_id): + """Remove an existing Megolm outbound session in a room. + + If there is no such session, nothing will happen. + + Args: + room_id (str): The room to use. + """ + try: + self.megolm_outbound_sessions.pop(room_id) + self.db.remove_outbound_session(room_id) + logger.info('Removed Meglom outbound session in %s.', room_id) + except KeyError: + pass + + def send_encrypted_message(self, room, content): + """Send a m.room.encrypted event in a room. + + Args: + room (Room): The room to use. + content (dict): The content of the event, will be encrypted. + + Raises: + MatrixRequestError if there was an error sending the event. + E2EUnknownDevices if there are never seen before devices in the room. + The event will not be sent. + """ + event = {'content': content, 'room_id': room.room_id, 'type': 'm.room.message'} + encrypted_event = self.megolm_build_encrypted_event(room, event) + return self.api.send_message_event( + room.room_id, 'm.room.encrypted', encrypted_event) + + def olm_handle_encrypted_event(self, encrypted_event): + """Decrypt and process an Olm m.room.encrypted event. + + Once decrypted, the event is processed according to its type. + + Args: + encrypted_event (dict): m.room.encrypted event. + """ + content = encrypted_event['content'] + if 'algorithm' not in content or content['algorithm'] != self._olm_algorithm: + return + + try: + event = self.olm_decrypt_event(content, encrypted_event['sender']) + except RuntimeError as e: + logger.warning('Failed to decrypt m.room_key event sent by user %s: %s', + encrypted_event['sender'], e) + return + + if event['type'] == 'm.room_key': + self.handle_room_key_event(event, encrypted_event['content']['sender_key']) + + def handle_room_key_event(self, event, sender_key): + """Handle a m.room_key event. + + Args: + event (dict): m.room_key event. + """ + signing_key = event['keys']['ed25519'] + content = event['content'] + if content['algorithm'] != self._megolm_algorithm: + logger.info('Ignoring unsupported algorithm %s in m.room_key event.', + content['algorithm']) + return + user_id = event['sender'] + device_id = event['sender_device'] + + new = self.megolm_add_inbound_session(content['room_id'], sender_key, + signing_key, content['session_id'], + content['session_key']) + if new: + logger.info('Created a new Megolm inbound session with device %s of ' + 'user %s.', device_id, user_id) + else: + logger.info('Inbound Megolm session with device %s of user %s ' + 'already exists or is invalid.', device_id, user_id) + + def megolm_add_inbound_session(self, room_id, sender_key, signing_key, session_id, + session_key): + """Create a new Megolm inbound session if necessary. + + Args: + room_id (str): The room corresponding to the session. + sender_key (str): The curve25519 key of the sender's device. + session_id (str): The id of the session. + session_key (str): The key of the session. + signing_key (str): The ed25519 key of the event which established the session. + + Returns: + ``True`` if a new session was created, ``False`` if it already existed or if + the parameters were invalid. + """ + sessions = self.megolm_inbound_sessions[room_id][sender_key] + if session_id in sessions: + return False + # Load the session if it exists + if self.db.get_inbound_session(room_id, sender_key, session_id, sessions): + return False + try: + session = MegolmInboundSession(session_key, signing_key) + except olm.OlmGroupSessionError: + return False + if session.id != session_id: + logger.warning('Session ID mismatch in m.room_key event. Expected %s from ' + 'event property, got %s.', session_id, session.id) + return False + self.db.save_inbound_session(room_id, sender_key, session) + sessions[session_id] = session + return True + + def megolm_decrypt_event(self, event): + """Decrypt a Megolm m.room.encrypted event. + + Args: + event (dict): The event to decrypt. It may be modified in the process. + + Returns: + The decrypted event, as a normal ``dict`` if unverified, or as a + :class:`.VerifiedEvent` if verified. + """ + content = event['content'] + device_id = content['device_id'] + user_id = event['sender'] + if 'algorithm' not in content: + # Assume that this is a redacted event + return + if content['algorithm'] != self._megolm_algorithm: + raise RuntimeError('Incorrect algorithm "{}" value in event sent by device ' + '{} of user {}.'.format(content['algorithm'], device_id, + user_id)) + + sender_key = content['sender_key'] + room_id = event['room_id'] + session_id = content['session_id'] + sessions = self.megolm_inbound_sessions[room_id][sender_key] + try: + session = sessions[session_id] + except KeyError: + session = self.db.get_inbound_session( + room_id, sender_key, session_id, sessions) + if not session: + raise RuntimeError("Unable to decrypt event sent by device {} of user " + "{}: The sender's device has not sent us the keys for " + "this message.".format(device_id, user_id)) + + try: + decrypted_event, message_index = session.decrypt(content['ciphertext']) + except olm.group_session.OlmGroupSessionError as e: + raise RuntimeError('Unable to decrypt event sent by device {} of user {} ' + 'with matching megolm session: {}.'.format(device_id, + user_id, e)) + + try: + device = self.device_keys[user_id][device_id] + except KeyError: + pass + else: + if device.verified: + if device.ed25519 != session.ed25519 or device.curve25519 != sender_key: + raise RuntimeError('Device keys mismatch in event sent by device {}.' + .format(device.device_id)) + event = VerifiedEvent(event) + + try: + properties = self.megolm_index_record[session.id][message_index] + except KeyError: + self.megolm_index_record[session.id][message_index] = { + 'origin_server_ts': event['origin_server_ts'], + 'event_id': event['event_id'] + } + else: + if properties['origin_server_ts'] != event['origin_server_ts'] or \ + properties['event_id'] != event['event_id']: + raise RuntimeError('Detected a replay attack from device {} of user {} ' + 'on decrypted event: {}.'.format(device_id, user_id, + decrypted_event)) + + decrypted_event = json.loads(decrypted_event) + + event['type'] = decrypted_event['type'] + event['content'] = decrypted_event['content'] + + return event + def sign_json(self, json): """Signs a JSON object. diff --git a/matrix_client/crypto/sessions.py b/matrix_client/crypto/sessions.py new file mode 100644 index 00000000..fe9d024d --- /dev/null +++ b/matrix_client/crypto/sessions.py @@ -0,0 +1,93 @@ +from datetime import datetime, timedelta + +from olm import OutboundGroupSession, InboundGroupSession + + +class MegolmOutboundSession(OutboundGroupSession): + + """Outbound group session aware of the users it is shared with. + + Also remembers the time it was created and the number of messages it has encrypted, + in order to know if it needs to be rotated. + + Args: + max_age (datetime.timedelta): Optional. The maximum time the session should + exist. Default to one week if not present. + max_messages (int): Optional. The maximum number of messages that should be sent. + A new message in considered sent each time there is a call to ``encrypt``. + Default to 100 if not present. + + Attributes: + creation_time (datetime.datetime): Creation time of the session. + message_count (int): Number of messages encrypted using the session. + """ + + def __init__(self, max_age=None, max_messages=None): + self.devices = set() + if max_age: + self.max_age = timedelta(milliseconds=max_age) + else: + self.max_age = timedelta(days=7) + self.max_messages = max_messages or 100 + self.creation_time = datetime.now() + self.message_count = 0 + super(MegolmOutboundSession, self).__init__() + + def __new__(cls, **kwargs): + return super(MegolmOutboundSession, cls).__new__(cls) + + def add_device(self, device_id): + """Adds a device the session is shared with.""" + self.devices.add(device_id) + + def add_devices(self, device_ids): + """Adds devices the session is shared with. + + Args: + device_ids (iterable): An iterable of device ids, preferably a set. + """ + self.devices.update(device_ids) + + def should_rotate(self): + """Wether the session should be rotated. + + Returns: + True if it should, False if not. + """ + if self.message_count >= self.max_messages or \ + datetime.now() - self.creation_time >= self.max_age: + return True + return False + + def encrypt(self, plaintext): + self.message_count += 1 + return super(MegolmOutboundSession, self).encrypt(plaintext) + + @classmethod + def from_pickle(cls, pickle, devices, max_age, max_messages, creation_time, + message_count, passphrase=''): + session = super(MegolmOutboundSession, cls).from_pickle(pickle, passphrase) + session.devices = devices + session.max_age = max_age + session.max_messages = max_messages + session.creation_time = creation_time + session.message_count = message_count + return session + + +class MegolmInboundSession(InboundGroupSession): + + """Olm session with memory of the ed25519 key of the user it was established with.""" + + def __init__(self, session_key, signing_key): + self.ed25519 = signing_key + super(MegolmInboundSession, self).__init__(session_key) + + def __new__(cls, *args): + return super(MegolmInboundSession, cls).__new__(cls) + + @classmethod + def from_pickle(cls, pickle, signing_key, passphrase=''): + session = super(MegolmInboundSession, cls).from_pickle(pickle, passphrase) + session.ed25519 = signing_key + return session diff --git a/matrix_client/crypto/verified_event.py b/matrix_client/crypto/verified_event.py new file mode 100644 index 00000000..e156452b --- /dev/null +++ b/matrix_client/crypto/verified_event.py @@ -0,0 +1,2 @@ +class VerifiedEvent(dict): + pass diff --git a/matrix_client/device.py b/matrix_client/device.py new file mode 100644 index 00000000..2d0aff63 --- /dev/null +++ b/matrix_client/device.py @@ -0,0 +1,111 @@ +from .errors import MatrixRequestError + + +class Device(object): + """Represents a Matrix device, belonging to a user. + + Args: + api (MatrixHttpApi): The api object used to make requests. + user_id (str): User ID of this device's owner. + device_id (str): The device ID. + display_name (str): Optional. The display name of this device, if any. + last_seen_ip (str): Optional. The IP address where this device was last seen. + last_seen_ts (int): Optional. The timestamp (in milliseconds since the unix + epoch) when this device was last seen. + verified, blacklisted, ignored (bool): Optional. Device verification info. + ed25519_key (str): Optional. The Ed25519 fingerprint key of this device. The + corresponding attribute ``ed25519`` cannot be changed after initialisation. + curve25519_key (str): Optional. The Curve25519 fingerprint key of this device. The + corresponding attribute ``curve25519`` cannot be changed after initialisation. + database (CryptoStore): Optional. Allows to save device verification info. + """ + + def __init__(self, + api, + user_id, + device_id, + database=None, + display_name=None, + last_seen_ip=None, + last_seen_ts=None, + verified=False, + blacklisted=False, + ignored=False, + ed25519_key=None, + curve25519_key=None): + self.api = api + self.user_id = user_id + self.device_id = device_id + self.database = database + self.display_name = display_name + self.last_seen_ts = last_seen_ts + self.last_seen_ip = last_seen_ip + self._verified = verified + self._blacklisted = blacklisted + self._ignored = ignored + self._ed25519 = ed25519_key + self._curve25519 = curve25519_key + + def get_info(self): + """Gets information on the device. + + The ``display_name``, ``last_seen_ip`` and ``last_seen_ts`` attribute will + get updated, if these were available. + + Returns: + True if successful, False if the device was not found. + """ + try: + info = self.api.get_device(self.device_id) + except MatrixRequestError as e: + if e.code == 404: + return False + raise + self.display_name = info.get('display_name') + self.last_seen_ip = info.get('last_seen_ip') + self.last_seen_ts = info.get('last_seen_ts') + return True + + def save_to_db(func): + def save(self, boolean): + if not self.ed25519: + raise ValueError('Changing this property is not allowed when the device ' + 'keys are unknown.') + func(self, boolean) + self.database.save_device_keys({self.user_id: {self.device_id: self}}) + return save + + @property + def ed25519(self): + return self._ed25519 + + @property + def curve25519(self): + return self._curve25519 + + @property + def verified(self): + return self._verified + + @verified.setter + @save_to_db + def verified(self, boolean): + self._verified = boolean + + @property + def ignored(self): + return self._ignored + + @ignored.setter + @save_to_db + def ignored(self, boolean): + self._ignored = boolean + + @property + def blacklisted(self): + return self._blacklisted + + @blacklisted.setter + @save_to_db + def blacklisted(self, boolean): + self._blacklisted = boolean diff --git a/matrix_client/errors.py b/matrix_client/errors.py index e9dc8fe3..91154bb8 100644 --- a/matrix_client/errors.py +++ b/matrix_client/errors.py @@ -46,3 +46,25 @@ def __init__(self, original_exception, method, endpoint): original_exception) ) self.original_exception = original_exception + + +class MatrixNoEncryptionError(MatrixError): + """Encryption was not available.""" + + def __init__(self, content=""): + super(MatrixNoEncryptionError, self).__init__(content) + + +class E2EUnknownDevices(Exception): + """The room contained unknown devices when sending a message. + + Args: + user_devices (dict): A map from user_id to a list of Device objects, + containing the unknown devices for that user. + """ + + def __init__(self, user_devices): + super(Exception, self).__init__( + "The room contains unknown devices which have not been verified. They can " + "be inspected via the 'user_devices' attribute of this exception.") + self.user_devices = user_devices diff --git a/matrix_client/room.py b/matrix_client/room.py index 488f335b..5ec57d7a 100644 --- a/matrix_client/room.py +++ b/matrix_client/room.py @@ -13,12 +13,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import re from uuid import uuid4 from .checks import check_room_id from .user import User -from .errors import MatrixRequestError +from .errors import MatrixRequestError, MatrixNoEncryptionError + +logger = logging.getLogger(__name__) class Room(object): @@ -28,7 +31,7 @@ class Room(object): NOTE: This does not verify the room with the Home Server. """ - def __init__(self, client, room_id): + def __init__(self, client, room_id, verify_devices=False): check_room_id(room_id) self.room_id = room_id @@ -50,6 +53,9 @@ def __init__(self, client, room_id): # user_id: displayname } self.encrypted = False + self.rotation_period_msgs = None + self.rotation_period_ms = None + self.verify_devices = verify_devices def set_user_profile(self, displayname=None, @@ -102,7 +108,10 @@ def display_name(self): def send_text(self, text): """Send a plain text message to the room.""" - return self.client.api.send_message(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_text_body(text)) + else: + return self.client.api.send_message(self.room_id, text) def get_html_content(self, html, body=None, msgtype="m.text"): return { @@ -119,8 +128,12 @@ def send_html(self, html, body=None, msgtype="m.text"): html (str): The html formatted message to be sent. body (str): The unformatted body of the message to be sent. """ - return self.client.api.send_message_event( - self.room_id, "m.room.message", self.get_html_content(html, body, msgtype)) + content = self.get_html_content(html, body, msgtype) + if self.encrypted and self.client._encryption: + return self.send_encrypted(content) + else: + return self.client.api.send_message_event( + self.room_id, "m.room.message", content) def set_account_data(self, type, account_data): return self.client.api.set_room_account_data( @@ -142,9 +155,12 @@ def add_tag(self, tag, order=None, content=None): def send_emote(self, text): """Send an emote (/me style) message to the room.""" - return self.client.api.send_emote(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_emote_body(text)) + else: + return self.client.api.send_emote(self.room_id, text) - def send_file(self, url, name, **fileinfo): + def send_file(self, url, name, encryption_info=None, **fileinfo): """Send a pre-uploaded file to the room. See http://matrix.org/docs/spec/r0.2.0/client_server.html#m-file for @@ -155,15 +171,24 @@ def send_file(self, url, name, **fileinfo): name (str): The filename of the image. fileinfo (): Extra information about the file """ - - return self.client.api.send_content( - self.room_id, url, name, "m.file", - extra_information=fileinfo - ) + if self.encrypted and self.client._encryption: + content = self.client.api.get_content_body( + url, name, "m.file", extra_information=fileinfo, + encryption_info=encryption_info + ) + return self.send_encrypted(content) + else: + return self.client.api.send_content( + self.room_id, url, name, "m.file", + extra_information=fileinfo + ) def send_notice(self, text): """Send a notice (from bot) message to the room.""" - return self.client.api.send_notice(self.room_id, text) + if self.encrypted and self.client._encryption: + return self.send_encrypted(self.client.api.get_notice_body(text)) + else: + return self.client.api.send_notice(self.room_id, text) # See http://matrix.org/docs/spec/r0.0.1/client_server.html#m-image for the # imageinfo args. @@ -195,8 +220,13 @@ def send_location(self, geo_uri, name, thumb_url=None, **thumb_info): thumb_url (str): URL to the thumbnail of the location. thumb_info (): Metadata about the thumbnail, type ImageInfo. """ - return self.client.api.send_location(self.room_id, geo_uri, name, - thumb_url, thumb_info) + if self.encrypted and self.client._encryption: + content = self.client.api.get_location_body( + geo_uri, name, thumb_url, thumb_info) + return self.send_encrypted(content) + else: + return self.client.api.send_location(self.room_id, geo_uri, name, + thumb_url, thumb_info) def send_video(self, url, name, **videoinfo): """Send a pre-uploaded video to the room. @@ -226,6 +256,22 @@ def send_audio(self, url, name, **audioinfo): return self.client.api.send_content(self.room_id, url, name, "m.audio", extra_information=audioinfo) + def send_encrypted(self, content): + """Send an arbitrary encrypted message to the room. + + Args: + content (dict): The content of a m.room.message event. + + Raises: + ``MatrixNoEncryptionError`` if encryption is not enabled in client, or if + the room is unencrypted. + """ + if not self.client._encryption: + raise MatrixNoEncryptionError('Encryption is not enabled in client.') + if not self.encrypted: + raise MatrixNoEncryptionError('Encryption is not enabled in the room.') + return self.client.olm_device.send_encrypted_message(self, content) + def redact_message(self, event_id, reason=None): """Redacts the message with specified event_id for the given reason. @@ -296,6 +342,12 @@ def add_state_listener(self, callback, event_type=None): ) def _put_event(self, event): + if self.encrypted and self.client._encryption: + if event['type'] == 'm.room.encrypted': + try: + event = self.client.olm_device.megolm_decrypt_event(event) + except RuntimeError as e: + logger.warning(e) self.events.append(event) if len(self.events) > self.event_history_limit: self.events.pop(0) @@ -493,7 +545,7 @@ def _add_member(self, user_id, displayname=None): if user_id in self.client.users: self._members[user_id] = self.client.users[user_id] return - self._members[user_id] = User(self.client.api, user_id, displayname) + self._members[user_id] = User(self.client, user_id, displayname) self.client.users[user_id] = self._members[user_id] def backfill_previous_messages(self, reverse=False, limit=10): @@ -618,18 +670,35 @@ def set_guest_access(self, allow_guests): except MatrixRequestError: return False - def enable_encryption(self): + def enable_encryption(self, rotation_period_ms=None, rotation_period_msgs=None): """Enables encryption in the room. NOTE: Once enabled, encryption cannot be disabled. + Args: + rotation_period_ms (int): Optional. Lifetime of Megolm sessions in the room. + rotation_period_msgs (int): Optional. Number of messages that can be encrypted + by a Megolm session before it is rotated. + Returns: - True if successful, False if not + ``True`` if successful, ``False`` if not. + If the room was already encrypted, True is returned, but the + ``rotation_period_ms`` and ``rotation_period_ms`` parameters have no effect. """ + if self.encrypted: + return True + + content = {"algorithm": "m.megolm.v1.aes-sha2"} + if rotation_period_ms: + content["rotation_period_ms"] = rotation_period_ms + if rotation_period_msgs: + content["rotation_period_msgs"] = rotation_period_msgs + try: - self.send_state_event("m.room.encryption", - {"algorithm": "m.megolm.v1.aes-sha2"}) + self.send_state_event("m.room.encryption", content) self.encrypted = True + self.rotation_period_ms = rotation_period_ms + self.rotation_period_msgs = rotation_period_msgs return True except MatrixRequestError: return False @@ -658,13 +727,26 @@ def _process_state_event(self, state_event): elif etype == "m.room.encryption": if econtent.get("algorithm") == "m.megolm.v1.aes-sha2": self.encrypted = True + if not self.rotation_period_ms: + self.rotation_period_ms = econtent.get("rotation_period_ms") + if not self.rotation_period_msgs: + self.rotation_period_msgs = econtent.get("rotation_period_msgs") elif etype == "m.room.member" and clevel == clevel.ALL: # tracking room members can be large e.g. #matrix:matrix.org + user_id = state_event["state_key"] if econtent["membership"] == "join": - user_id = state_event["state_key"] self._add_member(user_id, econtent.get("displayname")) + if self.client._encryption and self.encrypted: + # Track the device list of this user + self.client.olm_device.device_list.track_user_no_download(user_id) elif econtent["membership"] in ("leave", "kick", "invite"): - self._members.pop(state_event["state_key"], None) + self._members.pop(user_id, None) + if econtent["membership"] != "invite": + if self.client._encryption and self.encrypted: + # Invalidate any outbound session we have in the room when + # someone leaves + self.client.olm_device.megolm_remove_outbound_session( + self.room_id) for listener in self.state_listeners: if ( diff --git a/matrix_client/user.py b/matrix_client/user.py index e56a89ef..2dfd0299 100644 --- a/matrix_client/user.py +++ b/matrix_client/user.py @@ -15,17 +15,19 @@ from warnings import warn from .checks import check_user_id +from .device import Device class User(object): """ The User class can be used to call user specific functions. """ - def __init__(self, api, user_id, displayname=None): + def __init__(self, client, user_id, displayname=None): check_user_id(user_id) self.user_id = user_id self.displayname = displayname - self.api = api + self.client = client + self._devices = {} def get_display_name(self, room=None): """Get this user's display name. @@ -43,7 +45,7 @@ def get_display_name(self, room=None): except KeyError: return self.user_id if not self.displayname: - self.displayname = self.api.get_display_name(self.user_id) + self.displayname = self.client.api.get_display_name(self.user_id) return self.displayname or self.user_id def get_friendly_name(self): @@ -59,13 +61,13 @@ def set_display_name(self, display_name): display_name (str): Display Name """ self.displayname = display_name - return self.api.set_display_name(self.user_id, display_name) + return self.client.api.set_display_name(self.user_id, display_name) def get_avatar_url(self): - mxcurl = self.api.get_avatar_url(self.user_id) + mxcurl = self.client.api.get_avatar_url(self.user_id) url = None if mxcurl is not None: - url = self.api.get_download_url(mxcurl) + url = self.client.api.get_download_url(mxcurl) return url def set_avatar_url(self, avatar_url): @@ -74,4 +76,33 @@ def set_avatar_url(self, avatar_url): Args: avatar_url (str): mxc url from previously uploaded """ - return self.api.set_avatar_url(self.user_id, avatar_url) + return self.client.api.set_avatar_url(self.user_id, avatar_url) + + @property + def devices(self): + # If this user is joined in an encrypted room with us, we may already have an + # up-to-date list of their devices. + if self.client._encryption and \ + self.user_id in self.client.olm_device.device_list.tracked_user_ids: + + if self.user_id not in self.client.device_keys: + self.client.db.get_device_keys( + self.client.api, {self.user_id: []}, self.client.device_keys + ) + self._devices = self.client.device_keys[self.user_id] + else: + devices = self.client.api.query_keys({self.user_id: []})["device_keys"] + for device_id in devices: + if device_id not in self._devices: + # Do not add the keys even if they are in the payload, because + # we are not able to verify them right know. This means that device + # verification will only become available once we share an encrypted + # room with this user. + self._devices[device_id] = Device(self.client.api, device_id) + + for device in self._devices: + device.get_info() + + # Returning a copy prevents adding/removing devices while allowing to verify or + # blacklist them. + return self._devices.copy() diff --git a/setup.py b/setup.py index 6049f397..6b809d0b 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,8 @@ def exec_file(names): 'test': ['pytest', 'responses'], 'doc': ['Sphinx==1.4.6', 'sphinx-rtd-theme==0.1.9', 'sphinxcontrib-napoleon==0.5.3'], 'format': ['flake8'], - 'e2e': ['python-olm==dev', 'canonicaljson'] + 'e2e': ['python-olm==dev', 'canonicaljson', 'appdirs', 'unpaddedbase64', + 'pycrypto'] }, dependency_links=[ 'git+https://github.com/poljar/python-olm.git@4752eb22f005cb9f6143857008572e6d83252841#egg=python-olm-dev' diff --git a/test/client_test.py b/test/client_test.py index c5884924..472ad195 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -1,16 +1,19 @@ import pytest import responses import json +import matrix_client.client from copy import deepcopy from matrix_client.client import MatrixClient, Room, User, CACHE from matrix_client.api import MATRIX_V2_API_PATH from . import response_examples +from .crypto.dummy_olm_device import OlmDevice try: from urllib import quote except ImportError: from urllib.parse import quote HOSTNAME = "http://example.com" +matrix_client.client.OlmDevice = OlmDevice def test_create_client(): @@ -155,13 +158,30 @@ def test_state_event(): # test encryption room.encrypted = False ev["type"] = "m.room.encryption" - ev["content"] = {"algorithm": "m.megolm.v1.aes-sha2"} + ev["content"] = { + "algorithm": "m.megolm.v1.aes-sha2", + "rotation_period_msgs": 50, + "rotation_period_ms": 100000, + } room._process_state_event(ev) assert room.encrypted + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 # encrypted flag must not be cleared on configuration change ev["content"] = {"algorithm": None} room._process_state_event(ev) assert room.encrypted + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 + # nor should the session parameters be changed + ev["content"] = { + "algorithm": "m.megolm.v1.aes-sha2", + "rotation_period_msgs": 5, + "rotation_period_ms": 10000, + } + room._process_state_event(ev) + assert room.rotation_period_ms == 100000 + assert room.rotation_period_msgs == 50 def test_get_user(): @@ -519,6 +539,15 @@ def test_enable_encryption_in_room(): assert room.enable_encryption() assert room.encrypted + room = client._mkroom(room_id) + assert room.enable_encryption(rotation_period_msgs=1, rotation_period_ms=1) + assert room.rotation_period_msgs == 1 + assert room.rotation_period_ms == 1 + + assert room.enable_encryption(rotation_period_msgs=2) + # The room was already encrypted, we should not have changed its attribute + assert room.rotation_period_msgs == 1 + @responses.activate def test_detect_encryption_state(): @@ -542,6 +571,7 @@ def test_detect_encryption_state(): @responses.activate def test_one_time_keys_sync(): client = MatrixClient(HOSTNAME, encryption=True) + client.first_sync = False sync_url = HOSTNAME + MATRIX_V2_API_PATH + "/sync" sync_response = deepcopy(response_examples.example_sync) payload = {'dummy': 1} diff --git a/test/crypto/__init__.py b/test/crypto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/crypto/crypto_store_test.py b/test/crypto/crypto_store_test.py new file mode 100644 index 00000000..42bc859d --- /dev/null +++ b/test/crypto/crypto_store_test.py @@ -0,0 +1,359 @@ +import pytest +olm = pytest.importorskip("olm") # noqa + +import os +from collections import defaultdict +from tempfile import mkdtemp + +from matrix_client.crypto.crypto_store import CryptoStore +from matrix_client.crypto.olm_device import OlmDevice +from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession +from matrix_client.room import Room +from matrix_client.user import User + + +class TestCryptoStore(object): + + # Initialise a store and test some init code + device_id = 'AUIETSRN' + user_id = '@user:matrix.org' + room_id = '!test:example.com' + room = Room(None, room_id) + user = User(None, user_id, '') + room._members[user_id] = user + db_name = 'test.db' + db_path = mkdtemp() + store_conf = { + 'db_name': db_name, + 'db_path': db_path + } + store = CryptoStore( + user_id, device_id=device_id, db_path=db_path, db_name=db_name) + db_filepath = os.path.join(db_path, db_name) + assert os.path.exists(db_filepath) + store.close() + store = CryptoStore( + user_id, device_id=device_id, db_path=db_path, db_name=db_name) + + @pytest.fixture(autouse=True, scope='class') + def cleanup(self): + yield + os.remove(self.db_filepath) + + @pytest.fixture() + def account(self): + account = self.store.get_olm_account() + if account is None: + account = olm.Account() + self.store.save_olm_account(account) + return account + + @pytest.fixture() + def curve_key(self, account): + return account.identity_keys['curve25519'] + + @pytest.fixture() + def ed_key(self, account): + return account.identity_keys['ed25519'] + + @pytest.fixture() + def device(self): + return OlmDevice(None, self.user_id, self.device_id, store_conf=self.store_conf) + + def test_olm_account_persistence(self): + account = olm.Account() + identity_keys = account.identity_keys + self.store.remove_olm_account() + + # Try to load inexisting account + saved_account = self.store.get_olm_account() + assert saved_account is None + + # Try to load inexisting account without device_id + self.store.device_id = None + with pytest.raises(ValueError): + self.store.get_olm_account() + self.store.device_id = self.device_id + + # Save and load + self.store.save_olm_account(account) + saved_account = self.store.get_olm_account() + assert saved_account.identity_keys == identity_keys + + # Save and load without device_id + self.store.save_olm_account(account) + self.store.device_id = None + saved_account = self.store.get_olm_account() + assert saved_account.identity_keys == identity_keys + assert self.store.device_id == self.device_id + + # Replace the account, causing foreign keys to be deleted + self.store.save_sync_token('test') + self.store.replace_olm_account(account) + assert self.store.get_sync_token() is None + + # Load the account from an OlmDevice + device = OlmDevice(None, self.user_id, self.device_id, store_conf=self.store_conf) + assert device.olm_account.identity_keys == account.identity_keys + + # Load the account from an OlmDevice, without device_id + device = OlmDevice(None, self.user_id, store_conf=self.store_conf) + assert device.device_id == self.device_id + + def test_olm_sessions_persistence(self, account, curve_key, device): + session = olm.OutboundSession(account, curve_key, curve_key) + sessions = defaultdict(list) + + self.store.load_olm_sessions(sessions) + assert not sessions + assert not self.store.get_olm_sessions(curve_key) + + self.store.save_olm_session(curve_key, session) + self.store.load_olm_sessions(sessions) + assert sessions[curve_key][0].id == session.id + + saved_sessions = self.store.get_olm_sessions(curve_key) + assert saved_sessions[0].id == session.id + + sessions.clear() + saved_sessions = self.store.get_olm_sessions(curve_key, sessions) + assert sessions[curve_key][0].id == session.id + + # Replace the session when its internal state has changed + pickle = session.pickle() + session.encrypt('test') + self.store.save_olm_session(curve_key, session) + saved_sessions = self.store.get_olm_sessions(curve_key) + assert saved_sessions[0].pickle != pickle + + # Load sessions dynamically + assert not device.olm_sessions + with pytest.raises(AttributeError): + device._olm_decrypt(None, curve_key) + assert device.olm_sessions[curve_key][0].id == session.id + + device.olm_sessions.clear() + device.device_keys[self.user_id][self.device_id] = device + device.olm_ensure_sessions({self.user_id: [self.device_id]}) + assert device.olm_sessions[curve_key][0].id == session.id + + # Test cascade deletion + self.store.remove_olm_account() + assert not self.store.get_olm_sessions(curve_key) + + def test_megolm_inbound_persistence(self, curve_key, ed_key, device): + out_session = olm.OutboundGroupSession() + session = MegolmInboundSession(out_session.session_key, ed_key) + sessions = defaultdict(lambda: defaultdict(dict)) + + self.store.load_inbound_sessions(sessions) + assert not sessions + assert not self.store.get_inbound_session(self.room_id, curve_key, session.id) + + self.store.save_inbound_session(self.room_id, curve_key, session) + self.store.load_inbound_sessions(sessions) + assert sessions[self.room_id][curve_key][session.id].id == session.id + + saved_session = self.store.get_inbound_session(self.room_id, curve_key, + session.id) + assert saved_session.id == session.id + + sessions = {} + saved_session = self.store.get_inbound_session(self.room_id, curve_key, + session.id, sessions) + assert sessions[session.id].id == session.id + + assert not device.megolm_inbound_sessions + created = device.megolm_add_inbound_session( + self.room_id, curve_key, ed_key, session.id, out_session.session_key) + assert not created + assert device.megolm_inbound_sessions[self.room_id][curve_key][session.id].id == \ + session.id + + device.megolm_inbound_sessions.clear() + content = { + 'sender_key': curve_key, + 'session_id': session.id, + 'algorithm': device._megolm_algorithm, + 'device_id': '' + } + event = { + 'sender': '', + 'room_id': self.room_id, + 'content': content + } + with pytest.raises(KeyError): + device.megolm_decrypt_event(event) + assert device.megolm_inbound_sessions[self.room_id][curve_key][session.id].id == \ + session.id + + self.store.remove_olm_account() + assert not self.store.get_inbound_session(self.room_id, curve_key, session.id) + + @pytest.mark.usefixtures('account') + def test_megolm_outbound_persistence(self, device): + session = MegolmOutboundSession(max_messages=2, max_age=100000) + session.message_count = 1 + session.add_device(self.device_id) + sessions = {} + + self.store.load_outbound_sessions(sessions) + assert not sessions + assert not self.store.get_outbound_session(self.room_id) + + self.store.save_outbound_session(self.room_id, session) + self.store.save_megolm_outbound_devices(self.room_id, {self.device_id}) + self.store.load_outbound_sessions(sessions) + assert sessions[self.room_id].id == session.id + assert sessions[self.room_id].devices == session.devices + assert sessions[self.room_id].creation_time == session.creation_time + assert sessions[self.room_id].max_messages == session.max_messages + assert sessions[self.room_id].message_count == session.message_count + assert sessions[self.room_id].max_age == session.max_age + + saved_session = self.store.get_outbound_session(self.room_id) + assert saved_session.id == session.id + assert saved_session.devices == session.devices + assert saved_session.creation_time == session.creation_time + assert saved_session.max_messages == session.max_messages + assert saved_session.message_count == session.message_count + assert saved_session.max_age == session.max_age + + sessions.clear() + saved_session = self.store.get_outbound_session(self.room_id, sessions) + assert sessions[self.room_id].id == session.id + + self.store.remove_outbound_session(self.room_id) + assert not self.store.get_outbound_session(self.room_id) + + self.store.save_outbound_session(self.room_id, session) + saved_session = self.store.get_outbound_session(self.room_id) + # Verify the saved devices have been erased with the session + assert not saved_session.devices + + room = Room(None, self.room_id) + with pytest.raises(AttributeError): + device.megolm_build_encrypted_event(room, {}) + assert device.megolm_outbound_sessions[self.room_id].id == session.id + + self.store.remove_olm_account() + assert not self.store.get_outbound_session(self.room_id) + + @pytest.mark.usefixtures('account') + def test_device_keys_persistence(self, device): + user_devices = {self.user_id: [self.device_id]} + device_keys = defaultdict(dict) + device._verified = True + + self.store.load_device_keys(None, device_keys) + assert not device_keys + assert not self.store.get_device_keys(None, user_devices, device_keys) + assert not device_keys + + device_keys_to_save = {self.user_id: {self.device_id: device}} + self.store.save_device_keys(device_keys_to_save) + self.store.load_device_keys(None, device_keys) + assert device_keys[self.user_id][self.device_id].curve25519 == \ + device.curve25519 + assert device_keys[self.user_id][self.device_id].verified + + device_keys.clear() + devices = self.store.get_device_keys(None, user_devices)[self.user_id] + assert devices[self.device_id].curve25519 == device.curve25519 + assert self.store.get_device_keys(None, user_devices, device_keys) + assert device_keys[self.user_id][self.device_id].curve25519 == \ + device.curve25519 + assert device_keys[self.user_id][self.device_id].verified + + # Test device verification persistence + device.verified = False + device.ignored = True + devices = self.store.get_device_keys(None, user_devices)[self.user_id] + assert not devices[self.device_id].verified + assert devices[self.device_id].ignored + + # Test [] wildcard + devices = self.store.get_device_keys(None, {self.user_id: []})[self.user_id] + assert devices[self.device_id].curve25519 == device.curve25519 + + device.device_list.tracked_user_ids = {self.user_id} + device.device_list.get_room_device_keys(self.room) + assert device_keys[self.user_id][self.device_id].curve25519 == \ + device.curve25519 + + # Test multiples [] + device_keys.clear() + user_id = 'test' + device_id = 'test' + device_keys_to_save[user_id] = {device_id: device} + self.store.save_device_keys(device_keys_to_save) + user_devices[user_id] = [] + user_devices[self.user_id] = [] + device_keys = self.store.get_device_keys(None, user_devices) + assert device_keys[self.user_id][self.device_id].curve25519 == device.curve25519 + assert device_keys[user_id][device_id].curve25519 == device.curve25519 + + # Try to verify a device that has no keys + device._ed25519 = None + with pytest.raises(ValueError): + device.verified = False + + self.store.remove_olm_account() + assert not self.store.get_device_keys(None, user_devices) + + @pytest.mark.usefixtures('account') + def test_tracked_users_persistence(self): + tracked_user_ids = set() + tracked_user_ids_to_save = {self.user_id} + + self.store.load_tracked_users(tracked_user_ids) + assert not tracked_user_ids + + self.store.save_tracked_users(tracked_user_ids_to_save) + self.store.load_tracked_users(tracked_user_ids) + assert tracked_user_ids == tracked_user_ids_to_save + + self.store.remove_tracked_users({self.user_id}) + tracked_user_ids.clear() + self.store.load_tracked_users(tracked_user_ids) + assert not tracked_user_ids + + @pytest.mark.usefixtures('account') + def test_sync_token_persistence(self): + sync_token = 'test' + + assert not self.store.get_sync_token() + + self.store.save_sync_token(sync_token) + assert self.store.get_sync_token() == sync_token + + sync_token = 'new' + self.store.save_sync_token(sync_token) + assert self.store.get_sync_token() == sync_token + + def test_load_all(self, account, curve_key, ed_key, device): + curve_key = account.identity_keys['curve25519'] + session = olm.OutboundSession(account, curve_key, curve_key) + out_session = MegolmOutboundSession() + out_session.add_device(self.device_id) + in_session = MegolmInboundSession(out_session.session_key, ed_key) + device_keys_to_save = {self.user_id: {self.device_id: device}} + + self.store.save_inbound_session(self.room_id, curve_key, in_session) + self.store.save_olm_session(curve_key, session) + self.store.save_outbound_session(self.room_id, out_session) + self.store.save_megolm_outbound_devices(self.room_id, {self.device_id}) + self.store.save_device_keys(device_keys_to_save) + + device = OlmDevice( + None, self.user_id, self.device_id, store_conf=self.store_conf, load_all=True) + + assert session.id in {s.id for s in device.olm_sessions[curve_key]} + saved_in_session = \ + device.megolm_inbound_sessions[self.room_id][curve_key][in_session.id] + assert saved_in_session.id == in_session.id + saved_out_session = device.megolm_outbound_sessions[self.room_id] + assert saved_out_session.id == out_session.id + assert saved_out_session.devices == out_session.devices + assert device.device_keys[self.user_id][self.device_id].curve25519 == \ + device.curve25519 diff --git a/test/crypto/device_list_test.py b/test/crypto/device_list_test.py new file mode 100644 index 00000000..bd0bb6df --- /dev/null +++ b/test/crypto/device_list_test.py @@ -0,0 +1,312 @@ +import pytest +pytest.importorskip("olm") # noqa + +import json +from copy import deepcopy +from threading import Event, Condition + +import responses + +from matrix_client.api import MATRIX_V2_API_PATH +from matrix_client.client import MatrixClient +from matrix_client.device import Device +from matrix_client.room import User +from matrix_client.errors import MatrixRequestError +from matrix_client.crypto.device_list import (_OutdatedUsersSet as OutdatedUsersSet, + _UpdateDeviceList as UpdateDeviceList) +from test.crypto.dummy_olm_device import OlmDevice, DummyStore +from test.response_examples import example_key_query_response + +HOSTNAME = 'http://example.com' + + +class TestDeviceList: + cli = MatrixClient(HOSTNAME) + user_id = '@test:example.com' + alice = '@alice:example.com' + room_id = '!test:example.com' + device_id = 'AUIETSRN' + device = OlmDevice(cli.api, user_id, device_id) + device_list = device.device_list + signing_key = device.olm_account.identity_keys['ed25519'] + query_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/query' + + @responses.activate + def test_download_device_keys(self): + # The method we want to test + download_device_keys = self.device_list._download_device_keys + bob = '@bob:example.com' + eve = '@eve:example.com' + user_devices = {self.alice: [], bob: [], self.user_id: []} + + # This response is correct for Alice's keys, but lacks Bob's + # There are no failures + resp = example_key_query_response + responses.add(responses.POST, self.query_url, json=resp) + + # Still correct, but Alice's identity key has changed + resp = deepcopy(example_key_query_response) + new_id_key = 'ijxGZqwB/UvMtKABdaCdrI0OtQI6NhHBYiknoCkdWng' + payload = resp['device_keys'][self.alice]['JLAFKJWSCS'] + payload['keys']['curve25519:JLAFKJWSCS'] = new_id_key + payload['signatures'][self.alice]['ed25519:JLAFKJWSCS'] = \ + ('D9oLtYefMIr4StiHTIzn3+bhtPCfrZNDU9jsUbMu3MicfZLl4d8WlYn3TPmbwDi8XMGcT' + 'nNnqfdi/tYUPvKfCA') + responses.add(responses.POST, self.query_url, json=resp) + + # Still correct, but Alice's signing key has changed + alice_device = OlmDevice(self.cli.api, self.alice, 'JLAFKJWSCS') + resp = deepcopy(example_key_query_response) + resp['device_keys'][self.alice]['JLAFKJWSCS']['keys']['ed25519:JLAFKJWSCS'] = \ + alice_device.ed25519 + resp['device_keys'][self.alice]['JLAFKJWSCS'] = \ + alice_device.sign_json(resp['device_keys'][self.alice]['JLAFKJWSCS']) + responses.add(responses.POST, self.query_url, json=resp) + + # Response containing an unknown user + resp = deepcopy(example_key_query_response) + user_device = resp['device_keys'].pop(self.alice) + resp['device_keys'][eve] = user_device + responses.add(responses.POST, self.query_url, json=resp) + + # Response with an invalid signature + resp = deepcopy(example_key_query_response) + resp['device_keys'][self.alice]['JLAFKJWSCS']['test'] = 1 + responses.add(responses.POST, self.query_url, json=resp) + + # Response with a requested user and valid signature, but with a mismatch + resp = deepcopy(example_key_query_response) + user_device = resp['device_keys'].pop(self.alice) + resp['device_keys'][bob] = user_device + responses.add(responses.POST, self.query_url, json=resp) + + # Response with an invalid keys field + resp = deepcopy(example_key_query_response) + keys_field = resp['device_keys'][self.alice]['JLAFKJWSCS']['keys'] + key = keys_field.pop("ed25519:JLAFKJWSCS") + keys_field["ed25519:wrong"] = key + # Cover a missing branch by adding failures + resp["failures"]["other.com"] = {} + # And one more by adding ourself + resp['device_keys'][self.user_id] = {self.device_id: 'dummy'} + responses.add(responses.POST, self.query_url, json=resp) + + self.device.device_keys.clear() + assert download_device_keys(user_devices) + req = json.loads(responses.calls[0].request.body) + assert req['device_keys'] == {self.alice: [], bob: [], self.user_id: []} + device = Device(self.cli.api, self.alice, 'JLAFKJWSCS', database=DummyStore, + curve25519_key='3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI', + ed25519_key='VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA') + expected_device_keys = { + self.alice: { + 'JLAFKJWSCS': device + } + } + assert self.device.device_keys[self.alice]['JLAFKJWSCS'].curve25519 == \ + device.curve25519 + + # Different curve25519, key should get updated + assert download_device_keys(user_devices) + expected_device_keys[self.alice]['JLAFKJWSCS']._curve25519 = new_id_key + assert self.device.device_keys[self.alice]['JLAFKJWSCS'].curve25519 == \ + device.curve25519 + + # Different ed25519, key should not get updated + assert not download_device_keys(user_devices) + assert self.device.device_keys[self.alice]['JLAFKJWSCS'].ed25519 == \ + device.ed25519 + + self.device.device_keys.clear() + # All the remaining responses are wrong and we should not add the key + for _ in range(4): + assert not download_device_keys(user_devices) + assert self.device.device_keys == {} + + assert len(responses.calls) == 7 + + @responses.activate + def test_update_thread(self): + # Normal run + event = Event() + outdated_users = OutdatedUsersSet({self.user_id}) + outdated_users.events.add(event) + + def dummy_download(user_devices, since_token=None): + assert user_devices == {self.user_id: []} + return + thread = UpdateDeviceList(Condition(), outdated_users, dummy_download, set(), + DummyStore()) + + thread.start() + event.wait() + assert not thread.outdated_user_ids + assert thread.event.is_set() + assert thread.tracked_user_ids == {self.user_id} + thread.join() + assert not thread.is_alive() + + # Error run + outdated_users = OutdatedUsersSet({self.user_id}) + + def error_on_first_download(user_devices, since_token=None): + error_on_first_download.c += 1 + if error_on_first_download.c == 1: + raise MatrixRequestError + return + error_on_first_download.c = 0 + thread = UpdateDeviceList( + Condition(), outdated_users, error_on_first_download, set(), DummyStore()) + thread.start() + thread.event.wait() + assert error_on_first_download.c == 2 + assert not thread.outdated_user_ids + thread.join() + + # Cover a missing branch + thread = UpdateDeviceList( + Condition(), outdated_users, error_on_first_download, set(), DummyStore()) + thread._should_terminate.set() + thread.start() + thread.join() + assert not thread.is_alive() + + @responses.activate + def test_get_room_device_keys(self): + self.device_list.tracked_user_ids.clear() + room = self.cli._mkroom(self.room_id) + room._members[self.alice] = User(self.cli.api, self.alice) + + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + # Blocking + self.device_list.get_room_device_keys(room) + assert self.device_list.tracked_user_ids == {self.alice} + assert self.device_list.device_keys[self.alice]['JLAFKJWSCS'] + + # Same, but we already track the user + self.device_list.get_room_device_keys(room) + + # Non-blocking + self.device_list.tracked_user_ids.clear() + # We have to block for testing purposes, though + self.device_list.update_thread.event.clear() + self.device_list.get_room_device_keys(room, blocking=False) + self.device_list.update_thread.event.wait() + + # Same, but we already track the user + self.device_list.get_room_device_keys(room, blocking=False) + + @responses.activate + def test_track_users(self): + self.device_list.tracked_user_ids.clear() + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + self.device_list.update_thread.event.clear() + self.device_list.track_users({self.alice}) + self.device_list.update_thread.event.wait() + assert self.device_list.tracked_user_ids == {self.alice} + assert len(responses.calls) == 1 + + # Same, but we are already tracking Alice + self.device_list.track_users({self.alice}) + assert len(responses.calls) == 1 + + def test_stop_tracking_users(self): + self.device_list.tracked_user_ids.clear() + self.device_list.tracked_user_ids.add(self.alice) + self.device_list.outdated_user_ids.clear() + self.device_list.outdated_user_ids.add(self.alice) + + self.device_list.stop_tracking_users({self.alice}) + + assert not self.device_list.tracked_user_ids + assert not self.device_list.outdated_user_ids + + def test_pending_users(self): + # Say Alice is already tracked to avoid triggering dowload process + self.device_list.tracked_user_ids.add(self.alice) + + self.device_list.track_user_no_download(self.alice) + assert self.alice in self.device_list.pending_outdated_user_ids + + self.device_list.track_pending_users() + assert self.alice not in self.device_list.pending_outdated_user_ids + + @responses.activate + def test_update_user_device_keys(self): + self.device_list.tracked_user_ids.clear() + responses.add(responses.POST, self.query_url, json=example_key_query_response) + + self.device_list.update_user_device_keys({self.alice}) + assert len(responses.calls) == 0 + + self.device_list.tracked_user_ids.add(self.alice) + + self.device_list.update_thread.event.clear() + self.device_list.update_user_device_keys({self.alice}, since_token='dummy') + self.device_list.update_thread.event.wait() + assert len(responses.calls) == 1 + + @responses.activate + def test_update_after_restart(self): + keys_changes_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/changes' + + class DB(DummyStore): + + def __getattribute__(self, name): + if name == 'get_sync_token': + return lambda: 'test' + return super(DB, self).__getattribute__(name) + db = self.device_list.db + + # First launch, no sync token + self.device_list.update_after_restart('test') + + self.device_list.db = DB() + responses.add(responses.GET, keys_changes_url, json={}) + self.device_list.update_after_restart('test') + + resp = {'left': 'test', 'changed': self.user_id} + responses.replace(responses.GET, keys_changes_url, json=resp) + self.device_list.tracked_user_ids.clear() + self.device_list.update_after_restart('test') + self.device_list.db = db + + +def test_outdated_users_set(): + s = OutdatedUsersSet() + assert not s + + s = OutdatedUsersSet({1}) + event = Event() + s.events.add(event) + assert s == {1} + + # Make a manual copy of s + t = OutdatedUsersSet() + t.add(1) + t.events.add(event) + assert t == s and t.events == s.events + + u = s.copy() + event2 = Event() + u.add(2) + u.events.add(event2) + # Check that modifying u didn't change s + assert t == s and t.events == s.events + + s.update(u) + assert s == {1, 2} and s.events == {event, event2} + + s.mark_as_processed() + assert event.is_set() + + new = 's72594_4483_1935' + s.sync_token = new + old = 's72594_4483_1934' + s.sync_token = old + assert s.sync_token == new + + s.clear() + assert not s and not s.events diff --git a/test/crypto/dummy_olm_device.py b/test/crypto/dummy_olm_device.py new file mode 100644 index 00000000..83c50c6c --- /dev/null +++ b/test/crypto/dummy_olm_device.py @@ -0,0 +1,21 @@ +"""Tests can import OlmDevice from here, and know it won't try to use a database.""" + +from matrix_client.crypto.crypto_store import CryptoStore +from matrix_client.crypto.olm_device import OlmDevice as BaseOlmDevice + + +class DummyStore(CryptoStore): + def __init__(*args, **kw): pass + + def nop(*args, **kw): pass + + def __getattribute__(self, name): + if name in dir(CryptoStore): + return object.__getattribute__(self, 'nop') + raise AttributeError + + +class OlmDevice(BaseOlmDevice): + + def __init__(self, *args, **kw): + super(OlmDevice, self).__init__(*args, Store=DummyStore, **kw) diff --git a/test/crypto/encrypted_attachments_test.py b/test/crypto/encrypted_attachments_test.py new file mode 100644 index 00000000..ee13cf9f --- /dev/null +++ b/test/crypto/encrypted_attachments_test.py @@ -0,0 +1,15 @@ +import pytest +pytest.importorskip('olm') # noqa + +from matrix_client.crypto.encrypt_attachments import (encrypt_attachment, + decrypt_attachment) + + +def test_encrypt_decrypt(): + message = b'test' + ciphertext, info = encrypt_attachment(message) + assert decrypt_attachment(ciphertext, info) == message + + ciphertext += b'\x00' + with pytest.raises(RuntimeError): + decrypt_attachment(ciphertext, info) diff --git a/test/crypto/olm_device_test.py b/test/crypto/olm_device_test.py index 84f56df6..a315549d 100644 --- a/test/crypto/olm_device_test.py +++ b/test/crypto/olm_device_test.py @@ -1,15 +1,29 @@ import pytest -pytest.importorskip("olm") # noqa +olm = pytest.importorskip("olm") # noqa import json +import logging from copy import deepcopy +from datetime import timedelta, datetime +try: + from urllib import quote +except ImportError: + from urllib.parse import quote import responses +from matrix_client.crypto import olm_device +from matrix_client.crypto.verified_event import VerifiedEvent from matrix_client.api import MATRIX_V2_API_PATH from matrix_client.client import MatrixClient -from matrix_client.crypto.olm_device import OlmDevice -from test.response_examples import example_key_upload_response +from matrix_client.user import User +from matrix_client.device import Device +from matrix_client.errors import E2EUnknownDevices +from test.crypto.dummy_olm_device import OlmDevice, DummyStore +from matrix_client.crypto.sessions import MegolmOutboundSession, MegolmInboundSession +from test.response_examples import (example_key_upload_response, + example_claim_keys_response, + example_room_key_event) HOSTNAME = 'http://example.com' @@ -17,9 +31,22 @@ class TestOlmDevice: cli = MatrixClient(HOSTNAME) user_id = '@user:matrix.org' + room_id = '!test:example.com' device_id = 'QBUAZIFURK' device = OlmDevice(cli.api, user_id, device_id) signing_key = device.olm_account.identity_keys['ed25519'] + alice = '@alice:example.com' + alice_device_id = 'JLAFKJWSCS' + alice_curve_key = 'mmFRSHuJVq3aTudx3KB3w5ZvSFQhgEcy8d+m+vkEfUQ' + alice_ed_key = '4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc' + alice_device = Device(cli.api, alice, alice_device_id, database=DummyStore(), + curve25519_key=alice_curve_key, ed25519_key=alice_ed_key) + alice_olm_session = olm.OutboundSession( + device.olm_account, alice_curve_key, alice_curve_key) + room = cli._mkroom(room_id) + room._members[alice] = User(cli.api, alice) + # allow to_device api call to work well with responses + device.api._make_txn_id = lambda: 1 def test_sign_json(self): example_payload = { @@ -204,3 +231,546 @@ def test_invalid_keys_threshold(self, threshold): self.user_id, self.device_id, keys_threshold=threshold) + + @responses.activate + def test_olm_start_sessions(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + self.device.olm_sessions.clear() + self.device.device_keys.clear() + + user_devices = {self.alice: {self.alice_device_id}} + + # We don't have alice's keys + self.device.olm_start_sessions(user_devices) + assert not self.device.olm_sessions[self.alice_curve_key] + + # Cover logging part + olm_device.logger.setLevel(logging.WARNING) + # Now should be good + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + self.device.olm_start_sessions(user_devices) + assert self.device.olm_sessions[self.alice_curve_key] + + # With failures and wrong signature + self.device.olm_sessions.clear() + payload = deepcopy(example_claim_keys_response) + payload['failures'] = {'dummy': 1} + key = payload['one_time_keys'][self.alice][self.alice_device_id] + key['signed_curve25519:AAAAAQ']['test'] = 1 + responses.replace(responses.POST, claim_url, json=payload) + + self.device.olm_start_sessions(user_devices) + assert not self.device.olm_sessions[self.alice_curve_key] + + # Missing requested user and devices + user_devices[self.alice].add('test') + user_devices['test'] = 'test' + + self.device.olm_start_sessions(user_devices) + + @responses.activate + def test_olm_build_encrypted_event(self): + self.device.device_keys.clear() + self.device.olm_sessions.clear() + event_content = {'dummy': 'example'} + + # We don't have Alice's keys + with pytest.raises(RuntimeError): + self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + # We don't have a session with Alice + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + with pytest.raises(RuntimeError): + self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + user_devices = {self.alice: {self.alice_device_id}} + self.device.olm_start_sessions(user_devices) + assert self.device.olm_build_encrypted_event( + 'm.text', event_content, self.alice, self.alice_device_id) + + def test_olm_decrypt(self): + self.device.olm_sessions.clear() + # Since this method doesn't care about high-level event formatting, we will + # generate things at low level + our_account = self.device.olm_account + # Alice needs to start a session with us + alice = olm.Account() + sender_key = alice.identity_keys['curve25519'] + our_account.generate_one_time_keys(1) + otk = next(iter(our_account.one_time_keys['curve25519'].values())) + self.device.olm_account.mark_keys_as_published() + session = olm.OutboundSession(alice, our_account.identity_keys['curve25519'], otk) + + plaintext = {"test": "test"} + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # New pre-key message, but the session exists this time + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # Try to decrypt the same message twice + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Answer Alice in order to have a type 1 message + message = self.device.olm_sessions[sender_key][0].encrypt(json.dumps(plaintext)) + session.decrypt(message) + message = session.encrypt(json.dumps(plaintext)) + assert self.device._olm_decrypt(message, sender_key) == plaintext + + # Try to decrypt the same message type 1 twice + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt a message from a session that reused a one-time key + otk_reused_session = olm.OutboundSession( + alice, our_account.identity_keys['curve25519'], otk) + message = otk_reused_session.encrypt(json.dumps(plaintext)) + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt an invalid type 0 message + our_account.generate_one_time_keys(1) + otk = next(iter(our_account.one_time_keys['curve25519'].values())) + wrong_session = olm.OutboundSession(alice, sender_key, otk) + message = wrong_session.encrypt(json.dumps(plaintext)) + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + # Try to decrypt a type 1 message for which we have no sessions + message = session.encrypt(json.dumps(plaintext)) + self.device.olm_sessions.clear() + with pytest.raises(RuntimeError): + self.device._olm_decrypt(message, sender_key) + + def test_olm_decrypt_event(self): + self.device.device_keys.clear() + self.device.olm_sessions.clear() + alice_device = OlmDevice(self.device.api, self.alice, self.alice_device_id) + alice_device.device_keys[self.user_id][self.device_id] = self.device + self.device.device_keys[self.alice][self.alice_device_id] = alice_device + + # Artificially start an Olm session from Alice + self.device.olm_account.generate_one_time_keys(1) + otk = next(iter(self.device.olm_account.one_time_keys['curve25519'].values())) + self.device.olm_account.mark_keys_as_published() + sender_key = self.device.curve25519 + session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) + alice_device.olm_sessions[sender_key] = [session] + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + + # Now we can test + self.device.olm_decrypt_event(encrypted_event, self.alice) + + # Device verification + alice_device.verified = True + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.olm_decrypt_event(encrypted_event, self.alice) + + # The signing_key is wrong + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.device_keys[self.alice][self.alice_device_id]._ed25519 = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, self.alice) + + # We do not have the keys + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.device_keys[self.alice].clear() + self.device.olm_decrypt_event(encrypted_event, self.alice) + self.device.device_keys[self.alice][self.alice_device_id] = alice_device + alice_device.verified = False + + # Type 1 Olm payload + alice_device.olm_decrypt_event( + self.device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.alice, self.alice_device_id + ), + self.user_id) + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.olm_decrypt_event(encrypted_event, self.alice) + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, 'wrong') + + wrong_event = deepcopy(encrypted_event) + wrong_event['algorithm'] = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(wrong_event, self.alice) + + wrong_event = deepcopy(encrypted_event) + wrong_event['ciphertext'] = {} + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(wrong_event, self.alice) + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + self.device.user_id = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, self.alice) + self.device.user_id = self.user_id + + encrypted_event = alice_device.olm_build_encrypted_event( + 'example_type', {'content': 'test'}, self.user_id, self.device_id) + backup = self.device.ed25519 + self.device._ed25519 = 'wrong' + with pytest.raises(RuntimeError): + self.device.olm_decrypt_event(encrypted_event, self.alice) + self.device._ed25519 = backup + + @responses.activate + def test_olm_ensure_sessions(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + self.device.olm_sessions.clear() + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + user_devices = {self.alice: [self.alice_device_id]} + + self.device.olm_ensure_sessions(user_devices) + assert self.device.olm_sessions[self.alice_curve_key] + assert len(responses.calls) == 1 + + self.device.olm_ensure_sessions(user_devices) + assert len(responses.calls) == 1 + + @responses.activate + def test_megolm_share_session(self): + claim_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/claim' + responses.add(responses.POST, claim_url, json=example_claim_keys_response) + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.olm_sessions.clear() + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + self.device.device_keys['dummy']['dummy'] = \ + Device(self.cli.api, 'dummy', 'dummy', curve25519_key='a', ed25519_key='a') + user_devices = {self.alice: [self.alice_device_id], 'dummy': ['dummy']} + session = MegolmOutboundSession() + + # Sharing with Alice should succeed, but dummy will fail + self.device.megolm_share_session(self.room_id, user_devices, session) + assert session.devices == {self.alice_device_id, 'dummy'} + + req = json.loads(responses.calls[1].request.body)['messages'] + assert self.alice in req + assert 'dummy' not in req + + @responses.activate + def test_megolm_start_session(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + self.device.device_list.tracked_user_ids.add(self.alice) + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + user_devices = {self.alice: [self.alice_device_id]} + + self.device.megolm_start_session(self.room, user_devices) + session = self.device.megolm_outbound_sessions[self.room_id] + assert self.alice_device_id in session.devices + + # Check that we can decrypt our own messages + plaintext = { + 'type': 'test', + 'content': {'test': 'test'}, + } + encrypted_event = self.device.megolm_build_encrypted_event(self.room, plaintext) + event = { + 'sender': self.alice, + 'room_id': self.room_id, + 'content': encrypted_event, + 'type': 'm.room.encrypted', + 'origin_server_ts': 1, + 'event_id': 1 + } + self.device.megolm_decrypt_event(event) + assert event['content'] == plaintext['content'] + + @responses.activate + def test_megolm_share_session_with_new_devices(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + session = MegolmOutboundSession() + self.device.megolm_outbound_sessions[self.room_id] = session + user_devices = {self.alice: [self.alice_device_id]} + + self.device.megolm_share_session_with_new_devices( + self.room, user_devices, session) + assert self.alice_device_id in session.devices + assert len(responses.calls) == 1 + + self.device.megolm_share_session_with_new_devices( + self.room, user_devices, session) + assert len(responses.calls) == 1 + + def test_megolm_get_recipients(self): + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + + user_devices, _ = self.device.megolm_get_recipients(self.room) + assert user_devices == {self.alice: [self.alice_device_id]} + + self.device.megolm_outbound_sessions.clear() + session = MegolmOutboundSession() + self.device.megolm_outbound_sessions[self.room_id] = session + + user_devices, removed = self.device.megolm_get_recipients(self.room, session) + assert user_devices == {self.alice: [self.alice_device_id]} and not removed + + self.alice_device.blacklisted = True + _, removed = self.device.megolm_get_recipients(self.room, session) + assert not removed + session.add_device(self.alice_device_id) + _, removed = self.device.megolm_get_recipients(self.room, session) + assert removed and self.room_id not in self.device.megolm_outbound_sessions + self.alice_device.blacklisted = False + + self.room.verify_devices = True + with pytest.raises(E2EUnknownDevices) as e: + self.device.megolm_get_recipients(self.room) + assert e.value.user_devices == {self.alice: [self.alice_device]} + self.room.verify_devices = False + + @responses.activate + def test_megolm_build_encrypted_event(self): + to_device_url = HOSTNAME + MATRIX_V2_API_PATH + '/sendToDevice/m.room.encrypted/1' + responses.add(responses.PUT, to_device_url, json={}) + self.device.megolm_outbound_sessions.clear() + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + self.device.device_list.tracked_user_ids.add(self.alice) + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + event = {'type': 'm.room.message', 'content': {'body': 'test'}} + + self.room.rotation_period_msgs = 1 + self.device.megolm_build_encrypted_event(self.room, event) + + self.device.megolm_build_encrypted_event(self.room, event) + + session = self.device.megolm_outbound_sessions[self.room_id] + session.encrypt('test') + self.device.megolm_build_encrypted_event(self.room, event) + assert self.device.megolm_outbound_sessions[self.room_id].id != session.id + + def test_megolm_remove_outbound_session(self): + session = MegolmOutboundSession() + self.device.megolm_outbound_sessions[self.room_id] = session + self.device.megolm_remove_outbound_session(self.room_id) + self.device.megolm_remove_outbound_session(self.room_id) + + @responses.activate + def test_send_encrypted_message(self): + message_url = HOSTNAME + MATRIX_V2_API_PATH + \ + '/rooms/{}/send/m.room.encrypted/1'.format(quote(self.room.room_id)) + responses.add(responses.PUT, message_url, json={}) + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + self.device.olm_sessions[self.alice_curve_key] = [self.alice_olm_session] + session = MegolmOutboundSession() + session.add_device(self.alice_device_id) + self.device.megolm_outbound_sessions[self.room_id] = session + + self.device.send_encrypted_message(self.room, {'test': 'test'}) + + def test_megolm_add_inbound_session(self): + session = MegolmOutboundSession() + self.device.megolm_inbound_sessions.clear() + + assert not self.device.megolm_add_inbound_session( + self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, 'wrong') + assert self.device.megolm_add_inbound_session( + self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, + session.session_key + ) + assert session.id in \ + self.device.megolm_inbound_sessions[self.room_id][self.alice_curve_key] + assert not self.device.megolm_add_inbound_session( + self.room_id, self.alice_curve_key, self.alice_ed_key, session.id, + session.session_key + ) + assert not self.device.megolm_add_inbound_session( + self.room_id, self.alice_curve_key, self.alice_ed_key, 'wrong', + session.session_key + ) + + def test_handle_room_key_event(self): + self.device.megolm_inbound_sessions.clear() + + self.device.handle_room_key_event(example_room_key_event, self.alice_curve_key) + assert self.room_id in self.device.megolm_inbound_sessions + + self.device.handle_room_key_event(example_room_key_event, self.alice_curve_key) + + event = deepcopy(example_room_key_event) + event['content']['algorithm'] = 'wrong' + self.device.handle_room_key_event(event, self.alice_curve_key) + + event = deepcopy(example_room_key_event) + event['content']['session_id'] = 'wrong' + self.device.handle_room_key_event(event, self.alice_curve_key) + + def test_olm_handle_encrypted_event(self): + self.device.olm_sessions.clear() + alice_device = OlmDevice(self.device.api, self.alice, self.alice_device_id) + alice_device.device_keys[self.user_id][self.device_id] = self.device + self.device.device_keys[self.alice][self.alice_device_id] = alice_device + + # Artificially start an Olm session from Alice + self.device.olm_account.generate_one_time_keys(1) + otk = next(iter(self.device.olm_account.one_time_keys['curve25519'].values())) + self.device.olm_account.mark_keys_as_published() + sender_key = self.device.curve25519 + session = olm.OutboundSession(alice_device.olm_account, sender_key, otk) + alice_device.olm_sessions[sender_key] = [session] + + content = example_room_key_event['content'] + encrypted_event = alice_device.olm_build_encrypted_event( + 'm.room_key', content, self.user_id, self.device_id) + event = { + 'type': 'm.room.encrypted', + 'content': encrypted_event, + 'sender': self.alice + } + + self.device.olm_handle_encrypted_event(event) + + # Decrypting the same event twice will trigger an error + self.device.olm_handle_encrypted_event(event) + + encrypted_event = alice_device.olm_build_encrypted_event( + 'm.other', content, self.user_id, self.device_id) + event = { + 'type': 'm.room.encrypted', + 'content': encrypted_event, + 'sender': self.alice + } + self.device.olm_handle_encrypted_event(event) + + # Simulate redacted event + event['content'].pop('algorithm') + self.device.olm_handle_encrypted_event(event) + + def test_megolm_decrypt_event(self): + out_session = MegolmOutboundSession() + + plaintext = { + 'content': {"test": "test"}, + 'type': 'm.text', + } + ciphertext = out_session.encrypt(json.dumps(plaintext)) + + content = { + 'ciphertext': ciphertext, + 'session_id': out_session.id, + 'sender_key': self.alice_curve_key, + 'algorithm': 'm.megolm.v1.aes-sha2', + 'device_id': self.alice_device_id, + } + + event = { + 'sender': self.alice, + 'room_id': self.room_id, + 'content': content, + 'type': 'm.room.encrypted', + 'origin_server_ts': 1, + 'event_id': 1 + } + + with pytest.raises(RuntimeError): + self.device.megolm_decrypt_event(event) + + session_key = out_session.session_key + in_session = MegolmInboundSession(session_key, self.alice_ed_key) + sessions = self.device.megolm_inbound_sessions[self.room_id] + sessions[self.alice_curve_key][in_session.id] = in_session + + # Unknown message index + with pytest.raises(RuntimeError): + self.device.megolm_decrypt_event(event) + + ciphertext = out_session.encrypt(json.dumps(plaintext)) + event['content']['ciphertext'] = ciphertext + self.device.megolm_decrypt_event(event) + assert event['content'] == plaintext['content'] + + # No replay attack + event['content'] = content + self.device.megolm_decrypt_event(event) + assert event['content'] == plaintext['content'] + + # Replay attack + event['content'] = content + event['event_id'] = 2 + with pytest.raises(RuntimeError): + self.device.megolm_decrypt_event(event) + event['event_id'] = 1 + + # Device verification + self.device.device_keys[self.alice][self.alice_device_id] = self.alice_device + event['content'] = content + # Unverified + self.device.megolm_decrypt_event(event) + assert event['content'] == plaintext['content'] + assert isinstance(event, dict) + + event['content'] = content + # Verified + self.alice_device.verified = True + decrypted_event = self.device.megolm_decrypt_event(event) + assert decrypted_event['content'] == plaintext['content'] + assert isinstance(decrypted_event, VerifiedEvent) + + in_session = MegolmInboundSession(session_key, self.alice_curve_key) + sessions = self.device.megolm_inbound_sessions[self.room_id] + sessions[self.alice_curve_key][in_session.id] = in_session + # Wrong signing key + with pytest.raises(RuntimeError): + self.device.megolm_decrypt_event(event) + self.alice_device.verified = False + + event['content']['algorithm'] = 'wrong' + with pytest.raises(RuntimeError): + self.device.megolm_decrypt_event(event) + + event['content'].pop('algorithm') + event['type'] = 'encrypted' + self.device.megolm_decrypt_event(event) + assert event['type'] == 'encrypted' + + +def test_megolm_outbound_session(): + session = MegolmOutboundSession() + assert session.max_messages == 100 + assert session.max_age == timedelta(days=7) + + session = MegolmOutboundSession(max_messages=1, max_age=100000) + assert session.max_messages == 1 + assert session.max_age == timedelta(milliseconds=100000) + + assert not session.devices + + session.add_device('test') + assert 'test' in session.devices + + session.add_devices({'test2', 'test3'}) + assert 'test2' in session.devices and 'test3' in session.devices + + assert not session.should_rotate() + + session.encrypt('message') + assert session.should_rotate() + + session.max_messages = 2 + assert not session.should_rotate() + session.creation_time = datetime.now() - timedelta(milliseconds=100000) + assert session.should_rotate() diff --git a/test/device_test.py b/test/device_test.py new file mode 100644 index 00000000..ab48f3bc --- /dev/null +++ b/test/device_test.py @@ -0,0 +1,46 @@ +import pytest +import responses + +from matrix_client.api import MATRIX_V2_API_PATH +from matrix_client.client import MatrixClient +from matrix_client.errors import MatrixRequestError +from matrix_client.device import Device + +HOSTNAME = 'http://localhost' + + +class TestDevice(object): + + cli = MatrixClient(HOSTNAME) + user_id = '@test:localhost' + device_id = 'AUIETRSN' + + @pytest.fixture() + def device(self): + return Device(self.cli.api, self.user_id, self.device_id) + + @responses.activate + def test_get_info(self, device): + device_url = HOSTNAME + MATRIX_V2_API_PATH + '/devices/' + self.device_id + display_name = 'android' + last_seen_ip = '1.2.3.4' + last_seen_ts = 1474491775024 + resp = { + "device_id": self.device_id, + "display_name": display_name, + "last_seen_ip": last_seen_ip, + "last_seen_ts": last_seen_ts + } + responses.add(responses.GET, device_url, json=resp) + + assert device.get_info() + assert device.display_name == display_name + assert device.last_seen_ip == last_seen_ip + assert device.last_seen_ts == last_seen_ts + + responses.replace(responses.GET, device_url, status=404) + assert not device.get_info() + + responses.replace(responses.GET, device_url, status=500) + with pytest.raises(MatrixRequestError): + device.get_info() diff --git a/test/response_examples.py b/test/response_examples.py index 2d45aa86..9b5e9456 100644 --- a/test/response_examples.py +++ b/test/response_examples.py @@ -184,3 +184,75 @@ "og:image:width": 48, "og:title": "Matrix Blog Post" } + +example_key_query_response = { + "failures": {}, + "device_keys": { + "@alice:example.com": { + "JLAFKJWSCS": { + "user_id": "@alice:example.com", + "device_id": "JLAFKJWSCS", + "algorithms": [ + "m.olm.curve25519-aes-sha256", + "m.megolm.v1.aes-sha" + ], + "keys": { + "curve25519:JLAFKJWSCS": ("3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIAr" + "zgyI"), + "ed25519:JLAFKJWSCS": "VzJIYXQ85u19z2ZpEeLLVu8hUKTCE0VXYUn4IY4iFcA" + }, + "signatures": { + "@alice:example.com": { + "ed25519:JLAFKJWSCS": + ("wux6Dhjtk7GYPMW54hnx0doVH0NvuUAFBleL5OW99jhbjIutufglAgrYAcu8" + "ueacgNyeSumvtzVIPZXgbB2BCg") + } + }, + "unsigned": { + "device_display_name": "Alice'smobilephone" + } + } + } + } +} + +example_claim_keys_response = { + "failures": {}, + "one_time_keys": { + "@alice:example.com": { + "JLAFKJWSCS": { + 'signed_curve25519:AAAAAQ': { + 'key': '9UOzQjF2j2Xf8mBIiMgruuCkuWtD0ea9kvx63mO92Ws', + 'signatures': { + '@alice:example.com': { + 'ed25519:JLAFKJWSCS': ( + '6O+VYxN7mVcr/j66YdHASRrpW4ydC/0FcYmEWVAGIFzU4+yjzxxinhQD' + 'l7InhhdGuXeQlk4/w/CyU76TY6wdBA' + ) + } + } + } + } + } + } +} + +example_room_key_event = { + "sender": "@alice:example.com", + "sender_device": "JLAFKJWSCS", + "content": { + "algorithm": "m.megolm.v1.aes-sha2", + "room_id": "!test:example.com", + "session_id": "AVCXMm6LZ+J/vyCcomXmE48mbD1IyKbUBUd3UOW0wHE", + "session_key": ( + "AgAAAAAJS98WXiCc90wJ23H1ucZ+XFCv8pN8C5p/XojdA6l7PWlFwAV1fQXe7afrQMRL9BxeeF8M" + "uNnpvGX0hGOWcW0e2LU3EzQ0j8+jhxrPkQHUOJ8387CjRSA9UTBDmw3y8xquy3cXvuGE5DSpFUU7" + "J7Xh+Dli8XRaRDCbmPmMtSdPMwFQlzJui2fif78gnKJl5hOPJmw9SMim1AVHd1DltMBx4vB/3Kse" + "G413GWJkw9T+G6y51bsNEKsSU23lnJz32u5XwgNY9qdFKxGA6WL1wZZS6/iGW4gfTU/Jk89aGSA8" + "Aw") + }, + "type": "m.room_key", + "keys": { + "ed25519": "4VjV3OhFUxWFAcO5YOaQVmTIn29JdRmtNh9iAxoyhkc", + } +} diff --git a/test/user_test.py b/test/user_test.py index db5bae82..beae6b3f 100644 --- a/test/user_test.py +++ b/test/user_test.py @@ -15,7 +15,7 @@ class TestUser: @pytest.fixture() def user(self): - return User(self.cli.api, self.user_id) + return User(self.cli, self.user_id) @pytest.fixture() def room(self):