Skip to content

Commit

Permalink
Fix more cases where incorrect encryption keys were not detected (#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Jun 24, 2023
1 parent c1752bc commit eaa5e29
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 22 deletions.
66 changes: 50 additions & 16 deletions aioesphomeapi/_frame_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from typing import Callable, Optional, Union, cast

import async_timeout
from cryptography.exceptions import InvalidTag
from noise.connection import NoiseConnection # type: ignore

from .core import (
APIConnectionError,
BadNameAPIError,
HandshakeAPIError,
InvalidEncryptionKeyAPIError,
Expand Down Expand Up @@ -219,29 +221,38 @@ def __init__(
) -> None:
"""Initialize the API frame helper."""
super().__init__(on_pkt, on_error)
self._ready_event = asyncio.Event()
self._ready_future = asyncio.get_event_loop().create_future()
self._noise_psk = noise_psk
self._expected_name = expected_name
self._state = NoiseConnectionState.HELLO
self._setup_proto()

def _set_ready_future_exception(self, exc: Exception) -> None:
if not self._ready_future.done():
self._ready_future.set_exception(exc)

def close(self) -> None:
"""Close the connection."""
# Make sure we set the ready event if its not already set
# so that we don't block forever on the ready event if we
# are waiting for the handshake to complete.
self._ready_event.set()
self._set_ready_future_exception(APIConnectionError("Connection closed"))
self._state = NoiseConnectionState.CLOSED
super().close()

def _handle_error_and_close(self, exc: Exception) -> None:
self._set_ready_future_exception(exc)
super()._handle_error_and_close(exc)

def _write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock.
The entire packet must be written in a single call to write
to avoid locking.
"""
_LOGGER.debug("Sending frame %s", frame.hex())
assert self._transport is not None, "Transport is not set"
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug("Sending frame: [%s]", frame.hex())

try:
header = bytes(
Expand All @@ -260,7 +271,7 @@ async def perform_handshake(self) -> None:
self._send_hello()
try:
async with async_timeout.timeout(60.0):
await self._ready_event.wait()
await self._ready_future
except asyncio.TimeoutError as err:
raise HandshakeAPIError("Timeout during handshake") from err

Expand All @@ -273,8 +284,10 @@ def data_received(self, data: bytes) -> None:
self._handle_error_and_close(
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
)
return
msg_size = (header[1] << 8) | header[2]
frame = self._read_exactly(msg_size)

if frame is None:
return

Expand All @@ -292,16 +305,18 @@ def _send_hello(self) -> None:
def _handle_hello(self, server_hello: bytearray) -> None:
"""Perform the handshake with the server, the caller is responsible for having the lock."""
if not server_hello:
raise HandshakeAPIError("ServerHello is empty")
self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
return

# First byte of server hello is the protocol the server chose
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
# exists.
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
raise HandshakeAPIError(
f"Unknown protocol selected by client {chosen_proto}"
self._handle_error_and_close(
HandshakeAPIError(f"Unknown protocol selected by client {chosen_proto}")
)
return

# Check name matches expected name (for noise sessions, this is done
# during hello phase before a connection is set up)
Expand All @@ -311,9 +326,12 @@ def _handle_hello(self, server_hello: bytearray) -> None:
# server name found, this extension was added in 2022.2
server_name = server_hello[1:server_name_i].decode()
if self._expected_name is not None and self._expected_name != server_name:
raise BadNameAPIError(
f"Server sent a different name '{server_name}'", server_name
self._handle_error_and_close(
BadNameAPIError(
f"Server sent a different name '{server_name}'", server_name
)
)
return

self._state = NoiseConnectionState.HANDSHAKE
self._send_handshake()
Expand All @@ -335,12 +353,24 @@ def _handle_handshake(self, msg: bytearray) -> None:
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
raise HandshakeAPIError(f"Handshake failure: {explanation}")
self._proto.read_message(msg[1:])
self._handle_error_and_close(
InvalidEncryptionKeyAPIError("Invalid encryption key")
)
return
self._handle_error_and_close(
HandshakeAPIError(f"Handshake failure: {explanation}")
)
return
try:
self._proto.read_message(msg[1:])
except InvalidTag as invalid_tag_exc:
ex = InvalidEncryptionKeyAPIError("Invalid encryption key")
ex.__cause__ = invalid_tag_exc
self._handle_error_and_close(ex)
return
_LOGGER.debug("Handshake complete")
self._state = NoiseConnectionState.READY
self._ready_event.set()
self._ready_future.set_result(None)

def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket."""
Expand All @@ -367,13 +397,17 @@ def _handle_frame(self, frame: bytearray) -> None:
assert self._proto is not None
msg = self._proto.decrypt(bytes(frame))
if len(msg) < 4:
raise ProtocolAPIError(f"Bad packet frame: {msg}")
self._handle_error_and_close(ProtocolAPIError(f"Bad packet frame: {msg}"))
return
pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg):
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
self._handle_error_and_close(
ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
)
return
data = msg[4 : 4 + data_len]
return self._on_pkt(pkt_type, data)
self._on_pkt(pkt_type, data)

def _handle_closed( # pylint: disable=unused-argument
self, frame: bytearray
Expand Down
10 changes: 8 additions & 2 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
BadNameAPIError,
ConnectionNotEstablishedAPIError,
HandshakeAPIError,
InvalidAuthAPIError,
PingFailedAPIError,
Expand Down Expand Up @@ -432,7 +433,7 @@ async def _do_connect() -> None:
self._cleanup()
raise self._fatal_exception or APIConnectionError("Connection cancelled")
except Exception: # pylint: disable=broad-except
# Always clean up the connection if an error occured during connect
# Always clean up the connection if an error occurred during connect
self._connection_state = ConnectionState.CLOSED
self._cleanup()
raise
Expand Down Expand Up @@ -493,7 +494,12 @@ def is_authenticated(self) -> bool:
def send_message(self, msg: message.Message) -> None:
"""Send a protobuf message to the remote."""
if not self._is_socket_open:
raise APIConnectionError(
if in_do_connect.get(False):
# If we are in the do_connect task, we can't raise an error
# because it would obscure the original exception (ie encrypt error).
_LOGGER.debug("%s: Connection isn't established yet", self.log_name)
return
raise ConnectionNotEstablishedAPIError(
f"Connection isn't established yet ({self._connection_state})"
)

Expand Down
4 changes: 4 additions & 0 deletions aioesphomeapi/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ class HandshakeAPIError(APIConnectionError):
pass


class ConnectionNotEstablishedAPIError(APIConnectionError):
pass


class BadNameAPIError(APIConnectionError):
"""Raised when a name received from the remote but does not much the expected name."""

Expand Down
25 changes: 22 additions & 3 deletions aioesphomeapi/reconnect_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
import zeroconf

from .client import APIClient
from .core import APIConnectionError
from .core import (
APIConnectionError,
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
RequiresEncryptionAPIError,
)

_LOGGER = logging.getLogger(__name__)

EXPECTED_DISCONNECT_COOLDOWN = 3.0
MAXIMUM_BACKOFF_TRIES = 100


class ReconnectLogic(zeroconf.RecordUpdateListener):
Expand Down Expand Up @@ -103,13 +109,26 @@ async def _try_connect(self) -> bool:
level = logging.WARNING if self._tries == 0 else logging.DEBUG
_LOGGER.log(
level,
"Can't connect to ESPHome API for %s: %s",
"Can't connect to ESPHome API for %s: %s (%s)",
self._log_name,
err,
type(err).__name__,
# Print stacktrace if unhandled (not APIConnectionError)
exc_info=not isinstance(err, APIConnectionError),
)
self._tries += 1
if isinstance(
err,
(
RequiresEncryptionAPIError,
InvalidEncryptionKeyAPIError,
InvalidAuthAPIError,
),
):
# If we get an encryption or password error,
# backoff for the maximum amount of time
self._tries = MAXIMUM_BACKOFF_TRIES
else:
self._tries += 1
return False
_LOGGER.info("Successfully connected to %s", self._log_name)
self._connected = True
Expand Down
79 changes: 78 additions & 1 deletion tests/test__frame_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import pytest

from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi.core import BadNameAPIError, InvalidEncryptionKeyAPIError
from aioesphomeapi.util import varuint_to_bytes

PREAMBLE = b"\x00"
Expand Down Expand Up @@ -63,3 +64,79 @@ def _on_error(exc: Exception):

assert type_ == pkt_type
assert data == pkt_data


@pytest.mark.asyncio
async def test_noise_frame_helper_incorrect_key():
"""Test that the noise frame helper raises InvalidEncryptionKeyAPIError on bad key."""
outgoing_packets = [
"010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
]
incoming_packets = [
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
packets = []

def _packet(type_: int, data: bytes):
packets.append((type_, data))

def _on_error(exc: Exception):
raise exc

helper = APINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
)
helper._transport = MagicMock()

for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))

with pytest.raises(InvalidEncryptionKeyAPIError):
for pkt in incoming_packets:
helper.data_received(bytes.fromhex(pkt))

with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake()


@pytest.mark.asyncio
async def test_noise_incorrect_name():
"""Test we raise on bad name."""
outgoing_packets = [
"010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
]
incoming_packets = [
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
packets = []

def _packet(type_: int, data: bytes):
packets.append((type_, data))

def _on_error(exc: Exception):
raise exc

helper = APINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname",
)
helper._transport = MagicMock()

for pkt in outgoing_packets:
helper._write_frame(bytes.fromhex(pkt))

with pytest.raises(BadNameAPIError):
for pkt in incoming_packets:
helper.data_received(bytes.fromhex(pkt))

with pytest.raises(BadNameAPIError):
await helper.perform_handshake()

0 comments on commit eaa5e29

Please sign in to comment.