diff --git a/src/flameconnect/b2c_login.py b/src/flameconnect/b2c_login.py index ddc66c0..0359b72 100644 --- a/src/flameconnect/b2c_login.py +++ b/src/flameconnect/b2c_login.py @@ -62,17 +62,13 @@ def _parse_login_page(html: str, page_url: str) -> dict[str, str]: # Extract CSRF token from var SETTINGS = {..., "csrf":"...", ...} csrf_match = re.search(r'"csrf"\s*:\s*"([^"]+)"', html) if not csrf_match: - raise AuthenticationError( - "Could not find CSRF token in B2C login page" - ) + raise AuthenticationError("Could not find CSRF token in B2C login page") csrf = csrf_match.group(1) # Extract transId from SETTINGS tx_match = re.search(r'"transId"\s*:\s*"([^"]+)"', html) if not tx_match: - raise AuthenticationError( - "Could not find transId in B2C login page" - ) + raise AuthenticationError("Could not find transId in B2C login page") tx = tx_match.group(1) p = _B2C_POLICY @@ -84,10 +80,7 @@ def _parse_login_page(html: str, page_url: str) -> dict[str, str]: qs = f"tx={tx}&p={p}" post_url = f"{origin}{base}SelfAsserted?{qs}" - confirmed_url = ( - f"{origin}{base}" - f"api/CombinedSigninAndSignup/confirmed" - ) + confirmed_url = f"{origin}{base}api/CombinedSigninAndSignup/confirmed" return { "csrf": csrf, @@ -109,10 +102,8 @@ def _build_cookie_header( browsers send them). This function formats cookies in the plain ``name=value; name2=value2`` style that B2C requires. """ - filtered = cookie_jar.filter_cookies(url) - return "; ".join( - f"{m.key}={m.value}" for m in filtered.values() - ) + filtered = cookie_jar.filter_cookies(yarl.URL(url)) + return "; ".join(f"{m.key}={m.value}" for m in filtered.values()) def _log_request( @@ -130,10 +121,7 @@ def _log_request( if headers: _LOGGER.debug(">>> headers: %s", headers) if data: - safe = { - k: ("***" if k == "password" else v) - for k, v in data.items() - } + safe = {k: ("***" if k == "password" else v) for k, v in data.items()} _LOGGER.debug(">>> body: %s", safe) @@ -143,7 +131,9 @@ def _log_response( ) -> None: """Log an incoming HTTP response at DEBUG level.""" _LOGGER.debug( - "<<< %s %s", resp.status, resp.url, + "<<< %s %s", + resp.status, + resp.url, ) _LOGGER.debug("<<< headers: %s", dict(resp.headers)) if body is not None: @@ -153,9 +143,7 @@ def _log_response( _LOGGER.debug("<<< body: %s", preview) -async def b2c_login_with_credentials( - auth_uri: str, email: str, password: str -) -> str: +async def b2c_login_with_credentials(auth_uri: str, email: str, password: str) -> str: """Submit credentials directly to Azure AD B2C and return the redirect URL. Performs the same HTTP flow a browser would: @@ -216,12 +204,8 @@ async def b2c_login_with_credentials( "X-Requested-With": "XMLHttpRequest", "Referer": auth_uri, "Origin": origin, - "Accept": ( - "application/json, text/javascript, */*; q=0.01" - ), - "Content-Type": ( - "application/x-www-form-urlencoded; charset=UTF-8" - ), + "Accept": ("application/json, text/javascript, */*; q=0.01"), + "Content-Type": ("application/x-www-form-urlencoded; charset=UTF-8"), } # Build an unquoted Cookie header — aiohttp's cookie jar @@ -255,17 +239,11 @@ async def b2c_login_with_credentials( _log_response(resp, body) if resp.status != 200: raise AuthenticationError( - f"Credential submission returned HTTP" - f" {resp.status}" + f"Credential submission returned HTTP {resp.status}" ) # Check for error in the JSON-like response - if ( - '"status":"400"' in body - or '"status": "400"' in body - ): - raise AuthenticationError( - "Invalid email or password" - ) + if '"status":"400"' in body or '"status": "400"' in body: + raise AuthenticationError("Invalid email or password") # Merge cookies set by the POST response (e.g. # updated x-ms-cpim-cache and x-ms-cpim-trans) # into the cookie header for the confirmed GET. @@ -276,16 +254,12 @@ async def b2c_login_with_credentials( if "=" in part: n, v = part.split("=", 1) cookies[n] = v - for raw_sc in resp.headers.getall( - "Set-Cookie", [] - ): + for raw_sc in resp.headers.getall("Set-Cookie", []): sc_pair = raw_sc.split(";", 1)[0] if "=" in sc_pair: n, v = sc_pair.split("=", 1) cookies[n] = v - cookie_header = "; ".join( - f"{n}={v}" for n, v in cookies.items() - ) + cookie_header = "; ".join(f"{n}={v}" for n, v in cookies.items()) # Step 4: GET the confirmed endpoint — follows redirects # until we hit the custom-scheme redirect @@ -300,9 +274,7 @@ async def b2c_login_with_credentials( ) # Follow redirects manually to catch custom-scheme one - next_url: str = ( - fields["confirmed_url"] + "?" + confirmed_qs - ) + next_url: str = fields["confirmed_url"] + "?" + confirmed_qs confirmed_headers = { "Cookie": cookie_header, } @@ -316,53 +288,39 @@ async def b2c_login_with_credentials( resp_body = await resp.text() _log_response(resp, resp_body) if resp.status in (301, 302, 303, 307, 308): - location = resp.headers.get( - "Location", "" - ) + location = resp.headers.get("Location", "") if not location: raise AuthenticationError( - "Redirect without Location" - " header" + "Redirect without Location header" ) # Custom-scheme redirect (msal{id}://auth) - if ( - location.startswith("msal") - and "://auth" in location - ): + if location.startswith("msal") and "://auth" in location: _LOGGER.debug( - "Captured custom-scheme" - " redirect: %s", + "Captured custom-scheme redirect: %s", location[:120] + "...", ) return location # Resolve relative URLs if not location.startswith("http"): - location = urljoin( - next_url, location - ) + location = urljoin(next_url, location) next_url = location continue if resp.status == 200: redirect_match = re.search( - r'(msal[a-f0-9-]+://auth' + r"(msal[a-f0-9-]+://auth" r'\?[^\s"\'<]+)', resp_body, ) if redirect_match: return redirect_match.group(1) raise AuthenticationError( - "Reached 200 response without" - " finding redirect URL" + "Reached 200 response without finding redirect URL" ) raise AuthenticationError( - "Unexpected HTTP" - f" {resp.status} during" - " redirect chain" + f"Unexpected HTTP {resp.status} during redirect chain" ) - raise AuthenticationError( - "Too many redirects during B2C login" - ) + raise AuthenticationError("Too many redirects during B2C login") except AuthenticationError: raise except aiohttp.ClientError as exc: diff --git a/src/flameconnect/cli.py b/src/flameconnect/cli.py index 8ff10d2..62863a4 100644 --- a/src/flameconnect/cli.py +++ b/src/flameconnect/cli.py @@ -211,12 +211,8 @@ def _display_mode( ) -> None: """Display Mode parameter.""" unit = temp_unit.unit if temp_unit else TempUnit.CELSIUS - unit_suffix = ( - "C" if unit == TempUnit.CELSIUS else "F" - ) if temp_unit else "" - display_temp = _convert_temp( - param.target_temperature, unit - ) + unit_suffix = ("C" if unit == TempUnit.CELSIUS else "F") if temp_unit else "" + display_temp = _convert_temp(param.target_temperature, unit) print("\n [321] Mode") print(f" {'─' * 40}") mode = _enum_name(_FIRE_MODE_NAMES, param.mode) @@ -253,12 +249,8 @@ def _display_heat( ) -> None: """Display HeatSettings parameter.""" unit = temp_unit.unit if temp_unit else TempUnit.CELSIUS - unit_suffix = ( - "C" if unit == TempUnit.CELSIUS else "F" - ) if temp_unit else "" - display_temp = _convert_temp( - param.setpoint_temperature, unit - ) + unit_suffix = ("C" if unit == TempUnit.CELSIUS else "F") if temp_unit else "" + display_temp = _convert_temp(param.setpoint_temperature, unit) print("\n [323] Heat Settings") print(f" {'─' * 40}") status = _enum_name(_HEAT_STATUS_NAMES, param.heat_status) @@ -566,9 +558,7 @@ async def _set_brightness(client: FlameConnectClient, fire_id: str, value: str) print(f"Brightness set to {value}.") -async def _set_pulsating( - client: FlameConnectClient, fire_id: str, value: str -) -> None: +async def _set_pulsating(client: FlameConnectClient, fire_id: str, value: str) -> None: """Set pulsating effect on or off.""" if value not in _PULSATING_LOOKUP: valid = ", ".join(_PULSATING_LOOKUP) @@ -690,9 +680,7 @@ async def _set_timer(client: FlameConnectClient, fire_id: str, value: str) -> No print("Timer disabled.") -async def _set_temp_unit( - client: FlameConnectClient, fire_id: str, value: str -) -> None: +async def _set_temp_unit(client: FlameConnectClient, fire_id: str, value: str) -> None: """Set the temperature display unit.""" if value not in _TEMP_UNIT_LOOKUP: valid = ", ".join(_TEMP_UNIT_LOOKUP) @@ -704,7 +692,6 @@ async def _set_temp_unit( print(f"Temperature unit set to {value}.") - async def _set_flame_effect( client: FlameConnectClient, fire_id: str, value: str ) -> None: @@ -822,6 +809,7 @@ async def _set_ambient_sensor( await client.write_parameters(fire_id, [new_param]) print(f"Ambient sensor set to {value}.") + async def cmd_tui(*, verbose: bool = False) -> None: """Launch the TUI, showing install message if missing.""" try: diff --git a/src/flameconnect/models.py b/src/flameconnect/models.py index 132106d..3000e5d 100644 --- a/src/flameconnect/models.py +++ b/src/flameconnect/models.py @@ -141,7 +141,6 @@ class RGBWColor: white: int - NAMED_COLORS: dict[str, RGBWColor] = { "dark-red": RGBWColor(red=180, green=0, blue=0, white=0), "light-red": RGBWColor(red=255, green=0, blue=0, white=80), diff --git a/src/flameconnect/protocol.py b/src/flameconnect/protocol.py index ad956eb..b5ebf74 100644 --- a/src/flameconnect/protocol.py +++ b/src/flameconnect/protocol.py @@ -92,7 +92,11 @@ def _decode_mode(raw: bytes) -> ModeParam: _check_length(raw, 6, "Mode") mode = FireMode(raw[3]) target_temperature = _decode_temperature(raw, 4) - _LOGGER.debug("Decoded Mode: mode=%s target_temperature=%.1f", mode, target_temperature) + _LOGGER.debug( + "Decoded Mode: mode=%s target_temperature=%.1f", + mode, + target_temperature, + ) return ModeParam(mode=mode, target_temperature=target_temperature) @@ -318,7 +322,8 @@ def _encode_heat_settings(param: HeatParam) -> bytes: param.setpoint_temperature, param.boost_duration, ) - wire_boost = max(0, param.boost_duration - 1) # model is 1-indexed, wire is 0-indexed + # model is 1-indexed, wire is 0-indexed + wire_boost = max(0, param.boost_duration - 1) payload = ( bytes([param.heat_status, param.heat_mode]) + _encode_temperature(param.setpoint_temperature) diff --git a/src/flameconnect/tui/app.py b/src/flameconnect/tui/app.py index 0b18e75..b773fa8 100644 --- a/src/flameconnect/tui/app.py +++ b/src/flameconnect/tui/app.py @@ -8,7 +8,7 @@ from dataclasses import replace from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Awaitable +from typing import TYPE_CHECKING from textual.app import App, ComposeResult from textual.binding import Binding @@ -23,6 +23,7 @@ if TYPE_CHECKING: import asyncio + from collections.abc import Awaitable from flameconnect.client import FlameConnectClient from flameconnect.models import ( @@ -250,9 +251,7 @@ def _on_dismiss(result: str | None) -> None: if future.done(): return if result is None: - future.set_exception( - AuthenticationError("Authentication cancelled") - ) + future.set_exception(AuthenticationError("Authentication cancelled")) else: future.set_result(result) @@ -379,9 +378,7 @@ def _on_speed_selected(speed: int | None) -> None: if speed is not None and speed != current_speed: self.call_later(self._apply_flame_speed, speed) - self.push_screen( - FlameSpeedScreen(current_speed), callback=_on_speed_selected - ) + self.push_screen(FlameSpeedScreen(current_speed), callback=_on_speed_selected) def _apply_flame_speed(self, speed: int) -> None: """Write the selected flame speed to the fireplace.""" @@ -419,9 +416,7 @@ def action_toggle_brightness(self) -> None: if not isinstance(current, FlameEffectParam): return new_brightness = ( - Brightness.LOW - if current.brightness == Brightness.HIGH - else Brightness.HIGH + Brightness.LOW if current.brightness == Brightness.HIGH else Brightness.HIGH ) new_param = replace(current, brightness=new_brightness) label = "Low" if new_brightness == Brightness.LOW else "High" @@ -500,9 +495,7 @@ def action_toggle_media_light(self) -> None: if not isinstance(current, FlameEffectParam): return new_val = ( - LightStatus.OFF - if current.media_light == LightStatus.ON - else LightStatus.ON + LightStatus.OFF if current.media_light == LightStatus.ON else LightStatus.ON ) new_param = replace(current, media_light=new_val) label = "On" if new_val == LightStatus.ON else "Off" @@ -588,9 +581,7 @@ def _on_color_selected(color: FlameColor | None) -> None: if color is not None and color != current_color: self.call_later(self._apply_flame_color, color) - self.push_screen( - FlameColorScreen(current_color), callback=_on_color_selected - ) + self.push_screen(FlameColorScreen(current_color), callback=_on_color_selected) def _apply_flame_color(self, color: FlameColor) -> None: """Write the selected flame color to the fireplace.""" @@ -636,9 +627,7 @@ def _on_theme_selected(theme: MediaTheme | None) -> None: if theme is not None and theme != current_theme: self.call_later(self._apply_media_theme, theme) - self.push_screen( - MediaThemeScreen(current_theme), callback=_on_theme_selected - ) + self.push_screen(MediaThemeScreen(current_theme), callback=_on_theme_selected) def _apply_media_theme(self, theme: MediaTheme) -> None: """Write the selected media theme to the fireplace.""" @@ -674,9 +663,7 @@ async def _worker() -> None: await s.refresh_state() refreshed_params = s.current_parameters refreshed = refreshed_params.get(FlameEffectParam) - _LOGGER.debug( - "Media theme change: after_refresh=%s", refreshed - ) + _LOGGER.debug("Media theme change: after_refresh=%s", refreshed) except Exception as exc: _LOGGER.exception("Media theme change failed") s = self.screen @@ -799,16 +786,12 @@ def action_toggle_heat(self) -> None: if not isinstance(current, HeatParam): return new_val = ( - HeatStatus.OFF - if current.heat_status == HeatStatus.ON - else HeatStatus.ON + HeatStatus.OFF if current.heat_status == HeatStatus.ON else HeatStatus.ON ) new_param = replace(current, heat_status=new_val) label = "On" if new_val == HeatStatus.ON else "Off" self._run_command( - self.client.write_parameters( - self.fire_id, [new_param] - ), + self.client.write_parameters(self.fire_id, [new_param]), f"Setting heat to {label}...", "Heat toggle failed", ) @@ -841,9 +824,7 @@ def _on_selected( callback=_on_selected, ) - def _apply_heat_mode( - self, mode: HeatMode, boost_minutes: int | None - ) -> None: + def _apply_heat_mode(self, mode: HeatMode, boost_minutes: int | None) -> None: """Write the selected heat mode to the fireplace.""" from flameconnect.models import HeatMode, HeatParam @@ -858,9 +839,7 @@ def _apply_heat_mode( if not isinstance(current, HeatParam): return if mode == HeatMode.BOOST and boost_minutes is not None: - new_param = replace( - current, heat_mode=mode, boost_duration=boost_minutes - ) + new_param = replace(current, heat_mode=mode, boost_duration=boost_minutes) else: new_param = replace(current, heat_mode=mode) mode_label = _display_name(mode) @@ -978,6 +957,7 @@ def action_toggle_timer(self) -> None: "Timer toggle failed", ) else: + def _on_timer_dismiss(duration: int | None) -> None: if duration is not None: self.call_later(self._apply_timer, duration) @@ -1090,12 +1070,13 @@ async def _tui_auth_prompt(auth_uri: str, redirect_uri: str) -> str: import curses as _curses try: - _curses.setupterm(fd=sys.__stderr__.fileno()) - rmcup = _curses.tigetstr("rmcup") or b"" - cnorm = _curses.tigetstr("cnorm") or b"" - clear = _curses.tigetstr("clear") or b"" - sys.__stderr__.buffer.write(rmcup + cnorm + clear) - sys.__stderr__.flush() + if sys.__stderr__ is not None: + _curses.setupterm(fd=sys.__stderr__.fileno()) + rmcup = _curses.tigetstr("rmcup") or b"" + cnorm = _curses.tigetstr("cnorm") or b"" + clear = _curses.tigetstr("clear") or b"" + sys.__stderr__.buffer.write(rmcup + cnorm + clear) + sys.__stderr__.flush() except Exception: # noqa: BLE001 pass diff --git a/src/flameconnect/tui/auth_screen.py b/src/flameconnect/tui/auth_screen.py index 150977f..3634928 100644 --- a/src/flameconnect/tui/auth_screen.py +++ b/src/flameconnect/tui/auth_screen.py @@ -4,9 +4,12 @@ import logging import webbrowser +from typing import TYPE_CHECKING -from textual.app import ComposeResult from textual.containers import Vertical + +if TYPE_CHECKING: + from textual.app import ComposeResult from textual.screen import ModalScreen from textual.widgets import Button, Input, Label, Static diff --git a/src/flameconnect/tui/color_screen.py b/src/flameconnect/tui/color_screen.py index ebd3721..5101067 100644 --- a/src/flameconnect/tui/color_screen.py +++ b/src/flameconnect/tui/color_screen.py @@ -117,19 +117,27 @@ def compose(self) -> ComposeResult: with Horizontal(id="rgbw-inputs"): yield Static("R:") yield Input( - str(cur.red), id="input-r", type="integer", + str(cur.red), + id="input-r", + type="integer", ) yield Static("G:") yield Input( - str(cur.green), id="input-g", type="integer", + str(cur.green), + id="input-g", + type="integer", ) yield Static("B:") yield Input( - str(cur.blue), id="input-b", type="integer", + str(cur.blue), + id="input-b", + type="integer", ) yield Static("W:") yield Input( - str(cur.white), id="input-w", type="integer", + str(cur.white), + id="input-w", + type="integer", ) with Horizontal(id="rgbw-actions"): yield Button("Set", id="set-rgbw", variant="primary") diff --git a/src/flameconnect/tui/fire_select_screen.py b/src/flameconnect/tui/fire_select_screen.py index e273f76..5904bcf 100644 --- a/src/flameconnect/tui/fire_select_screen.py +++ b/src/flameconnect/tui/fire_select_screen.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from textual.app import ComposeResult + from textual.widgets._button import ButtonVariant _CSS = """ FireSelectScreen { @@ -84,7 +85,7 @@ def compose(self) -> ComposeResult: f" \u2014 {fire.brand} {fire.product_model}" f" ({conn})" ) - variant = ( + variant: ButtonVariant = ( "primary" if fire.fire_id == self._current_fire_id else "default" diff --git a/src/flameconnect/tui/flame_speed_screen.py b/src/flameconnect/tui/flame_speed_screen.py index 508bd02..f26862b 100644 --- a/src/flameconnect/tui/flame_speed_screen.py +++ b/src/flameconnect/tui/flame_speed_screen.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from textual.app import ComposeResult + from textual.widgets._button import ButtonVariant _CSS = """ FlameSpeedScreen { @@ -70,7 +71,9 @@ def compose(self) -> ComposeResult: ) with Horizontal(id="flame-speed-buttons"): for i in range(1, 6): - variant = "primary" if i == self._current_speed else "default" + variant: ButtonVariant = ( + "primary" if i == self._current_speed else "default" + ) yield Button(str(i), id=f"speed-{i}", variant=variant) def on_button_pressed(self, event: Button.Pressed) -> None: diff --git a/src/flameconnect/tui/heat_mode_screen.py b/src/flameconnect/tui/heat_mode_screen.py index ab696ae..ff4e8a0 100644 --- a/src/flameconnect/tui/heat_mode_screen.py +++ b/src/flameconnect/tui/heat_mode_screen.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from textual.app import ComposeResult + from textual.widgets._button import ButtonVariant _CSS = """ HeatModeScreen { @@ -92,16 +93,12 @@ def compose(self) -> ComposeResult: with Horizontal(id="heat-mode-buttons"): for mode in (HeatMode.NORMAL, HeatMode.ECO, HeatMode.BOOST): label = mode.name.replace("_", " ").title() - variant = ( + variant: ButtonVariant = ( "primary" if mode == self._current_mode else "default" ) - yield Button( - label, id=f"mode-{mode.name.lower()}", variant=variant - ) + yield Button(label, id=f"mode-{mode.name.lower()}", variant=variant) with Vertical(id="boost-input-container"): - yield Static( - "Boost duration (1-20 min):", id="boost-label" - ) + yield Static("Boost duration (1-20 min):", id="boost-label") yield Input( placeholder="minutes", id="boost-duration", diff --git a/src/flameconnect/tui/screens.py b/src/flameconnect/tui/screens.py index 8dc87c0..453f342 100644 --- a/src/flameconnect/tui/screens.py +++ b/src/flameconnect/tui/screens.py @@ -111,9 +111,7 @@ def emit(self, record: logging.LogRecord) -> None: try: ts = datetime.fromtimestamp(record.created).strftime("%H:%M:%S") msg = self.format(record) - open_tag, close_tag = _LEVEL_MARKUP.get( - record.levelno, ("", "") - ) + open_tag, close_tag = _LEVEL_MARKUP.get(record.levelno, ("", "")) self._rich_log.write( f"[dim]{ts}[/dim] {open_tag}{msg}{close_tag}", shrink=False, @@ -167,7 +165,8 @@ def on_mount(self) -> None: """Install log handler and do initial data load.""" rich_log = self.query_one("#messages-panel", RichLog) self._log_handler = _TuiLogHandler(rich_log) - self._log_handler.setFormatter(logging.Formatter("%(levelname)s %(name)s: %(message)s")) + fmt = "%(levelname)s %(name)s: %(message)s" + self._log_handler.setFormatter(logging.Formatter(fmt)) fc_logger = logging.getLogger("flameconnect") fc_logger.addHandler(self._log_handler) @@ -184,10 +183,7 @@ def _apply_compact_mode(self) -> None: except OSError: w, h = self.app.size.width, self.app.size.height - compact = ( - w < _COMPACT_THRESHOLD_WIDTH - or h < _COMPACT_THRESHOLD_HEIGHT - ) + compact = w < _COMPACT_THRESHOLD_WIDTH or h < _COMPACT_THRESHOLD_HEIGHT self.set_class(compact, "compact") # Toggle .compact class on each widget — same-element CSS @@ -297,7 +293,9 @@ def _log_param_changes( new_val = getattr(new_param, field.name) if old_val != new_val: label = field.name.replace("_", " ").title() - self.log_message(f"[bold]{name}[/bold] {label}: {old_val} → {new_val}") + self.log_message( + f"[bold]{name}[/bold] {label}: {old_val} → {new_val}" + ) @property def current_parameters(self) -> dict[type, Parameter]: diff --git a/src/flameconnect/tui/temperature_screen.py b/src/flameconnect/tui/temperature_screen.py index e3dbc2c..1a5b363 100644 --- a/src/flameconnect/tui/temperature_screen.py +++ b/src/flameconnect/tui/temperature_screen.py @@ -32,6 +32,7 @@ def _convert_to_celsius( """Convert a Fahrenheit value back to Celsius (rounded to 1 dp).""" return round((fahrenheit - 32) * 5 / 9, 1) + _CSS = """ TemperatureScreen { align: center middle; @@ -97,9 +98,7 @@ def compose(self) -> ComposeResult: unit_str = "\u00b0C" if self._unit == TempUnit.CELSIUS else "\u00b0F" celsius = self._unit == TempUnit.CELSIUS range_str = "5.0 \u2013 35.0" if celsius else "40.0 \u2013 95.0" - display_temp = _convert_temp( - self._current_temp, self._unit - ) + display_temp = _convert_temp(self._current_temp, self._unit) with Vertical(id="temp-dialog"): yield Static( f"Set Temperature (current: {display_temp}{unit_str})", diff --git a/src/flameconnect/tui/widgets.py b/src/flameconnect/tui/widgets.py index bc526ea..d854cdb 100644 --- a/src/flameconnect/tui/widgets.py +++ b/src/flameconnect/tui/widgets.py @@ -68,9 +68,7 @@ class _ClickableValue(Static): _ClickableValue.clickable:hover { background: $surface-lighten-2; } """ - def __init__( - self, content: str, action: str | None = None, **kwargs: Any - ) -> None: + def __init__(self, content: str, action: str | None = None, **kwargs: Any) -> None: super().__init__(content, **kwargs) self._action = action if action: @@ -105,9 +103,7 @@ def __init__( def compose(self) -> ComposeResult: """Compose the label and value children.""" yield Static(self._label, classes="param-label") - yield _ClickableValue( - self._value, action=self._action - ) + yield _ClickableValue(self._value, action=self._action) def _display_name(value: IntEnum) -> str: @@ -117,10 +113,7 @@ def _display_name(value: IntEnum) -> str: def _format_rgbw(color: RGBWColor) -> str: """Format an RGBW color value for display.""" - return ( - f"R:{color.red} G:{color.green} " - f"B:{color.blue} W:{color.white}" - ) + return f"R:{color.red} G:{color.green} B:{color.blue} W:{color.white}" _MODE_DISPLAY: dict[FireMode, str] = { @@ -136,9 +129,7 @@ def _temp_suffix(temp_unit: TempUnitParam | None) -> str: return "C" if temp_unit.unit == TempUnit.CELSIUS else "F" -def _convert_temp( - celsius: float, unit: TempUnit -) -> float: +def _convert_temp(celsius: float, unit: TempUnit) -> float: """Convert a Celsius temperature for display. Returns the value unchanged when *unit* is CELSIUS, or @@ -157,16 +148,10 @@ def _format_mode( Returns a list of (label, value, action) tuples. """ - mode_label = _MODE_DISPLAY.get( - param.mode, _display_name(param.mode) - ) + mode_label = _MODE_DISPLAY.get(param.mode, _display_name(param.mode)) suffix = _temp_suffix(temp_unit) - unit = ( - temp_unit.unit if temp_unit else TempUnit.CELSIUS - ) - display_temp = _convert_temp( - param.target_temperature, unit - ) + unit = temp_unit.unit if temp_unit else TempUnit.CELSIUS + display_temp = _convert_temp(param.target_temperature, unit) return [ ( "[bold]Mode:[/bold] ", @@ -258,17 +243,11 @@ def _format_heat( from flameconnect.models import HeatMode boost_value = ( - f"{param.boost_duration}min" - if param.heat_mode == HeatMode.BOOST - else "Off" + f"{param.boost_duration}min" if param.heat_mode == HeatMode.BOOST else "Off" ) suffix = _temp_suffix(temp_unit) - unit = ( - temp_unit.unit if temp_unit else TempUnit.CELSIUS - ) - display_temp = _convert_temp( - param.setpoint_temperature, unit - ) + unit = temp_unit.unit if temp_unit else TempUnit.CELSIUS + display_temp = _convert_temp(param.setpoint_temperature, unit) return [ ( "[bold]Heat:[/bold] ", @@ -316,20 +295,10 @@ def _format_timer( from flameconnect.models import TimerStatus - value = ( - f"{_display_name(param.timer_status)}" - f" Duration: {param.duration}min" - ) - if ( - param.timer_status == TimerStatus.ENABLED - and param.duration > 0 - ): - off_time = datetime.now() + timedelta( - minutes=param.duration - ) - value += ( - f" Off at {off_time.strftime('%H:%M')}" - ) + value = f"{_display_name(param.timer_status)} Duration: {param.duration}min" + if param.timer_status == TimerStatus.ENABLED and param.duration > 0: + off_time = datetime.now() + timedelta(minutes=param.duration) + value += f" Off at {off_time.strftime('%H:%M')}" return [("[bold]Timer:[/bold] ", value, "toggle_timer")] @@ -417,8 +386,7 @@ def _format_sound( return [ ( "[bold]Sound:[/bold] ", - f"Volume {param.volume}" - f" File: {param.sound_file}", + f"Volume {param.volume} File: {param.sound_file}", None, ), ] @@ -474,50 +442,28 @@ def format_parameters( break # Collect formatted tuples keyed by type. - formatted: dict[ - type, list[tuple[str, str, str | None]] - ] = {} + formatted: dict[type, list[tuple[str, str, str | None]]] = {} for param in params: if isinstance(param, ModeParam): - formatted[ModeParam] = _format_mode( - param, temp_unit - ) + formatted[ModeParam] = _format_mode(param, temp_unit) elif isinstance(param, HeatParam): - formatted[HeatParam] = _format_heat( - param, temp_unit - ) + formatted[HeatParam] = _format_heat(param, temp_unit) elif isinstance(param, HeatModeParam): - formatted[HeatModeParam] = ( - _format_heat_mode(param) - ) + formatted[HeatModeParam] = _format_heat_mode(param) elif isinstance(param, FlameEffectParam): - formatted[FlameEffectParam] = ( - _format_flame_effect(param) - ) + formatted[FlameEffectParam] = _format_flame_effect(param) elif isinstance(param, TimerParam): - formatted[TimerParam] = ( - _format_timer(param) - ) + formatted[TimerParam] = _format_timer(param) elif isinstance(param, SoftwareVersionParam): - formatted[SoftwareVersionParam] = ( - _format_software_version(param) - ) + formatted[SoftwareVersionParam] = _format_software_version(param) elif isinstance(param, ErrorParam): - formatted[ErrorParam] = ( - _format_error(param) - ) + formatted[ErrorParam] = _format_error(param) elif isinstance(param, TempUnitParam): - formatted[TempUnitParam] = ( - _format_temp_unit(param) - ) + formatted[TempUnitParam] = _format_temp_unit(param) elif isinstance(param, SoundParam): - formatted[SoundParam] = ( - _format_sound(param) - ) + formatted[SoundParam] = _format_sound(param) elif isinstance(param, LogEffectParam): - formatted[LogEffectParam] = ( - _format_log_effect(param) - ) + formatted[LogEffectParam] = _format_log_effect(param) # Desired display order (ErrorParam last). display_order: list[type] = [ @@ -538,9 +484,7 @@ def format_parameters( result.extend(formatted[t]) if not result: - result.append( - ("[dim]No parameters available[/dim]", "", None) - ) + result.append(("[dim]No parameters available[/dim]", "", None)) return result @@ -618,50 +562,115 @@ def _rgbw_to_style(color: RGBWColor) -> str: # atoms: (text, trailing_gap_weight) _FLAME_DEFS: list[tuple[float, int, list[tuple[str, int]]]] = [ # Row 0: Sparse tips - (0.95, 0, [ - ("( )", 3), (",", 5), (")", 5), - (",", 5), ("( )", 3), (",", 5), ("( )", 0), - ]), + ( + 0.95, + 0, + [ + ("( )", 3), + (",", 5), + (")", 5), + (",", 5), + ("( )", 3), + (",", 5), + ("( )", 0), + ], + ), # Row 1: Forming columns - (0.95, 0, [ - ("( \\ )", 2), ("( | )", 2), ("( )", 1), - ("( \\)", 2), ("( | )", 2), - ("( | )", 1), ("( )", 0), - ]), + ( + 0.95, + 0, + [ + ("( \\ )", 2), + ("( | )", 2), + ("( )", 1), + ("( \\)", 2), + ("( | )", 2), + ("( | )", 1), + ("( )", 0), + ], + ), # Row 2: Growing - (0.95, 1, [ - ("( \\ \\ )", 2), ("( | | )", 2), ("( \\ )", 1), - ("( \\ \\)", 2), ("( / | )", 2), - ("( / | )", 1), ("( )", 0), - ]), + ( + 0.95, + 1, + [ + ("( \\ \\ )", 2), + ("( | | )", 2), + ("( \\ )", 1), + ("( \\ \\)", 2), + ("( / | )", 2), + ("( / | )", 1), + ("( )", 0), + ], + ), # Row 3: Full width - (0.95, 1, [ - ("( \\\\ \\)", 2), ("( || |)", 2), ("( \\\\ )", 1), - ("( \\\\ )", 2), ("( /| |)", 2), - ("( /| | )", 1), ("( )", 0), - ]), + ( + 0.95, + 1, + [ + ("( \\\\ \\)", 2), + ("( || |)", 2), + ("( \\\\ )", 1), + ("( \\\\ )", 2), + ("( /| |)", 2), + ("( /| | )", 1), + ("( )", 0), + ], + ), # Row 4: Dense mid - (0.95, 1, [ - ("( \\\\ )", 2), ("( || )", 2), ("( \\\\ )", 1), - ("( || )", 2), ("( /| )", 2), - ("( /| |)", 1), ("( )", 0), - ]), + ( + 0.95, + 1, + [ + ("( \\\\ )", 2), + ("( || )", 2), + ("( \\\\ )", 1), + ("( || )", 2), + ("( /| )", 2), + ("( /| |)", 1), + ("( )", 0), + ], + ), # Row 5: Narrowing - (0.95, 2, [ - ("( \\\\)", 2), ("( //)", 2), ("( \\\\ )", 1), - ("( || )", 2), ("( // )", 2), - ("( // )", 1), ("( )", 0), - ]), + ( + 0.95, + 2, + [ + ("( \\\\)", 2), + ("( //)", 2), + ("( \\\\ )", 1), + ("( || )", 2), + ("( // )", 2), + ("( // )", 1), + ("( )", 0), + ], + ), # Row 6: Base - (0.95, 2, [ - ("(\\)", 2), ("(/)", 2), ("(\\)(|)", 3), - ("(/)", 2), ("(/)", 2), ("()", 0), - ]), + ( + 0.95, + 2, + [ + ("(\\)", 2), + ("(/)", 2), + ("(\\)(|)", 3), + ("(/)", 2), + ("(/)", 2), + ("()", 0), + ], + ), # Row 7: Base - (0.95, 2, [ - ("(\\)", 2), ("(/)", 2), ("(\\|/)", 4), - ("(/)", 2), ("(/)", 2), ("()", 0), - ]), + ( + 0.95, + 2, + [ + ("(\\)", 2), + ("(/)", 2), + ("(\\|/)", 4), + ("(/)", 2), + ("(/)", 2), + ("()", 0), + ], + ), ] @@ -729,8 +738,8 @@ def _build_fire_art( zone (reducing the flame row budget to keep total height constant). """ - ow = w - 2 # fill width between outer frame borders - iw = w - 4 # content width between inner frame borders + ow = w - 2 # fill width between outer frame borders + iw = w - 4 # content width between inner frame borders # Rotate palette for animation palette = _rotate_palette(flame_palette, anim_frame) @@ -745,9 +754,7 @@ def _nl() -> None: # Reserve rows for heat indicators (reduce flame budget) heat_row_count = _HEAT_ROWS if heat_on else 0 - flame_rows_effective = max( - flame_rows - heat_row_count, _MIN_FLAME_ROWS - ) + flame_rows_effective = max(flame_rows - heat_row_count, _MIN_FLAME_ROWS) num_defs = len(_FLAME_DEFS) if flame_rows_effective >= num_defs: @@ -755,16 +762,14 @@ def _nl() -> None: defs_to_render = _FLAME_DEFS else: blank_above = 0 - defs_to_render = _FLAME_DEFS[ - num_defs - flame_rows_effective: - ] + defs_to_render = _FLAME_DEFS[num_defs - flame_rows_effective :] # -- heat indicator rows (above frame) -- if heat_on: # Alternate two wave patterns for visual variety wave_chars = [ - "\u2248" * ow, # (approx-equal signs) - "~" * ow, # (tildes) + "\u2248" * ow, # (approx-equal signs) + "~" * ow, # (tildes) ] actual_heat_rows = flame_rows - flame_rows_effective for i in range(actual_heat_rows): @@ -872,9 +877,7 @@ def update_state( if heat_param is not None: from flameconnect.models import HeatStatus - self._heat_on = ( - heat_param.heat_status == HeatStatus.ON - ) + self._heat_on = heat_param.heat_status == HeatStatus.ON else: self._heat_on = False @@ -898,20 +901,13 @@ def update_state( if fire_on: speed_changed = new_speed != self._flame_speed self._flame_speed = new_speed - if ( - self._anim_timer is None - or speed_changed - ): + if self._anim_timer is None or speed_changed: # Cancel existing timer if any if self._anim_timer is not None: self._anim_timer.stop() self._anim_timer = None - interval = _FLAME_SPEED_INTERVALS.get( - self._flame_speed, 0.3 - ) - self._anim_timer = self.set_interval( - interval, self._advance_frame - ) + interval = _FLAME_SPEED_INTERVALS.get(self._flame_speed, 0.3) + self._anim_timer = self.set_interval(interval, self._advance_frame) else: # Fire is off -- stop animation if self._anim_timer is not None: @@ -944,16 +940,10 @@ def render(self) -> _Text: # LED and media styling applies when power is on, # regardless of flame effect state. if power_on and flame_effect is not None: - palette = _FLAME_PALETTES.get( - flame_effect.flame_color, _DEFAULT_PALETTE - ) + palette = _FLAME_PALETTES.get(flame_effect.flame_color, _DEFAULT_PALETTE) if flame_effect.light_status == LightStatus.ON: - led_style = _rgbw_to_style( - flame_effect.overhead_color - ) - media_style = _rgbw_to_style( - flame_effect.media_color - ) + led_style = _rgbw_to_style(flame_effect.overhead_color) + media_style = _rgbw_to_style(flame_effect.media_color) # Flames are only visible when power is on AND # flame effect is ON (or not yet received). @@ -966,7 +956,8 @@ def render(self) -> _Text: fire_on = False return _build_fire_art( - w, h, + w, + h, fire_on=fire_on, flame_palette=palette, led_style=led_style, @@ -983,9 +974,7 @@ def compose(self) -> ComposeResult: """Initial composition -- a loading placeholder.""" yield Static("[dim]Loading...[/dim]") - def update_parameters( - self, params: list[Parameter] - ) -> None: + def update_parameters(self, params: list[Parameter]) -> None: """Update the panel with new parameter data. Clears existing children and mounts new @@ -997,9 +986,7 @@ def update_parameters( fields = format_parameters(params) widgets: list[ClickableParam] = [] for label, value, action in fields: - widgets.append( - ClickableParam(label, value, action=action) - ) + widgets.append(ClickableParam(label, value, action=action)) self.query("*").remove() self.mount(*widgets) diff --git a/tests/test_auth.py b/tests/test_auth.py index 2a773d7..28565b2 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -66,9 +66,7 @@ async def test_silent_acquisition(self, mock_msal, tmp_path): mock_msal.SerializableTokenCache.return_value = mock_cache mock_app = MagicMock() - mock_app.get_accounts.return_value = [ - {"username": "user@example.com"} - ] + mock_app.get_accounts.return_value = [{"username": "user@example.com"}] mock_app.acquire_token_silent.return_value = { "access_token": "cached-token-789" } @@ -81,9 +79,7 @@ async def test_silent_acquisition(self, mock_msal, tmp_path): mock_app.acquire_token_silent.assert_called_once() @patch("flameconnect.auth.msal") - async def test_no_accounts_triggers_interactive( - self, mock_msal, tmp_path - ): + async def test_no_accounts_triggers_interactive(self, mock_msal, tmp_path): """When no cached accounts exist, interactive flow is started.""" cache_path = tmp_path / "token_cache.json" @@ -101,23 +97,17 @@ async def test_no_accounts_triggers_interactive( } mock_msal.PublicClientApplication.return_value = mock_app - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=abc123" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) token = await auth.get_token() assert token == "interactive-token-abc" @patch("flameconnect.auth.msal") - async def test_cache_saved_on_state_change( - self, mock_msal, tmp_path - ): + async def test_cache_saved_on_state_change(self, mock_msal, tmp_path): """Cache is written to disk when has_state_changed is True.""" cache_path = tmp_path / "token_cache.json" @@ -127,12 +117,8 @@ async def test_cache_saved_on_state_change( mock_msal.SerializableTokenCache.return_value = mock_cache mock_app = MagicMock() - mock_app.get_accounts.return_value = [ - {"username": "user@example.com"} - ] - mock_app.acquire_token_silent.return_value = { - "access_token": "refreshed-token" - } + mock_app.get_accounts.return_value = [{"username": "user@example.com"}] + mock_app.acquire_token_silent.return_value = {"access_token": "refreshed-token"} mock_msal.PublicClientApplication.return_value = mock_app auth = MsalAuth(cache_path=cache_path) @@ -151,31 +137,22 @@ class TestParseRedirectUrl: """Test the URL parsing helper.""" def test_parses_code(self): - result = MsalAuth._parse_redirect_url( - "https://redirect?code=abc123&state=xyz" - ) + result = MsalAuth._parse_redirect_url("https://redirect?code=abc123&state=xyz") assert result["code"] == "abc123" assert result["state"] == "xyz" def test_ellipsis_raises(self): with pytest.raises(AuthenticationError, match="ellipsis"): - MsalAuth._parse_redirect_url( - "https://redirect?code=abc\u2026def" - ) + MsalAuth._parse_redirect_url("https://redirect?code=abc\u2026def") def test_no_code_raises(self): - with pytest.raises( - AuthenticationError, match="No authorization code" - ): - MsalAuth._parse_redirect_url( - "https://redirect?state=xyz" - ) + with pytest.raises(AuthenticationError, match="No authorization code"): + MsalAuth._parse_redirect_url("https://redirect?state=xyz") def test_error_in_url_raises(self): with pytest.raises(AuthenticationError, match="Auth error"): MsalAuth._parse_redirect_url( - "https://redirect?error=access_denied" - "&error_description=User+cancelled" + "https://redirect?error=access_denied&error_description=User+cancelled" ) def test_fragment_url_parsing(self): @@ -188,12 +165,8 @@ def test_fragment_url_parsing(self): def test_error_without_description_raises(self): """URL with error but no error_description still raises.""" - with pytest.raises( - AuthenticationError, match="Auth error.*server_error" - ): - MsalAuth._parse_redirect_url( - "https://redirect?error=server_error" - ) + with pytest.raises(AuthenticationError, match="Auth error.*server_error"): + MsalAuth._parse_redirect_url("https://redirect?error=server_error") # --------------------------------------------------------------------------- @@ -205,9 +178,7 @@ class TestMsalAuthEdgeCases: """Test error paths in MsalAuth.get_token / _interactive_flow.""" @patch("flameconnect.auth.msal") - async def test_existing_cache_is_loaded( - self, mock_msal, tmp_path - ): + async def test_existing_cache_is_loaded(self, mock_msal, tmp_path): """When the cache file exists, deserialize is called.""" cache_path = tmp_path / "token_cache.json" cache_path.write_text('{"cached": "data"}') @@ -217,26 +188,18 @@ async def test_existing_cache_is_loaded( mock_msal.SerializableTokenCache.return_value = mock_cache mock_app = MagicMock() - mock_app.get_accounts.return_value = [ - {"username": "u@ex.com"} - ] - mock_app.acquire_token_silent.return_value = { - "access_token": "loaded-token" - } + mock_app.get_accounts.return_value = [{"username": "u@ex.com"}] + mock_app.acquire_token_silent.return_value = {"access_token": "loaded-token"} mock_msal.PublicClientApplication.return_value = mock_app auth = MsalAuth(cache_path=cache_path) token = await auth.get_token() assert token == "loaded-token" - mock_cache.deserialize.assert_called_once_with( - '{"cached": "data"}' - ) + mock_cache.deserialize.assert_called_once_with('{"cached": "data"}') @patch("flameconnect.auth.msal") - async def test_no_auth_uri_in_flow_raises( - self, mock_msal, tmp_path - ): + async def test_no_auth_uri_in_flow_raises(self, mock_msal, tmp_path): """initiate_auth_code_flow with no auth_uri raises.""" cache_path = tmp_path / "token_cache.json" @@ -251,24 +214,16 @@ async def test_no_auth_uri_in_flow_raises( } mock_msal.PublicClientApplication.return_value = mock_app - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=abc123" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) - with pytest.raises( - AuthenticationError, match="Failed to initiate" - ): + with pytest.raises(AuthenticationError, match="Failed to initiate"): await auth.get_token() @patch("flameconnect.auth.msal") - async def test_no_prompt_callback_raises( - self, mock_msal, tmp_path - ): + async def test_no_prompt_callback_raises(self, mock_msal, tmp_path): """No prompt_callback provided raises on interactive login.""" cache_path = tmp_path / "token_cache.json" @@ -285,15 +240,11 @@ async def test_no_prompt_callback_raises( auth = MsalAuth(cache_path=cache_path) - with pytest.raises( - AuthenticationError, match="no prompt_callback" - ): + with pytest.raises(AuthenticationError, match="no prompt_callback"): await auth.get_token() @patch("flameconnect.auth.msal") - async def test_token_exchange_failure_raises( - self, mock_msal, tmp_path - ): + async def test_token_exchange_failure_raises(self, mock_msal, tmp_path): """acquire_token_by_auth_code_flow with no access_token.""" cache_path = tmp_path / "token_cache.json" @@ -312,18 +263,12 @@ async def test_token_exchange_failure_raises( } mock_msal.PublicClientApplication.return_value = mock_app - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=expired-code" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) - with pytest.raises( - AuthenticationError, match="Token exchange failed" - ): + with pytest.raises(AuthenticationError, match="Token exchange failed"): await auth.get_token() @@ -336,9 +281,7 @@ class TestBuildAppArgs: """Verify _build_app passes correct args to MSAL.""" @patch("flameconnect.auth.msal") - async def test_build_app_passes_client_id( - self, mock_msal, tmp_path - ): + async def test_build_app_passes_client_id(self, mock_msal, tmp_path): """PublicClientApplication receives CLIENT_ID as first arg.""" cache_path = tmp_path / "token_cache.json" @@ -348,9 +291,7 @@ async def test_build_app_passes_client_id( mock_app = MagicMock() mock_app.get_accounts.return_value = [{"username": "u"}] - mock_app.acquire_token_silent.return_value = { - "access_token": "t" - } + mock_app.acquire_token_silent.return_value = {"access_token": "t"} mock_msal.PublicClientApplication.return_value = mock_app auth = MsalAuth(cache_path=cache_path) @@ -364,9 +305,7 @@ async def test_build_app_passes_client_id( ) @patch("flameconnect.auth.msal") - async def test_build_app_cache_passed_as_token_cache( - self, mock_msal, tmp_path - ): + async def test_build_app_cache_passed_as_token_cache(self, mock_msal, tmp_path): """The SerializableTokenCache instance is used as token_cache.""" cache_path = tmp_path / "token_cache.json" @@ -376,9 +315,7 @@ async def test_build_app_cache_passed_as_token_cache( mock_app = MagicMock() mock_app.get_accounts.return_value = [{"username": "u"}] - mock_app.acquire_token_silent.return_value = { - "access_token": "t" - } + mock_app.acquire_token_silent.return_value = {"access_token": "t"} mock_msal.PublicClientApplication.return_value = mock_app auth = MsalAuth(cache_path=cache_path) @@ -399,9 +336,7 @@ class TestSilentAcquisitionArgs: """Verify acquire_token_silent receives correct args.""" @patch("flameconnect.auth.msal") - async def test_silent_uses_scopes_and_first_account( - self, mock_msal, tmp_path - ): + async def test_silent_uses_scopes_and_first_account(self, mock_msal, tmp_path): """acquire_token_silent gets SCOPES and accounts[0].""" cache_path = tmp_path / "token_cache.json" @@ -412,22 +347,16 @@ async def test_silent_uses_scopes_and_first_account( acct = {"username": "user@example.com"} mock_app = MagicMock() mock_app.get_accounts.return_value = [acct] - mock_app.acquire_token_silent.return_value = { - "access_token": "tok" - } + mock_app.acquire_token_silent.return_value = {"access_token": "tok"} mock_msal.PublicClientApplication.return_value = mock_app auth = MsalAuth(cache_path=cache_path) await auth.get_token() - mock_app.acquire_token_silent.assert_called_once_with( - SCOPES, account=acct - ) + mock_app.acquire_token_silent.assert_called_once_with(SCOPES, account=acct) @patch("flameconnect.auth.msal") - async def test_silent_none_result_falls_through( - self, mock_msal, tmp_path - ): + async def test_silent_none_result_falls_through(self, mock_msal, tmp_path): """When acquire_token_silent returns None, fall to interactive.""" cache_path = tmp_path / "token_cache.json" @@ -446,23 +375,17 @@ async def test_silent_none_result_falls_through( } mock_msal.PublicClientApplication.return_value = mock_app - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=abc123" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) token = await auth.get_token() assert token == "interactive-tok" mock_app.initiate_auth_code_flow.assert_called_once() @patch("flameconnect.auth.msal") - async def test_silent_result_without_access_token_key( - self, mock_msal, tmp_path - ): + async def test_silent_result_without_access_token_key(self, mock_msal, tmp_path): """Result dict without 'access_token' falls to interactive.""" cache_path = tmp_path / "token_cache.json" @@ -473,9 +396,7 @@ async def test_silent_result_without_access_token_key( mock_app = MagicMock() mock_app.get_accounts.return_value = [{"username": "u"}] # Result is truthy but has no access_token key - mock_app.acquire_token_silent.return_value = { - "error": "interaction_required" - } + mock_app.acquire_token_silent.return_value = {"error": "interaction_required"} mock_app.initiate_auth_code_flow.return_value = { "auth_uri": "https://example.com/auth", } @@ -484,14 +405,10 @@ async def test_silent_result_without_access_token_key( } mock_msal.PublicClientApplication.return_value = mock_app - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=abc123" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) token = await auth.get_token() assert token == "fallback-tok" @@ -533,18 +450,12 @@ async def test_initiate_flow_receives_scopes_and_redirect( self, mock_msal, tmp_path ): """initiate_auth_code_flow gets SCOPES and _REDIRECT_URI.""" - cache_path, _, mock_app, _ = self._setup_interactive_mocks( - mock_msal, tmp_path - ) + cache_path, _, mock_app, _ = self._setup_interactive_mocks(mock_msal, tmp_path) - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=c1" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) await auth.get_token() mock_app.initiate_auth_code_flow.assert_called_once_with( @@ -557,32 +468,22 @@ async def test_prompt_callback_receives_auth_and_redirect_uri( self, mock_msal, tmp_path ): """prompt_callback gets auth_uri from flow and _REDIRECT_URI.""" - cache_path, _, mock_app, _ = self._setup_interactive_mocks( - mock_msal, tmp_path - ) + cache_path, _, mock_app, _ = self._setup_interactive_mocks(mock_msal, tmp_path) received_args: list[tuple[str, str]] = [] - async def capturing_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def capturing_prompt(auth_uri: str, redirect_uri: str) -> str: received_args.append((auth_uri, redirect_uri)) return "https://redirect?code=c1" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=capturing_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=capturing_prompt) await auth.get_token() assert len(received_args) == 1 - assert received_args[0][0] == ( - "https://example.com/auth?foo=bar" - ) + assert received_args[0][0] == ("https://example.com/auth?foo=bar") assert received_args[0][1] == _REDIRECT_URI @patch("flameconnect.auth.msal") - async def test_acquire_token_by_auth_code_flow_args( - self, mock_msal, tmp_path - ): + async def test_acquire_token_by_auth_code_flow_args(self, mock_msal, tmp_path): """acquire_token_by_auth_code_flow gets flow and parsed URL.""" ( cache_path, @@ -591,41 +492,29 @@ async def test_acquire_token_by_auth_code_flow_args( flow_dict, ) = self._setup_interactive_mocks(mock_msal, tmp_path) - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=c1&state=s1" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) await auth.get_token() - args = ( - mock_app.acquire_token_by_auth_code_flow.call_args - ) + args = mock_app.acquire_token_by_auth_code_flow.call_args # First positional arg is the flow dict assert args[0][0] is flow_dict # Second positional arg is parsed redirect URL dict assert args[0][1] == {"code": "c1", "state": "s1"} @patch("flameconnect.auth.msal") - async def test_interactive_flow_saves_cache( - self, mock_msal, tmp_path - ): + async def test_interactive_flow_saves_cache(self, mock_msal, tmp_path): """After successful interactive flow, cache is saved.""" - cache_path, mock_cache, mock_app, _ = ( - self._setup_interactive_mocks(mock_msal, tmp_path) + cache_path, mock_cache, mock_app, _ = self._setup_interactive_mocks( + mock_msal, tmp_path ) - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=c1" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) token = await auth.get_token() assert token == "new-token-xyz" @@ -633,29 +522,19 @@ async def fake_prompt( assert cache_path.exists() @patch("flameconnect.auth.msal") - async def test_interactive_flow_strips_redirect_response( - self, mock_msal, tmp_path - ): + async def test_interactive_flow_strips_redirect_response(self, mock_msal, tmp_path): """redirect_response is stripped before parsing.""" - cache_path, _, mock_app, _ = self._setup_interactive_mocks( - mock_msal, tmp_path - ) + cache_path, _, mock_app, _ = self._setup_interactive_mocks(mock_msal, tmp_path) - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return " https://redirect?code=c1 \n" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) token = await auth.get_token() assert token == "new-token-xyz" @patch("flameconnect.auth.msal") - async def test_auth_uri_from_flow_used_in_prompt( - self, mock_msal, tmp_path - ): + async def test_auth_uri_from_flow_used_in_prompt(self, mock_msal, tmp_path): """The auth_uri extracted from flow dict is passed to prompt.""" cache_path = tmp_path / "token_cache.json" @@ -669,16 +548,12 @@ async def test_auth_uri_from_flow_used_in_prompt( mock_app.initiate_auth_code_flow.return_value = { "auth_uri": specific_uri, } - mock_app.acquire_token_by_auth_code_flow.return_value = { - "access_token": "t" - } + mock_app.acquire_token_by_auth_code_flow.return_value = {"access_token": "t"} mock_msal.PublicClientApplication.return_value = mock_app captured_uri = None - async def capturing_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def capturing_prompt(auth_uri: str, redirect_uri: str) -> str: nonlocal captured_uri captured_uri = auth_uri return "https://redirect?code=c1" @@ -722,14 +597,10 @@ async def test_error_and_description_in_exception_message( } mock_msal.PublicClientApplication.return_value = mock_app - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=expired-code" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) with pytest.raises(AuthenticationError) as exc_info: await auth.get_token() @@ -739,9 +610,7 @@ async def fake_prompt( assert "Code expired" in msg @patch("flameconnect.auth.msal") - async def test_exchange_error_defaults_when_missing( - self, mock_msal, tmp_path - ): + async def test_exchange_error_defaults_when_missing(self, mock_msal, tmp_path): """Default error='unknown' and description='N/A' used.""" cache_path = tmp_path / "token_cache.json" @@ -760,14 +629,10 @@ async def test_exchange_error_defaults_when_missing( } mock_msal.PublicClientApplication.return_value = mock_app - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=bad-code" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) with pytest.raises(AuthenticationError) as exc_info: await auth.get_token() @@ -808,9 +673,7 @@ class TestSaveCacheNotCalledWhenUnchanged: """Verify cache is NOT saved when has_state_changed is False.""" @patch("flameconnect.auth.msal") - async def test_cache_not_written_when_unchanged( - self, mock_msal, tmp_path - ): + async def test_cache_not_written_when_unchanged(self, mock_msal, tmp_path): """When has_state_changed is False, file is NOT written.""" cache_path = tmp_path / "token_cache.json" @@ -820,9 +683,7 @@ async def test_cache_not_written_when_unchanged( mock_app = MagicMock() mock_app.get_accounts.return_value = [{"username": "u"}] - mock_app.acquire_token_silent.return_value = { - "access_token": "t" - } + mock_app.acquire_token_silent.return_value = {"access_token": "t"} mock_msal.PublicClientApplication.return_value = mock_app auth = MsalAuth(cache_path=cache_path) @@ -841,9 +702,7 @@ class TestMsalAuthLogMessages: """Verify exact log messages emitted by MsalAuth.""" @patch("flameconnect.auth.msal") - async def test_silent_acquisition_log_messages( - self, mock_msal, tmp_path, caplog - ): + async def test_silent_acquisition_log_messages(self, mock_msal, tmp_path, caplog): """Verify all log messages for silent token acquisition.""" cache_path = tmp_path / "token_cache.json" @@ -853,9 +712,7 @@ async def test_silent_acquisition_log_messages( mock_app = MagicMock() mock_app.get_accounts.return_value = [{"username": "u"}] - mock_app.acquire_token_silent.return_value = { - "access_token": "tok" - } + mock_app.acquire_token_silent.return_value = {"access_token": "tok"} mock_msal.PublicClientApplication.return_value = mock_app auth = MsalAuth(cache_path=cache_path) @@ -863,19 +720,11 @@ async def test_silent_acquisition_log_messages( await auth.get_token() messages = [r.message for r in caplog.records] - assert any( - m == "Attempting silent token acquisition" - for m in messages - ) - assert any( - m.startswith("Token acquired silently") - for m in messages - ) + assert any(m == "Attempting silent token acquisition" for m in messages) + assert any(m.startswith("Token acquired silently") for m in messages) @patch("flameconnect.auth.msal") - async def test_interactive_flow_log_messages( - self, mock_msal, tmp_path, caplog - ): + async def test_interactive_flow_log_messages(self, mock_msal, tmp_path, caplog): """Verify all log messages for interactive flow.""" cache_path = tmp_path / "token_cache.json" @@ -889,39 +738,23 @@ async def test_interactive_flow_log_messages( mock_app.initiate_auth_code_flow.return_value = { "auth_uri": "https://example.com/auth", } - mock_app.acquire_token_by_auth_code_flow.return_value = { - "access_token": "tok" - } + mock_app.acquire_token_by_auth_code_flow.return_value = {"access_token": "tok"} mock_msal.PublicClientApplication.return_value = mock_app - async def fake_prompt( - auth_uri: str, redirect_uri: str - ) -> str: + async def fake_prompt(auth_uri: str, redirect_uri: str) -> str: return "https://redirect?code=c1" - auth = MsalAuth( - cache_path=cache_path, prompt_callback=fake_prompt - ) + auth = MsalAuth(cache_path=cache_path, prompt_callback=fake_prompt) with caplog.at_level(logging.DEBUG): await auth.get_token() messages = [r.message for r in caplog.records] - assert any( - m.startswith("No cached token") for m in messages - ) - assert any( - m.startswith("Exchanging authorization code") - for m in messages - ) - assert any( - m.startswith("Authentication successful") - for m in messages - ) + assert any(m.startswith("No cached token") for m in messages) + assert any(m.startswith("Exchanging authorization code") for m in messages) + assert any(m.startswith("Authentication successful") for m in messages) @patch("flameconnect.auth.msal") - async def test_save_cache_log_message( - self, mock_msal, tmp_path, caplog - ): + async def test_save_cache_log_message(self, mock_msal, tmp_path, caplog): """Verify log message when cache is saved.""" cache_path = tmp_path / "token_cache.json" @@ -932,9 +765,7 @@ async def test_save_cache_log_message( mock_app = MagicMock() mock_app.get_accounts.return_value = [{"username": "u"}] - mock_app.acquire_token_silent.return_value = { - "access_token": "tok" - } + mock_app.acquire_token_silent.return_value = {"access_token": "tok"} mock_msal.PublicClientApplication.return_value = mock_app auth = MsalAuth(cache_path=cache_path) @@ -942,14 +773,9 @@ async def test_save_cache_log_message( await auth.get_token() messages = [r.message for r in caplog.records] - assert any( - m.startswith("Token cache saved to") - for m in messages - ) + assert any(m.startswith("Token cache saved to") for m in messages) # Verify the cache path is included in the log message - assert any( - str(cache_path) in m for m in messages - ) + assert any(str(cache_path) in m for m in messages) # --------------------------------------------------------------------------- diff --git a/tests/test_b2c_login.py b/tests/test_b2c_login.py index 7610f39..059f8d9 100644 --- a/tests/test_b2c_login.py +++ b/tests/test_b2c_login.py @@ -9,6 +9,7 @@ import aiohttp import pytest +import yarl from multidict import CIMultiDict from flameconnect.b2c_login import ( @@ -59,10 +60,7 @@ ) _CLIENT_ID = "1af761dc-085a-411f-9cb9-53e5e2115bd2" -REDIRECT_URL = ( - f"msal{_CLIENT_ID}://auth" - "?code=test-auth-code-123&state=test-state" -) +REDIRECT_URL = f"msal{_CLIENT_ID}://auth?code=test-auth-code-123&state=test-state" AUTH_URI = "https://example.com/authorize" @@ -77,9 +75,7 @@ class TestExtractBasePath: def test_short_path_returns_root(self): """URL with fewer than 2 path segments returns '/'.""" - assert _extract_base_path( - "https://example.com/single" - ) == "/" + assert _extract_base_path("https://example.com/single") == "/" def test_no_path_returns_root(self): """URL with no path returns '/'.""" @@ -93,9 +89,7 @@ def test_normal_b2c_url(self): "oauth2/v2.0/authorize?params" ) result = _extract_base_path(url) - expected = ( - f"/tenant.onmicrosoft.com/{_POLICY}/" - ) + expected = f"/tenant.onmicrosoft.com/{_POLICY}/" assert result == expected def test_exactly_two_segments(self): @@ -128,56 +122,35 @@ class TestParseLoginPage: """Test HTML parsing of the B2C login page.""" def test_extracts_csrf_token(self): - result = _parse_login_page( - SAMPLE_B2C_HTML, SAMPLE_PAGE_URL - ) + result = _parse_login_page(SAMPLE_B2C_HTML, SAMPLE_PAGE_URL) assert result["csrf"] == "dGVzdC1jc3JmLXRva2Vu" def test_extracts_transaction_id(self): - result = _parse_login_page( - SAMPLE_B2C_HTML, SAMPLE_PAGE_URL - ) - tx = ( - "StateProperties=" - "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9" - ) + result = _parse_login_page(SAMPLE_B2C_HTML, SAMPLE_PAGE_URL) + tx = "StateProperties=eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9" assert result["tx"] == tx def test_uses_hardcoded_policy(self): - result = _parse_login_page( - SAMPLE_B2C_HTML, SAMPLE_PAGE_URL - ) + result = _parse_login_page(SAMPLE_B2C_HTML, SAMPLE_PAGE_URL) assert result["p"] == _POLICY def test_builds_post_url(self): - result = _parse_login_page( - SAMPLE_B2C_HTML, SAMPLE_PAGE_URL - ) - expected_prefix = ( - f"{_HOST}/{_TENANT}/{_POLICY}/SelfAsserted?" - ) + result = _parse_login_page(SAMPLE_B2C_HTML, SAMPLE_PAGE_URL) + expected_prefix = f"{_HOST}/{_TENANT}/{_POLICY}/SelfAsserted?" assert result["post_url"].startswith(expected_prefix) assert "tx=StateProperties" in result["post_url"] assert f"p={_POLICY}" in result["post_url"] def test_builds_confirmed_url(self): - result = _parse_login_page( - SAMPLE_B2C_HTML, SAMPLE_PAGE_URL - ) - expected = ( - f"{_HOST}/{_TENANT}/{_POLICY}/" - "api/CombinedSigninAndSignup/confirmed" - ) + result = _parse_login_page(SAMPLE_B2C_HTML, SAMPLE_PAGE_URL) + expected = f"{_HOST}/{_TENANT}/{_POLICY}/api/CombinedSigninAndSignup/confirmed" assert result["confirmed_url"] == expected def test_missing_csrf_raises_exact_msg(self): html = "No settings here" with pytest.raises( AuthenticationError, - match=( - "^Could not find CSRF token" - " in B2C login page$" - ), + match=("^Could not find CSRF token in B2C login page$"), ): _parse_login_page(html, SAMPLE_PAGE_URL) @@ -185,10 +158,7 @@ def test_missing_trans_id_raises_exact_msg(self): html = '' with pytest.raises( AuthenticationError, - match=( - "^Could not find transId" - " in B2C login page$" - ), + match=("^Could not find transId in B2C login page$"), ): _parse_login_page(html, SAMPLE_PAGE_URL) @@ -214,20 +184,14 @@ def test_formats_cookies_unquoted(self): jar = MagicMock(spec=aiohttp.CookieJar) jar.filter_cookies.return_value = {"a": m1, "b": m2} - result = _build_cookie_header( - jar, "https://example.com/path" - ) - jar.filter_cookies.assert_called_once_with( - "https://example.com/path" - ) + result = _build_cookie_header(jar, "https://example.com/path") + jar.filter_cookies.assert_called_once_with(yarl.URL("https://example.com/path")) assert result == "session=abc+123; token=xyz=456" def test_empty_jar(self): jar = MagicMock(spec=aiohttp.CookieJar) jar.filter_cookies.return_value = {} - result = _build_cookie_header( - jar, "https://example.com" - ) + result = _build_cookie_header(jar, "https://example.com") assert result == "" def test_single_cookie(self): @@ -236,9 +200,7 @@ def test_single_cookie(self): m1.value = "y" jar = MagicMock(spec=aiohttp.CookieJar) jar.filter_cookies.return_value = {"a": m1} - result = _build_cookie_header( - jar, "https://example.com" - ) + result = _build_cookie_header(jar, "https://example.com") assert result == "x=y" @@ -309,9 +271,7 @@ def test_no_params_no_extra_log(self, caplog): class TestLogResponse: """Test _log_response with captured log records.""" - def _make_resp( - self, status=200, url="https://example.com" - ): + def _make_resp(self, status=200, url="https://example.com"): resp = MagicMock() resp.status = status resp.url = url @@ -452,9 +412,7 @@ def _patch_sessions( def _patch_session(session: MagicMock): """Patch aiohttp.ClientSession and CookieJar (compat).""" with ( - patch( - f"{_MOD}.ClientSession", return_value=session - ), + patch(f"{_MOD}.ClientSession", return_value=session), patch(f"{_MOD}.CookieJar"), ): yield @@ -470,18 +428,14 @@ async def test_successful_login(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') confirmed_resp = _make_mock_response( status=302, headers={"Location": REDIRECT_URL}, ) session = _make_mock_session( - get=MagicMock( - side_effect=[login_resp, confirmed_resp] - ), + get=MagicMock(side_effect=[login_resp, confirmed_resp]), post=MagicMock(return_value=post_resp), ) @@ -510,13 +464,14 @@ async def test_bad_credentials_raises(self): post=MagicMock(return_value=post_resp), ) - with _patch_session(session), pytest.raises( - AuthenticationError, - match="Invalid email or password", + with ( + _patch_session(session), + pytest.raises( + AuthenticationError, + match="Invalid email or password", + ), ): - await b2c_login_with_credentials( - AUTH_URI, "bad@test.com", "wrong" - ) + await b2c_login_with_credentials(AUTH_URI, "bad@test.com", "wrong") async def test_bad_credentials_spaced_json(self): """Status 400 with spaces around colon also caught.""" @@ -535,30 +490,30 @@ async def test_bad_credentials_spaced_json(self): post=MagicMock(return_value=post_resp), ) - with _patch_session(session), pytest.raises( - AuthenticationError, - match="Invalid email or password", + with ( + _patch_session(session), + pytest.raises( + AuthenticationError, + match="Invalid email or password", + ), ): - await b2c_login_with_credentials( - AUTH_URI, "bad@test.com", "wrong" - ) + await b2c_login_with_credentials(AUTH_URI, "bad@test.com", "wrong") async def test_login_page_http_error_raises(self): """Non-200 login page raises AuthenticationError.""" - login_resp = _make_mock_response( - status=500, text="Server Error" - ) + login_resp = _make_mock_response(status=500, text="Server Error") session = _make_mock_session( get=MagicMock(return_value=login_resp), ) - with _patch_session(session), pytest.raises( - AuthenticationError, - match="B2C login page returned HTTP 500", + with ( + _patch_session(session), + pytest.raises( + AuthenticationError, + match="B2C login page returned HTTP 500", + ), ): - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") async def test_parse_failure_raises(self): """Unparseable HTML raises AuthenticationError.""" @@ -571,12 +526,11 @@ async def test_parse_failure_raises(self): get=MagicMock(return_value=login_resp), ) - with _patch_session(session), pytest.raises( - AuthenticationError, match="CSRF token" + with ( + _patch_session(session), + pytest.raises(AuthenticationError, match="CSRF token"), ): - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") async def test_multi_hop_redirect(self): """Intermediate HTTP redirects before custom scheme.""" @@ -585,14 +539,10 @@ async def test_multi_hop_redirect(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') intermediate_resp = _make_mock_response( status=302, - headers={ - "Location": "https://example.com/hop" - }, + headers={"Location": "https://example.com/hop"}, ) final_resp = _make_mock_response( status=302, @@ -611,9 +561,7 @@ async def test_multi_hop_redirect(self): ) with _patch_session(session): - result = await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + result = await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") assert result == REDIRECT_URL @@ -624,22 +572,21 @@ async def test_credential_post_http_error_raises(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=500, text="Server Error" - ) + post_resp = _make_mock_response(status=500, text="Server Error") session = _make_mock_session( get=MagicMock(return_value=login_resp), post=MagicMock(return_value=post_resp), ) - with _patch_session(session), pytest.raises( - AuthenticationError, - match="Credential submission returned HTTP 500", + with ( + _patch_session(session), + pytest.raises( + AuthenticationError, + match="Credential submission returned HTTP 500", + ), ): - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") async def test_redirect_url_in_page_body(self): """Redirect URL in response body is captured.""" @@ -648,30 +595,18 @@ async def test_redirect_url_in_page_body(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') - body = ( - "" - ) - confirmed_resp = _make_mock_response( - status=200, text=body - ) + body = f'' + confirmed_resp = _make_mock_response(status=200, text=body) session = _make_mock_session( - get=MagicMock( - side_effect=[login_resp, confirmed_resp] - ), + get=MagicMock(side_effect=[login_resp, confirmed_resp]), post=MagicMock(return_value=post_resp), ) with _patch_session(session): - result = await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + result = await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") assert "code=test-auth-code-123" in result @@ -682,29 +617,22 @@ async def test_redirect_without_location_raises(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) - no_loc = _make_mock_response( - status=302, headers={} - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') + no_loc = _make_mock_response(status=302, headers={}) session = _make_mock_session( - get=MagicMock( - side_effect=[login_resp, no_loc] - ), + get=MagicMock(side_effect=[login_resp, no_loc]), post=MagicMock(return_value=post_resp), ) - with _patch_session(session), pytest.raises( - AuthenticationError, - match=( - "Redirect without Location header" + with ( + _patch_session(session), + pytest.raises( + AuthenticationError, + match=("Redirect without Location header"), ), ): - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") async def test_relative_redirect_resolved(self): """Relative Location is resolved against current URL.""" @@ -713,9 +641,7 @@ async def test_relative_redirect_resolved(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') relative_resp = _make_mock_response( status=302, headers={"Location": "/some/relative/path"}, @@ -737,9 +663,7 @@ async def test_relative_redirect_resolved(self): ) with _patch_session(session): - result = await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + result = await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") assert result == REDIRECT_URL @@ -750,31 +674,25 @@ async def test_200_without_redirect_url_raises(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') no_redir = _make_mock_response( status=200, text="No redirect here", ) session = _make_mock_session( - get=MagicMock( - side_effect=[login_resp, no_redir] - ), + get=MagicMock(side_effect=[login_resp, no_redir]), post=MagicMock(return_value=post_resp), ) - with _patch_session(session), pytest.raises( - AuthenticationError, - match=( - "Reached 200 response without" - " finding redirect URL" + with ( + _patch_session(session), + pytest.raises( + AuthenticationError, + match=("Reached 200 response without finding redirect URL"), ), ): - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") async def test_unexpected_http_status_raises(self): """Unexpected HTTP status during redirect chain.""" @@ -783,30 +701,22 @@ async def test_unexpected_http_status_raises(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) - forbidden = _make_mock_response( - status=403, text="Forbidden" - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') + forbidden = _make_mock_response(status=403, text="Forbidden") session = _make_mock_session( - get=MagicMock( - side_effect=[login_resp, forbidden] - ), + get=MagicMock(side_effect=[login_resp, forbidden]), post=MagicMock(return_value=post_resp), ) - with _patch_session(session), pytest.raises( - AuthenticationError, - match=( - "Unexpected HTTP 403" - " during redirect chain" + with ( + _patch_session(session), + pytest.raises( + AuthenticationError, + match=("Unexpected HTTP 403 during redirect chain"), ), ): - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") async def test_too_many_redirects_raises(self): """Exceeding max redirect hops raises.""" @@ -815,32 +725,25 @@ async def test_too_many_redirects_raises(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') redirect_resp = _make_mock_response( status=302, - headers={ - "Location": "https://example.com/loop" - }, + headers={"Location": "https://example.com/loop"}, ) session = _make_mock_session( - get=MagicMock( - side_effect=( - [login_resp] + [redirect_resp] * 21 - ) - ), + get=MagicMock(side_effect=([login_resp] + [redirect_resp] * 21)), post=MagicMock(return_value=post_resp), ) - with _patch_session(session), pytest.raises( - AuthenticationError, - match="Too many redirects during B2C login", + with ( + _patch_session(session), + pytest.raises( + AuthenticationError, + match="Too many redirects during B2C login", + ), ): - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") async def test_network_error_wrapped(self): """aiohttp.ClientError wrapped in AuthError.""" @@ -853,19 +756,16 @@ async def test_network_error_wrapped(self): session = _make_mock_session( get=MagicMock(side_effect=[login_resp]), ) - session.post = MagicMock( - side_effect=aiohttp.ClientError( - "Connection reset" - ) - ) + session.post = MagicMock(side_effect=aiohttp.ClientError("Connection reset")) - with _patch_session(session), pytest.raises( - AuthenticationError, - match="Network error during B2C login", + with ( + _patch_session(session), + pytest.raises( + AuthenticationError, + match="Network error during B2C login", + ), ): - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") # ------------------------------------------------------------------- @@ -883,9 +783,7 @@ async def test_session_created_with_cookie_jar(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') confirmed_resp = _make_mock_response( status=302, headers={"Location": REDIRECT_URL}, @@ -899,12 +797,8 @@ async def test_session_created_with_cookie_jar(self): get=MagicMock(return_value=confirmed_resp), ) - with _patch_sessions( - login_session, raw_session - ) as (jar_cls, _jar): - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + with _patch_sessions(login_session, raw_session) as (jar_cls, _jar): + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") jar_cls.assert_called_once_with(unsafe=True) @@ -915,9 +809,7 @@ async def test_session_headers_user_agent(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') confirmed_resp = _make_mock_response( status=302, headers={"Location": REDIRECT_URL}, @@ -932,17 +824,11 @@ def capture_cs(**kwargs): call_idx[0] += 1 if idx == 0: return _make_mock_session( - get=MagicMock( - return_value=login_resp - ), + get=MagicMock(return_value=login_resp), ) return _make_mock_session( - post=MagicMock( - return_value=post_resp - ), - get=MagicMock( - return_value=confirmed_resp - ), + post=MagicMock(return_value=post_resp), + get=MagicMock(return_value=confirmed_resp), ) with ( @@ -956,16 +842,14 @@ def capture_cs(**kwargs): jar = MagicMock() jar.filter_cookies.return_value = {} jar_cls.return_value = jar - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") assert len(cs_calls) == 2 for i, call in enumerate(cs_calls): hdrs = call.get("headers", {}) - assert hdrs.get("User-Agent") == ( - _USER_AGENT - ), f"Session {i} missing User-Agent" + assert hdrs.get("User-Agent") == (_USER_AGENT), ( + f"Session {i} missing User-Agent" + ) async def test_get_auth_uri_with_redirects(self): """Initial GET uses auth_uri with allow_redirects.""" @@ -974,9 +858,7 @@ async def test_get_auth_uri_with_redirects(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') confirmed_resp = _make_mock_response( status=302, headers={"Location": REDIRECT_URL}, @@ -990,12 +872,8 @@ async def test_get_auth_uri_with_redirects(self): get=MagicMock(return_value=confirmed_resp), ) - with _patch_sessions( - login_session, raw_session - ): - await b2c_login_with_credentials( - AUTH_URI, "user@test.com", "pass" - ) + with _patch_sessions(login_session, raw_session): + await b2c_login_with_credentials(AUTH_URI, "user@test.com", "pass") # Verify initial GET call = login_session.get.call_args @@ -1009,9 +887,7 @@ async def test_post_data_fields(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') confirmed_resp = _make_mock_response( status=302, headers={"Location": REDIRECT_URL}, @@ -1025,12 +901,8 @@ async def test_post_data_fields(self): get=MagicMock(return_value=confirmed_resp), ) - with _patch_sessions( - login_session, raw_session - ): - await b2c_login_with_credentials( - AUTH_URI, "me@test.com", "mypass" - ) + with _patch_sessions(login_session, raw_session): + await b2c_login_with_credentials(AUTH_URI, "me@test.com", "mypass") call = raw_session.post.call_args data = call[1]["data"] @@ -1045,9 +917,7 @@ async def test_post_headers(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') confirmed_resp = _make_mock_response( status=302, headers={"Location": REDIRECT_URL}, @@ -1061,30 +931,18 @@ async def test_post_headers(self): get=MagicMock(return_value=confirmed_resp), ) - with _patch_sessions( - login_session, raw_session - ): - await b2c_login_with_credentials( - AUTH_URI, "me@test.com", "pass" - ) + with _patch_sessions(login_session, raw_session): + await b2c_login_with_credentials(AUTH_URI, "me@test.com", "pass") call = raw_session.post.call_args hdrs = call[1]["headers"] - assert hdrs["X-CSRF-TOKEN"] == ( - "dGVzdC1jc3JmLXRva2Vu" - ) - assert hdrs["X-Requested-With"] == ( - "XMLHttpRequest" - ) + assert hdrs["X-CSRF-TOKEN"] == ("dGVzdC1jc3JmLXRva2Vu") + assert hdrs["X-Requested-With"] == ("XMLHttpRequest") assert hdrs["Referer"] == AUTH_URI assert "Origin" in hdrs - assert hdrs["Accept"] == ( - "application/json, text/javascript," - " */*; q=0.01" - ) + assert hdrs["Accept"] == ("application/json, text/javascript, */*; q=0.01") assert hdrs["Content-Type"] == ( - "application/x-www-form-urlencoded;" - " charset=UTF-8" + "application/x-www-form-urlencoded; charset=UTF-8" ) assert "Cookie" in hdrs @@ -1095,9 +953,7 @@ async def test_post_no_redirects(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') confirmed_resp = _make_mock_response( status=302, headers={"Location": REDIRECT_URL}, @@ -1111,12 +967,8 @@ async def test_post_no_redirects(self): get=MagicMock(return_value=confirmed_resp), ) - with _patch_sessions( - login_session, raw_session - ): - await b2c_login_with_credentials( - AUTH_URI, "me@test.com", "pass" - ) + with _patch_sessions(login_session, raw_session): + await b2c_login_with_credentials(AUTH_URI, "me@test.com", "pass") call = raw_session.post.call_args assert call[1]["allow_redirects"] is False @@ -1128,9 +980,7 @@ async def test_confirmed_url_has_query(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') confirmed_resp = _make_mock_response( status=302, headers={"Location": REDIRECT_URL}, @@ -1144,28 +994,19 @@ async def test_confirmed_url_has_query(self): get=MagicMock(return_value=confirmed_resp), ) - with _patch_sessions( - login_session, raw_session - ): - await b2c_login_with_credentials( - AUTH_URI, "me@test.com", "pass" - ) + with _patch_sessions(login_session, raw_session): + await b2c_login_with_credentials(AUTH_URI, "me@test.com", "pass") # First GET on raw_session is the confirmed URL get_call = raw_session.get.call_args - url_arg = get_call[1].get( - "url", get_call[0][0] - ) + url_arg = get_call[1].get("url", get_call[0][0]) url_str = str(url_arg) assert "?" in url_str assert "rememberMe=false" in url_str assert "csrf_token=" in url_str assert f"p={_POLICY}" in url_str assert "tx=" in url_str - assert ( - "api/CombinedSigninAndSignup/confirmed" - in url_str - ) + assert "api/CombinedSigninAndSignup/confirmed" in url_str async def test_confirmed_get_no_redirects(self): """Confirmed GET uses allow_redirects=False.""" @@ -1174,9 +1015,7 @@ async def test_confirmed_get_no_redirects(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') confirmed_resp = _make_mock_response( status=302, headers={"Location": REDIRECT_URL}, @@ -1190,12 +1029,8 @@ async def test_confirmed_get_no_redirects(self): get=MagicMock(return_value=confirmed_resp), ) - with _patch_sessions( - login_session, raw_session - ): - await b2c_login_with_credentials( - AUTH_URI, "me@test.com", "pass" - ) + with _patch_sessions(login_session, raw_session): + await b2c_login_with_credentials(AUTH_URI, "me@test.com", "pass") get_call = raw_session.get.call_args assert get_call[1]["allow_redirects"] is False @@ -1207,9 +1042,7 @@ async def test_confirmed_get_has_cookie_header(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') confirmed_resp = _make_mock_response( status=302, headers={"Location": REDIRECT_URL}, @@ -1223,12 +1056,8 @@ async def test_confirmed_get_has_cookie_header(self): get=MagicMock(return_value=confirmed_resp), ) - with _patch_sessions( - login_session, raw_session - ): - await b2c_login_with_credentials( - AUTH_URI, "me@test.com", "pass" - ) + with _patch_sessions(login_session, raw_session): + await b2c_login_with_credentials(AUTH_URI, "me@test.com", "pass") get_call = raw_session.get.call_args hdrs = get_call[1]["headers"] @@ -1269,12 +1098,8 @@ async def test_cookie_merging_from_post_response(self): get=MagicMock(return_value=confirmed_resp), ) - with _patch_sessions( - login_session, raw_session - ): - result = await b2c_login_with_credentials( - AUTH_URI, "me@test.com", "pass" - ) + with _patch_sessions(login_session, raw_session): + result = await b2c_login_with_credentials(AUTH_URI, "me@test.com", "pass") assert result == REDIRECT_URL # Verify cookies were merged @@ -1291,16 +1116,10 @@ async def test_all_redirect_status_codes(self): text=SAMPLE_B2C_HTML, url=SAMPLE_PAGE_URL, ) - post_resp = _make_mock_response( - status=200, text='{"status":"200"}' - ) + post_resp = _make_mock_response(status=200, text='{"status":"200"}') redir = _make_mock_response( status=code, - headers={ - "Location": ( - "https://example.com/hop" - ) - }, + headers={"Location": ("https://example.com/hop")}, ) final = _make_mock_response( status=302, @@ -1319,16 +1138,12 @@ async def test_all_redirect_status_codes(self): ) with _patch_session(session): - result = ( - await b2c_login_with_credentials( - AUTH_URI, - "user@test.com", - "pass", - ) + result = await b2c_login_with_credentials( + AUTH_URI, + "user@test.com", + "pass", ) - assert result == REDIRECT_URL, ( - f"Failed for status {code}" - ) + assert result == REDIRECT_URL, f"Failed for status {code}" # ------------------------------------------------------------------- @@ -1340,10 +1155,7 @@ class TestConstants: """Kill mutants on module-level constant strings.""" def test_b2c_policy_value(self): - assert _B2C_POLICY == ( - "B2C_1A_FirePhoneSignUpOrSignIn" - "WithPhoneOrEmail" - ) + assert _B2C_POLICY == ("B2C_1A_FirePhoneSignUpOrSignInWithPhoneOrEmail") def test_user_agent_contains_mozilla(self): assert "Mozilla/5.0" in _USER_AGENT diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 809db81..0dff54b 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -556,9 +556,15 @@ def test_dispatches_timer(self, capsys): def test_dispatches_software_version(self, capsys): param = SoftwareVersionParam( - ui_major=1, ui_minor=0, ui_test=0, - control_major=1, control_minor=0, control_test=0, - relay_major=1, relay_minor=0, relay_test=0, + ui_major=1, + ui_minor=0, + ui_test=0, + control_major=1, + control_minor=0, + control_test=0, + relay_major=1, + relay_minor=0, + relay_test=0, ) _display_parameter(param) assert "[327] Software Version" in capsys.readouterr().out @@ -696,9 +702,7 @@ class TestCmdSet: async def test_dispatch_mode(self, capsys): client = AsyncMock() - overview = FireOverview( - fire=_make_fire(), parameters=[_make_mode_param()] - ) + overview = FireOverview(fire=_make_fire(), parameters=[_make_mode_param()]) client.get_fire_overview.return_value = overview await cmd_set(client, FIRE_ID, "mode", "standby") out = capsys.readouterr().out @@ -756,9 +760,7 @@ async def test_dispatch_media_theme(self, capsys): async def test_dispatch_heat_mode(self, capsys): client = AsyncMock() - overview = FireOverview( - fire=_make_fire(), parameters=[_make_heat_param()] - ) + overview = FireOverview(fire=_make_fire(), parameters=[_make_heat_param()]) client.get_fire_overview.return_value = overview await cmd_set(client, FIRE_ID, "heat-mode", "eco") out = capsys.readouterr().out @@ -766,9 +768,7 @@ async def test_dispatch_heat_mode(self, capsys): async def test_dispatch_heat_temp(self, capsys): client = AsyncMock() - overview = FireOverview( - fire=_make_fire(), parameters=[_make_heat_param()] - ) + overview = FireOverview(fire=_make_fire(), parameters=[_make_heat_param()]) client.get_fire_overview.return_value = overview await cmd_set(client, FIRE_ID, "heat-temp", "25.0") out = capsys.readouterr().out @@ -1102,9 +1102,7 @@ class TestSetHeatModeEdgeCases: async def test_boost_with_duration(self, capsys): client = AsyncMock() - overview = FireOverview( - fire=_make_fire(), parameters=[_make_heat_param()] - ) + overview = FireOverview(fire=_make_fire(), parameters=[_make_heat_param()]) client.get_fire_overview.return_value = overview await cmd_set(client, FIRE_ID, "heat-mode", "boost:15") out = capsys.readouterr().out @@ -1275,8 +1273,10 @@ async def test_list_command(self): mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) - with patch("flameconnect.cli.MsalAuth"), \ - patch("flameconnect.cli.FlameConnectClient", return_value=mock_client): + with ( + patch("flameconnect.cli.MsalAuth"), + patch("flameconnect.cli.FlameConnectClient", return_value=mock_client), + ): await async_main(args) mock_client.get_fires.assert_awaited_once() @@ -1288,8 +1288,10 @@ async def test_status_command(self): mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) - with patch("flameconnect.cli.MsalAuth"), \ - patch("flameconnect.cli.FlameConnectClient", return_value=mock_client): + with ( + patch("flameconnect.cli.MsalAuth"), + patch("flameconnect.cli.FlameConnectClient", return_value=mock_client), + ): await async_main(args) mock_client.get_fire_overview.assert_awaited_once_with(FIRE_ID) @@ -1299,8 +1301,10 @@ async def test_on_command(self): mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) - with patch("flameconnect.cli.MsalAuth"), \ - patch("flameconnect.cli.FlameConnectClient", return_value=mock_client): + with ( + patch("flameconnect.cli.MsalAuth"), + patch("flameconnect.cli.FlameConnectClient", return_value=mock_client), + ): await async_main(args) mock_client.turn_on.assert_awaited_once_with(FIRE_ID) @@ -1310,8 +1314,10 @@ async def test_off_command(self): mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) - with patch("flameconnect.cli.MsalAuth"), \ - patch("flameconnect.cli.FlameConnectClient", return_value=mock_client): + with ( + patch("flameconnect.cli.MsalAuth"), + patch("flameconnect.cli.FlameConnectClient", return_value=mock_client), + ): await async_main(args) mock_client.turn_off.assert_awaited_once_with(FIRE_ID) @@ -1327,8 +1333,10 @@ async def test_set_command(self): mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) - with patch("flameconnect.cli.MsalAuth"), \ - patch("flameconnect.cli.FlameConnectClient", return_value=mock_client): + with ( + patch("flameconnect.cli.MsalAuth"), + patch("flameconnect.cli.FlameConnectClient", return_value=mock_client), + ): await async_main(args) mock_client.write_parameters.assert_awaited_once() @@ -1348,8 +1356,10 @@ class TestMain: """Tests for the synchronous main() entry point.""" def test_main_calls_async_main(self): - with patch("flameconnect.cli.build_parser") as mock_parser_fn, \ - patch("flameconnect.cli.asyncio") as mock_asyncio: + with ( + patch("flameconnect.cli.build_parser") as mock_parser_fn, + patch("flameconnect.cli.asyncio") as mock_asyncio, + ): mock_parser = MagicMock() mock_args = argparse.Namespace(command="list", verbose=False) mock_parser.parse_args.return_value = mock_args @@ -1360,9 +1370,11 @@ def test_main_calls_async_main(self): mock_asyncio.run.assert_called_once() def test_main_verbose_logging(self): - with patch("flameconnect.cli.build_parser") as mock_parser_fn, \ - patch("flameconnect.cli.asyncio"), \ - patch("flameconnect.cli.logging") as mock_logging: + with ( + patch("flameconnect.cli.build_parser") as mock_parser_fn, + patch("flameconnect.cli.asyncio"), + patch("flameconnect.cli.logging") as mock_logging, + ): import logging as real_logging mock_parser = MagicMock() @@ -1374,14 +1386,14 @@ def test_main_verbose_logging(self): main() - mock_logging.basicConfig.assert_called_once_with( - level=real_logging.DEBUG - ) + mock_logging.basicConfig.assert_called_once_with(level=real_logging.DEBUG) def test_main_no_verbose_logging(self): - with patch("flameconnect.cli.build_parser") as mock_parser_fn, \ - patch("flameconnect.cli.asyncio"), \ - patch("flameconnect.cli.logging") as mock_logging: + with ( + patch("flameconnect.cli.build_parser") as mock_parser_fn, + patch("flameconnect.cli.asyncio"), + patch("flameconnect.cli.logging") as mock_logging, + ): import logging as real_logging mock_parser = MagicMock() @@ -1393,9 +1405,7 @@ def test_main_no_verbose_logging(self): main() - mock_logging.basicConfig.assert_called_once_with( - level=real_logging.WARNING - ) + mock_logging.basicConfig.assert_called_once_with(level=real_logging.WARNING) # =================================================================== @@ -1455,9 +1465,11 @@ def _run_masked(chars: list[str], prompt: str = "Password: ") -> str: mock_termios.TCSADRAIN = 1 mock_tty = MagicMock() - with patch("sys.stdin", mock_stdin), \ - patch("sys.stdout", mock_stdout), \ - patch.dict("sys.modules", {"termios": mock_termios, "tty": mock_tty}): + with ( + patch("sys.stdin", mock_stdin), + patch("sys.stdout", mock_stdout), + patch.dict("sys.modules", {"termios": mock_termios, "tty": mock_tty}), + ): return _masked_input(prompt) def test_basic_input(self): @@ -1623,9 +1635,15 @@ async def test_all_parameter_types(self, capsys): HeatModeParam(heat_control=HeatControl.ENABLED), TimerParam(timer_status=TimerStatus.ENABLED, duration=60), SoftwareVersionParam( - ui_major=1, ui_minor=0, ui_test=0, - control_major=2, control_minor=0, control_test=0, - relay_major=3, relay_minor=0, relay_test=0, + ui_major=1, + ui_minor=0, + ui_test=0, + control_major=2, + control_minor=0, + control_test=0, + relay_major=3, + relay_minor=0, + relay_test=0, ), ErrorParam(error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=0), TempUnitParam(unit=TempUnit.CELSIUS), diff --git a/tests/test_cli_set.py b/tests/test_cli_set.py index 216c9f8..a8fca0b 100644 --- a/tests/test_cli_set.py +++ b/tests/test_cli_set.py @@ -286,9 +286,7 @@ async def test_set_mode_invalid(self, mock_api, token_auth, capsys): class TestSetFlameSpeed: """Tests for the _set_flame_speed CLI command.""" - async def test_set_flame_speed_valid( - self, mock_api, token_auth, overview_payload - ): + async def test_set_flame_speed_valid(self, mock_api, token_auth, overview_payload): mock_api.get(OVERVIEW_URL, payload=overview_payload) mock_api.post(WRITE_URL, payload={}) @@ -447,18 +445,14 @@ async def test_set_heat_mode_boost_duration_invalid_format( captured = capsys.readouterr() assert "Error" in captured.out - async def test_set_heat_mode_boost_duration_21( - self, mock_api, token_auth, capsys - ): + async def test_set_heat_mode_boost_duration_21(self, mock_api, token_auth, capsys): async with FlameConnectClient(token_auth) as client: with pytest.raises(SystemExit): await _set_heat_mode(client, FIRE_ID, "boost:21") captured = capsys.readouterr() assert "Error" in captured.out - async def test_set_heat_mode_reject_fan_only( - self, mock_api, token_auth, capsys - ): + async def test_set_heat_mode_reject_fan_only(self, mock_api, token_auth, capsys): async with FlameConnectClient(token_auth) as client: with pytest.raises(SystemExit): await _set_heat_mode(client, FIRE_ID, "fan-only") @@ -550,9 +544,7 @@ async def test_dispatch_pulsating(self, mock_api, token_auth, overview_payload): key = ("POST", URL(WRITE_URL)) assert len(mock_api.requests[key]) == 1 - async def test_dispatch_flame_color( - self, mock_api, token_auth, overview_payload - ): + async def test_dispatch_flame_color(self, mock_api, token_auth, overview_payload): mock_api.get(OVERVIEW_URL, payload=overview_payload) mock_api.post(WRITE_URL, payload={}) @@ -562,9 +554,7 @@ async def test_dispatch_flame_color( key = ("POST", URL(WRITE_URL)) assert len(mock_api.requests[key]) == 1 - async def test_dispatch_media_theme( - self, mock_api, token_auth, overview_payload - ): + async def test_dispatch_media_theme(self, mock_api, token_auth, overview_payload): mock_api.get(OVERVIEW_URL, payload=overview_payload) mock_api.post(WRITE_URL, payload={}) @@ -626,9 +616,7 @@ def test_wrong_count(self): class TestSetFlameEffect: """Tests for the _set_flame_effect CLI command.""" - async def test_set_flame_effect_on( - self, mock_api, token_auth, overview_payload - ): + async def test_set_flame_effect_on(self, mock_api, token_auth, overview_payload): mock_api.get(OVERVIEW_URL, payload=overview_payload) mock_api.post(WRITE_URL, payload={}) @@ -659,9 +647,7 @@ async def test_set_flame_effect_invalid(self, mock_api, token_auth, capsys): class TestSetMediaLight: """Tests for the _set_media_light CLI command.""" - async def test_set_media_light_on( - self, mock_api, token_auth, overview_payload - ): + async def test_set_media_light_on(self, mock_api, token_auth, overview_payload): mock_api.get(OVERVIEW_URL, payload=overview_payload) mock_api.post(WRITE_URL, payload={}) @@ -691,9 +677,7 @@ async def test_set_media_light_invalid(self, mock_api, token_auth, capsys): class TestSetOverheadLight: """Tests for the _set_overhead_light CLI command.""" - async def test_set_overhead_light_on( - self, mock_api, token_auth, overview_payload - ): + async def test_set_overhead_light_on(self, mock_api, token_auth, overview_payload): mock_api.get(OVERVIEW_URL, payload=overview_payload) mock_api.post(WRITE_URL, payload={}) @@ -723,9 +707,7 @@ async def test_set_overhead_light_invalid(self, mock_api, token_auth, capsys): class TestSetAmbientSensor: """Tests for the _set_ambient_sensor CLI command.""" - async def test_set_ambient_sensor_on( - self, mock_api, token_auth, overview_payload - ): + async def test_set_ambient_sensor_on(self, mock_api, token_auth, overview_payload): mock_api.get(OVERVIEW_URL, payload=overview_payload) mock_api.post(WRITE_URL, payload={}) @@ -755,9 +737,7 @@ async def test_set_ambient_sensor_invalid(self, mock_api, token_auth, capsys): class TestSetMediaColor: """Tests for the _set_media_color CLI command.""" - async def test_set_media_color_named( - self, mock_api, token_auth, overview_payload - ): + async def test_set_media_color_named(self, mock_api, token_auth, overview_payload): mock_api.get(OVERVIEW_URL, payload=overview_payload) mock_api.post(WRITE_URL, payload={}) @@ -771,9 +751,7 @@ async def test_set_media_color_named( assert body["FireId"] == FIRE_ID assert body["Parameters"][0]["ParameterId"] == 322 - async def test_set_media_color_rgbw( - self, mock_api, token_auth, overview_payload - ): + async def test_set_media_color_rgbw(self, mock_api, token_auth, overview_payload): mock_api.get(OVERVIEW_URL, payload=overview_payload) mock_api.post(WRITE_URL, payload={}) diff --git a/tests/test_client.py b/tests/test_client.py index a4282b3..2bb518d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -61,16 +61,12 @@ def token_auth() -> TokenAuth: @pytest.fixture def get_fires_payload() -> list[dict]: - return json.loads( - (FIXTURES_DIR / "get_fires.json").read_text() - ) + return json.loads((FIXTURES_DIR / "get_fires.json").read_text()) @pytest.fixture def get_fire_overview_payload() -> dict: - return json.loads( - (FIXTURES_DIR / "get_fire_overview.json").read_text() - ) + return json.loads((FIXTURES_DIR / "get_fire_overview.json").read_text()) def _make_overview_payload( @@ -120,9 +116,7 @@ def _make_overview_payload( class TestGetFires: """Test the get_fires() method.""" - async def test_returns_fire_list( - self, mock_api, token_auth, get_fires_payload - ): + async def test_returns_fire_list(self, mock_api, token_auth, get_fires_payload): url = f"{API_BASE}/api/Fires/GetFires" mock_api.get(url, payload=get_fires_payload) @@ -181,10 +175,7 @@ async def test_decodes_all_parameters( self, mock_api, token_auth, get_fire_overview_payload ): fire_id = "test-fire-001" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" mock_api.get(url, payload=get_fire_overview_payload) async with FlameConnectClient(token_auth) as client: @@ -210,20 +201,13 @@ async def test_mode_param_values( self, mock_api, token_auth, get_fire_overview_payload ): fire_id = "test-fire-001" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" mock_api.get(url, payload=get_fire_overview_payload) async with FlameConnectClient(token_auth) as client: overview = await client.get_fire_overview(fire_id) - mode = next( - p - for p in overview.parameters - if isinstance(p, ModeParam) - ) + mode = next(p for p in overview.parameters if isinstance(p, ModeParam)) assert mode.mode == FireMode.MANUAL assert mode.target_temperature == pytest.approx(22.5) @@ -231,20 +215,13 @@ async def test_flame_effect_param_values( self, mock_api, token_auth, get_fire_overview_payload ): fire_id = "test-fire-001" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" mock_api.get(url, payload=get_fire_overview_payload) async with FlameConnectClient(token_auth) as client: overview = await client.get_fire_overview(fire_id) - flame = next( - p - for p in overview.parameters - if isinstance(p, FlameEffectParam) - ) + flame = next(p for p in overview.parameters if isinstance(p, FlameEffectParam)) assert flame.flame_effect == FlameEffect.ON assert flame.flame_speed == 3 assert flame.brightness == Brightness.LOW @@ -258,20 +235,13 @@ async def test_heat_param_values( self, mock_api, token_auth, get_fire_overview_payload ): fire_id = "test-fire-001" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" mock_api.get(url, payload=get_fire_overview_payload) async with FlameConnectClient(token_auth) as client: overview = await client.get_fire_overview(fire_id) - heat = next( - p - for p in overview.parameters - if isinstance(p, HeatParam) - ) + heat = next(p for p in overview.parameters if isinstance(p, HeatParam)) assert heat.heat_status == HeatStatus.ON assert heat.heat_mode == HeatMode.NORMAL assert heat.setpoint_temperature == pytest.approx(22.0) @@ -281,20 +251,13 @@ async def test_software_version_values( self, mock_api, token_auth, get_fire_overview_payload ): fire_id = "test-fire-001" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" mock_api.get(url, payload=get_fire_overview_payload) async with FlameConnectClient(token_auth) as client: overview = await client.get_fire_overview(fire_id) - sw = next( - p - for p in overview.parameters - if isinstance(p, SoftwareVersionParam) - ) + sw = next(p for p in overview.parameters if isinstance(p, SoftwareVersionParam)) assert sw.ui_major == 1 assert sw.ui_minor == 2 assert sw.ui_test == 3 @@ -310,10 +273,7 @@ async def test_overview_uses_uppercase_get( ): """Kills mutant overview__mutmut_8: 'GET' -> 'get'.""" fire_id = "test-fire-001" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" mock_api.get(url, payload=get_fire_overview_payload) async with FlameConnectClient(token_auth) as client: @@ -331,10 +291,7 @@ async def test_overview_fire_fields_from_fixture( and mutants 46-102 (wrong .get() keys/defaults). """ fire_id = "test-fire-001" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" mock_api.get(url, payload=get_fire_overview_payload) async with FlameConnectClient(token_auth) as client: @@ -351,22 +308,15 @@ async def test_overview_fire_fields_from_fixture( assert f.with_heat is True assert f.is_iot_fire is True - async def test_overview_defaults_when_keys_missing( - self, mock_api, token_auth - ): + async def test_overview_defaults_when_keys_missing(self, mock_api, token_auth): """When optional keys are absent, defaults are used. Kills mutants for .get() default values (brand="", product_type="", etc.) and .get() key name mutations. """ fire_id = "minimal-fire" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) - payload = _make_overview_payload( - fire_id=fire_id, parameters=[] - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" + payload = _make_overview_payload(fire_id=fire_id, parameters=[]) mock_api.get(url, payload=payload) async with FlameConnectClient(token_auth) as client: @@ -388,22 +338,15 @@ async def test_overview_defaults_when_keys_missing( # No parameters assert overview.parameters == [] - async def test_overview_no_parameters_key( - self, mock_api, token_auth - ): + async def test_overview_no_parameters_key(self, mock_api, token_auth): """When Parameters key is absent, defaults to empty list. Kills mutants 105/107: Parameters default None or removed. """ fire_id = "no-params-fire" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" # Don't include "Parameters" key at all - payload = { - "WifiFireOverview": {"FireId": fire_id} - } + payload = {"WifiFireOverview": {"FireId": fire_id}} mock_api.get(url, payload=payload) async with FlameConnectClient(token_auth) as client: @@ -411,19 +354,14 @@ async def test_overview_no_parameters_key( assert overview.parameters == [] - async def test_continue_on_decode_failure_not_break( - self, mock_api, token_auth - ): + async def test_continue_on_decode_failure_not_break(self, mock_api, token_auth): """After a bad parameter, good ones still decode. Kills mutant 135: continue -> break. Place the bad param between two good ones. """ fire_id = "test-fire-001" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" mode_val = encode_parameter( ModeParam( mode=FireMode.MANUAL, @@ -463,20 +401,14 @@ async def test_continue_on_decode_failure_not_break( class TestWriteParameters: """Test the write_parameters() method.""" - async def test_sends_correct_payload( - self, mock_api, token_auth - ): + async def test_sends_correct_payload(self, mock_api, token_auth): url = f"{API_BASE}/api/Fires/WriteWifiParameters" mock_api.post(url, payload={}) - mode = ModeParam( - mode=FireMode.MANUAL, target_temperature=22.0 - ) + mode = ModeParam(mode=FireMode.MANUAL, target_temperature=22.0) async with FlameConnectClient(token_auth) as client: - await client.write_parameters( - "test-fire-001", [mode] - ) + await client.write_parameters("test-fire-001", [mode]) key = ("POST", URL(url)) calls = mock_api.requests[key] @@ -492,37 +424,25 @@ async def test_multiple_params(self, mock_api, token_auth): url = f"{API_BASE}/api/Fires/WriteWifiParameters" mock_api.post(url, payload={}) - mode = ModeParam( - mode=FireMode.MANUAL, target_temperature=22.0 - ) - heat_mode = HeatModeParam( - heat_control=HeatControl.ENABLED - ) + mode = ModeParam(mode=FireMode.MANUAL, target_temperature=22.0) + heat_mode = HeatModeParam(heat_control=HeatControl.ENABLED) async with FlameConnectClient(token_auth) as client: - await client.write_parameters( - "test-fire-001", [mode, heat_mode] - ) + await client.write_parameters("test-fire-001", [mode, heat_mode]) key = ("POST", URL(url)) calls = mock_api.requests[key] body = calls[0].kwargs["json"] assert len(body["Parameters"]) == 2 - param_ids = { - p["ParameterId"] for p in body["Parameters"] - } + param_ids = {p["ParameterId"] for p in body["Parameters"]} assert param_ids == {321, 325} - async def test_write_uses_post_method( - self, mock_api, token_auth - ): + async def test_write_uses_post_method(self, mock_api, token_auth): """Verify write uses POST (not lowercase).""" url = f"{API_BASE}/api/Fires/WriteWifiParameters" mock_api.post(url, payload={}) - mode = ModeParam( - mode=FireMode.MANUAL, target_temperature=22.0 - ) + mode = ModeParam(mode=FireMode.MANUAL, target_temperature=22.0) async with FlameConnectClient(token_auth) as client: await client.write_parameters("fire-1", [mode]) @@ -546,17 +466,10 @@ async def test_turn_on_preserves_settings( get_fire_overview_payload, ): fire_id = "test-fire-001" - overview_url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) - write_url = ( - f"{API_BASE}/api/Fires/WriteWifiParameters" - ) + overview_url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" + write_url = f"{API_BASE}/api/Fires/WriteWifiParameters" - mock_api.get( - overview_url, payload=get_fire_overview_payload - ) + mock_api.get(overview_url, payload=get_fire_overview_payload) mock_api.post(write_url, payload={}) async with FlameConnectClient(token_auth) as client: @@ -569,9 +482,7 @@ async def test_turn_on_preserves_settings( body = calls[0].kwargs["json"] assert body["FireId"] == fire_id - param_ids = { - p["ParameterId"] for p in body["Parameters"] - } + param_ids = {p["ParameterId"] for p in body["Parameters"]} assert 321 in param_ids assert 322 in param_ids @@ -589,16 +500,9 @@ async def test_turn_on_preserves_existing_temperature( be preserved in the written ModeParam. """ fire_id = "test-fire-001" - overview_url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) - write_url = ( - f"{API_BASE}/api/Fires/WriteWifiParameters" - ) - mock_api.get( - overview_url, payload=get_fire_overview_payload - ) + overview_url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" + write_url = f"{API_BASE}/api/Fires/WriteWifiParameters" + mock_api.get(overview_url, payload=get_fire_overview_payload) mock_api.post(write_url, payload={}) async with FlameConnectClient(token_auth) as client: @@ -607,11 +511,7 @@ async def test_turn_on_preserves_existing_temperature( key = ("POST", URL(write_url)) body = mock_api.requests[key][0].kwargs["json"] # Decode the written ModeParam - mode_wire = next( - p - for p in body["Parameters"] - if p["ParameterId"] == 321 - ) + mode_wire = next(p for p in body["Parameters"] if p["ParameterId"] == 321) raw = base64.b64decode(mode_wire["Value"]) # Byte 3 is mode (1=MANUAL), bytes 4-5 are temperature assert raw[3] == FireMode.MANUAL @@ -630,16 +530,9 @@ async def test_turn_on_flame_effect_set_to_on( no args (which would preserve OFF if it was OFF). """ fire_id = "test-fire-001" - overview_url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) - write_url = ( - f"{API_BASE}/api/Fires/WriteWifiParameters" - ) - mock_api.get( - overview_url, payload=get_fire_overview_payload - ) + overview_url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" + write_url = f"{API_BASE}/api/Fires/WriteWifiParameters" + mock_api.get(overview_url, payload=get_fire_overview_payload) mock_api.post(write_url, payload={}) async with FlameConnectClient(token_auth) as client: @@ -647,35 +540,22 @@ async def test_turn_on_flame_effect_set_to_on( key = ("POST", URL(write_url)) body = mock_api.requests[key][0].kwargs["json"] - flame_wire = next( - p - for p in body["Parameters"] - if p["ParameterId"] == 322 - ) + flame_wire = next(p for p in body["Parameters"] if p["ParameterId"] == 322) raw = base64.b64decode(flame_wire["Value"]) # Byte 3 is flame_effect: 1=ON assert raw[3] == FlameEffect.ON - async def test_turn_on_default_temp_no_mode_param( - self, mock_api, token_auth - ): + async def test_turn_on_default_temp_no_mode_param(self, mock_api, token_auth): """When no ModeParam exists, default temp is 22.0. Kills turn_on__mutmut_3 (current_mode="" instead of None), turn_on__mutmut_4, turn_on__mutmut_8. """ fire_id = "no-mode-fire" - overview_url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) - write_url = ( - f"{API_BASE}/api/Fires/WriteWifiParameters" - ) + overview_url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" + write_url = f"{API_BASE}/api/Fires/WriteWifiParameters" # Overview with NO parameters - payload = _make_overview_payload( - fire_id=fire_id, parameters=[] - ) + payload = _make_overview_payload(fire_id=fire_id, parameters=[]) mock_api.get(overview_url, payload=payload) mock_api.post(write_url, payload={}) @@ -684,20 +564,14 @@ async def test_turn_on_default_temp_no_mode_param( key = ("POST", URL(write_url)) body = mock_api.requests[key][0].kwargs["json"] - mode_wire = next( - p - for p in body["Parameters"] - if p["ParameterId"] == 321 - ) + mode_wire = next(p for p in body["Parameters"] if p["ParameterId"] == 321) raw = base64.b64decode(mode_wire["Value"]) temp = float(raw[4]) + float(raw[5]) / 10.0 assert temp == pytest.approx(22.0) # No FlameEffectParam in overview -> only ModeParam assert len(body["Parameters"]) == 1 - async def test_turn_on_sets_flame_on_when_initially_off( - self, mock_api, token_auth - ): + async def test_turn_on_sets_flame_on_when_initially_off(self, mock_api, token_auth): """When flame effect is OFF, turn_on sets it to ON. Kills turn_on__mutmut_20: replace(current_flame, ) @@ -705,13 +579,8 @@ async def test_turn_on_sets_flame_on_when_initially_off( code sets it to ON. """ fire_id = "flame-off-fire" - overview_url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) - write_url = ( - f"{API_BASE}/api/Fires/WriteWifiParameters" - ) + overview_url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" + write_url = f"{API_BASE}/api/Fires/WriteWifiParameters" mode_val = encode_parameter( ModeParam( mode=FireMode.STANDBY, @@ -752,31 +621,20 @@ async def test_turn_on_sets_flame_on_when_initially_off( key = ("POST", URL(write_url)) body = mock_api.requests[key][0].kwargs["json"] - flame_wire = next( - p - for p in body["Parameters"] - if p["ParameterId"] == 322 - ) + flame_wire = next(p for p in body["Parameters"] if p["ParameterId"] == 322) raw = base64.b64decode(flame_wire["Value"]) # Byte 3 is flame_effect: must be 1 (ON) assert raw[3] == FlameEffect.ON - async def test_turn_on_no_flame_param_writes_only_mode( - self, mock_api, token_auth - ): + async def test_turn_on_no_flame_param_writes_only_mode(self, mock_api, token_auth): """When no FlameEffectParam, only ModeParam is written. Kills turn_on__mutmut_4 (current_flame="" instead of None). """ fire_id = "no-flame-fire" - overview_url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) - write_url = ( - f"{API_BASE}/api/Fires/WriteWifiParameters" - ) + overview_url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" + write_url = f"{API_BASE}/api/Fires/WriteWifiParameters" mode_val = encode_parameter( ModeParam( mode=FireMode.STANDBY, @@ -817,16 +675,9 @@ async def test_turn_off_sends_standby( get_fire_overview_payload, ): fire_id = "test-fire-001" - overview_url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) - write_url = ( - f"{API_BASE}/api/Fires/WriteWifiParameters" - ) - mock_api.get( - overview_url, payload=get_fire_overview_payload - ) + overview_url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" + write_url = f"{API_BASE}/api/Fires/WriteWifiParameters" + mock_api.get(overview_url, payload=get_fire_overview_payload) mock_api.post(write_url, payload={}) async with FlameConnectClient(token_auth) as client: @@ -853,16 +704,9 @@ async def test_turn_off_preserves_temperature( temperature is read from existing ModeParam. """ fire_id = "test-fire-001" - overview_url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) - write_url = ( - f"{API_BASE}/api/Fires/WriteWifiParameters" - ) - mock_api.get( - overview_url, payload=get_fire_overview_payload - ) + overview_url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" + write_url = f"{API_BASE}/api/Fires/WriteWifiParameters" + mock_api.get(overview_url, payload=get_fire_overview_payload) mock_api.post(write_url, payload={}) async with FlameConnectClient(token_auth) as client: @@ -878,24 +722,15 @@ async def test_turn_off_preserves_temperature( temp = float(raw[4]) + float(raw[5]) / 10.0 assert temp == pytest.approx(22.5) - async def test_turn_off_default_temp_no_mode( - self, mock_api, token_auth - ): + async def test_turn_off_default_temp_no_mode(self, mock_api, token_auth): """When no ModeParam, default temp is 22.0. Kills turn_off__mutmut_3 and turn_off__mutmut_7. """ fire_id = "no-mode-fire" - overview_url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) - write_url = ( - f"{API_BASE}/api/Fires/WriteWifiParameters" - ) - payload = _make_overview_payload( - fire_id=fire_id, parameters=[] - ) + overview_url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" + write_url = f"{API_BASE}/api/Fires/WriteWifiParameters" + payload = _make_overview_payload(fire_id=fire_id, parameters=[]) mock_api.get(overview_url, payload=payload) mock_api.post(write_url, payload={}) @@ -918,9 +753,7 @@ async def test_turn_off_default_temp_no_mode( class TestApiErrorHandling: """Test non-2xx response handling.""" - async def test_401_raises_api_error( - self, mock_api, token_auth - ): + async def test_401_raises_api_error(self, mock_api, token_auth): url = f"{API_BASE}/api/Fires/GetFires" mock_api.get(url, status=401, body="Unauthorized") @@ -930,13 +763,9 @@ async def test_401_raises_api_error( assert exc_info.value.status == 401 - async def test_500_raises_api_error( - self, mock_api, token_auth - ): + async def test_500_raises_api_error(self, mock_api, token_auth): url = f"{API_BASE}/api/Fires/GetFires" - mock_api.get( - url, status=500, body="Internal Server Error" - ) + mock_api.get(url, status=500, body="Internal Server Error") async with FlameConnectClient(token_auth) as client: with pytest.raises(ApiError) as exc_info: @@ -944,14 +773,9 @@ async def test_500_raises_api_error( assert exc_info.value.status == 500 - async def test_404_raises_api_error( - self, mock_api, token_auth - ): + async def test_404_raises_api_error(self, mock_api, token_auth): fire_id = "nonexistent" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" mock_api.get(url, status=404, body="Not Found") async with FlameConnectClient(token_auth) as client: @@ -960,9 +784,7 @@ async def test_404_raises_api_error( assert exc_info.value.status == 404 - async def test_no_session_raises_runtime_error( - self, token_auth - ): + async def test_no_session_raises_runtime_error(self, token_auth): """Using client without context manager or session. Kills _request__mutmut_3/6/7/8 by matching both parts @@ -978,35 +800,27 @@ async def test_no_session_raises_runtime_error( ): await client.get_fires() - async def test_external_session( - self, mock_api, token_auth - ): + async def test_external_session(self, mock_api, token_auth): """Client should work with externally-provided session.""" url = f"{API_BASE}/api/Fires/GetFires" mock_api.get(url, payload=[]) async with ( aiohttp.ClientSession() as session, - FlameConnectClient( - token_auth, session=session - ) as client, + FlameConnectClient(token_auth, session=session) as client, ): fires = await client.get_fires() assert fires == [] - async def test_300_raises_api_error( - self, mock_api, token_auth - ): + async def test_300_raises_api_error(self, mock_api, token_auth): """Status 300 should raise ApiError. Kills mutant _request__mutmut_41 (>= 300 -> > 300) and _request__mutmut_42 (>= 300 -> >= 301). """ url = f"{API_BASE}/api/Fires/GetFires" - mock_api.get( - url, status=300, body="Multiple Choices" - ) + mock_api.get(url, status=300, body="Multiple Choices") async with FlameConnectClient(token_auth) as client: with pytest.raises(ApiError) as exc_info: @@ -1014,18 +828,14 @@ async def test_300_raises_api_error( assert exc_info.value.status == 300 - async def test_api_error_includes_response_text( - self, mock_api, token_auth - ): + async def test_api_error_includes_response_text(self, mock_api, token_auth): """ApiError message should contain response body text. Kills _request__mutmut_43 (text=None) and _request__mutmut_45 (ApiError(status, None)). """ url = f"{API_BASE}/api/Fires/GetFires" - mock_api.get( - url, status=503, body="Service Unavailable" - ) + mock_api.get(url, status=503, body="Service Unavailable") async with FlameConnectClient(token_auth) as client: with pytest.raises(ApiError) as exc_info: @@ -1050,18 +860,14 @@ def test_sound_param(self): def test_log_effect_param(self): param = LogEffectParam( log_effect=LogEffect.ON, - color=RGBWColor( - red=0, green=0, blue=0, white=0 - ), + color=RGBWColor(red=0, green=0, blue=0, white=0), pattern=0, ) assert _get_parameter_id(param) == 370 def test_unknown_type_raises_value_error(self): """Unknown type raises ValueError with type name.""" - with pytest.raises( - ValueError, match="Unknown parameter type: str" - ): + with pytest.raises(ValueError, match="Unknown parameter type: str"): _get_parameter_id("not-a-param") def test_mode_param(self): @@ -1098,15 +904,11 @@ def test_heat_param(self): assert _get_parameter_id(param) == 323 def test_heat_mode_param(self): - param = HeatModeParam( - heat_control=HeatControl.ENABLED - ) + param = HeatModeParam(heat_control=HeatControl.ENABLED) assert _get_parameter_id(param) == 325 def test_timer_param(self): - param = TimerParam( - timer_status=TimerStatus.DISABLED, duration=0 - ) + param = TimerParam(timer_status=TimerStatus.DISABLED, duration=0) assert _get_parameter_id(param) == 326 def test_temp_unit_param(self): @@ -1122,15 +924,10 @@ def test_temp_unit_param(self): class TestGetFireOverviewDecodeFailure: """Test decode failures in get_fire_overview.""" - async def test_bad_parameter_skipped( - self, mock_api, token_auth - ): + async def test_bad_parameter_skipped(self, mock_api, token_auth): """Parameter that fails to decode is skipped.""" fire_id = "test-fire-001" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" payload = { "WifiFireOverview": { @@ -1159,14 +956,10 @@ async def test_bad_parameter_skipped( mock_api.get(url, payload=payload) async with FlameConnectClient(token_auth) as client: - overview = await client.get_fire_overview( - fire_id - ) + overview = await client.get_fire_overview(fire_id) assert len(overview.parameters) == 1 - assert isinstance( - overview.parameters[0], ModeParam - ) + assert isinstance(overview.parameters[0], ModeParam) # ------------------------------------------------------------------- @@ -1181,9 +974,7 @@ class TestSessionHandling: __aexit____mutmut_1/2. """ - async def test_external_session_flag_is_true( - self, token_auth - ): + async def test_external_session_flag_is_true(self, token_auth): """When session is provided, _external_session=True. Kills __init____mutmut_2 (_external_session=None) @@ -1191,16 +982,12 @@ async def test_external_session_flag_is_true( """ session = aiohttp.ClientSession() try: - client = FlameConnectClient( - token_auth, session=session - ) + client = FlameConnectClient(token_auth, session=session) assert client._external_session is True finally: await session.close() - async def test_no_session_flag_is_false( - self, token_auth - ): + async def test_no_session_flag_is_false(self, token_auth): """When no session provided, _external_session=False. Kills __init____mutmut_3 (is None instead of @@ -1209,25 +996,19 @@ async def test_no_session_flag_is_false( client = FlameConnectClient(token_auth) assert client._external_session is False - async def test_init_stores_provided_session( - self, token_auth - ): + async def test_init_stores_provided_session(self, token_auth): """Provided session is stored in _session. Kills __init____mutmut_4 (_session = None). """ session = aiohttp.ClientSession() try: - client = FlameConnectClient( - token_auth, session=session - ) + client = FlameConnectClient(token_auth, session=session) assert client._session is session finally: await session.close() - async def test_aexit_closes_own_session( - self, mock_api, token_auth - ): + async def test_aexit_closes_own_session(self, mock_api, token_auth): """When we created the session, __aexit__ closes it. Kills __aexit____mutmut_1 (and -> or) and @@ -1245,9 +1026,7 @@ async def test_aexit_closes_own_session( # Session should be closed after __aexit__ assert session.closed - async def test_aexit_does_not_close_external_session( - self, mock_api, token_auth - ): + async def test_aexit_does_not_close_external_session(self, mock_api, token_auth): """External session should NOT be closed by client. Kills __aexit____mutmut_2 (removed 'not' for @@ -1258,9 +1037,7 @@ async def test_aexit_does_not_close_external_session( session = aiohttp.ClientSession() try: - async with FlameConnectClient( - token_auth, session=session - ) as client: + async with FlameConnectClient(token_auth, session=session) as client: await client.get_fires() # External session should still be open @@ -1281,9 +1058,7 @@ class TestRequestInternals: headers dict, and DEFAULT_HEADERS integration. """ - async def test_request_sends_authorization_header( - self, mock_api, token_auth - ): + async def test_request_sends_authorization_header(self, mock_api, token_auth): """Verify Authorization header with Bearer token. Kills _request__mutmut_10 (token=None), @@ -1300,13 +1075,9 @@ async def test_request_sends_authorization_header( call = mock_api.requests[key][0] headers = call.kwargs["headers"] assert "Authorization" in headers - assert headers["Authorization"] == ( - "Bearer test-token-123" - ) + assert headers["Authorization"] == ("Bearer test-token-123") - async def test_request_sends_content_type( - self, mock_api, token_auth - ): + async def test_request_sends_content_type(self, mock_api, token_auth): """Verify Content-Type header is application/json. Kills _request__mutmut_15-19. @@ -1323,9 +1094,7 @@ async def test_request_sends_content_type( assert "Content-Type" in headers assert headers["Content-Type"] == "application/json" - async def test_request_includes_default_headers( - self, mock_api, token_auth - ): + async def test_request_includes_default_headers(self, mock_api, token_auth): """Verify DEFAULT_HEADERS are included. Kills _request__mutmut_22 (headers=None) and @@ -1343,16 +1112,12 @@ async def test_request_includes_default_headers( for hdr_key, hdr_val in DEFAULT_HEADERS.items(): assert headers.get(hdr_key) == hdr_val - async def test_request_passes_json_body( - self, mock_api, token_auth - ): + async def test_request_passes_json_body(self, mock_api, token_auth): """Verify json body is passed through to request.""" url = f"{API_BASE}/api/Fires/WriteWifiParameters" mock_api.post(url, payload={}) - mode = ModeParam( - mode=FireMode.MANUAL, target_temperature=22.0 - ) + mode = ModeParam(mode=FireMode.MANUAL, target_temperature=22.0) async with FlameConnectClient(token_auth) as client: await client.write_parameters("f1", [mode]) @@ -1363,9 +1128,7 @@ async def test_request_passes_json_body( assert body is not None assert body["FireId"] == "f1" - async def test_request_uses_token_from_auth( - self, mock_api - ): + async def test_request_uses_token_from_auth(self, mock_api): """Verify the token from auth provider is used. Kills _request__mutmut_10 (token=None). @@ -1379,9 +1142,7 @@ async def test_request_uses_token_from_auth( key = ("GET", URL(url)) call = mock_api.requests[key][0] - assert call.kwargs["headers"]["Authorization"] == ( - "Bearer my-special-token" - ) + assert call.kwargs["headers"]["Authorization"] == ("Bearer my-special-token") # ------------------------------------------------------------------- @@ -1396,27 +1157,20 @@ class TestRequestLogging: _LOGGER.debug() call arguments. """ - async def test_request_logs_method_url_status( - self, mock_api, token_auth, caplog - ): + async def test_request_logs_method_url_status(self, mock_api, token_auth, caplog): """Verify debug log contains method, URL, status.""" url = f"{API_BASE}/api/Fires/GetFires" mock_api.get(url, payload=[]) - with caplog.at_level( - logging.DEBUG, logger="flameconnect.client" - ): - async with FlameConnectClient( - token_auth - ) as client: + with caplog.at_level(logging.DEBUG, logger="flameconnect.client"): + async with FlameConnectClient(token_auth) as client: await client.get_fires() # Find the debug message from _request found = [ r for r in caplog.records - if r.name == "flameconnect.client" - and "GET" in r.message + if r.name == "flameconnect.client" and "GET" in r.message ] assert len(found) >= 1 msg = found[0].message @@ -1432,15 +1186,10 @@ class TestOverviewDecodeWarningLogging: _LOGGER.warning() call format/args. """ - async def test_decode_failure_logs_warning( - self, mock_api, token_auth, caplog - ): + async def test_decode_failure_logs_warning(self, mock_api, token_auth, caplog): """Verify warning log on decode failure.""" fire_id = "test-fire-001" - url = ( - f"{API_BASE}/api/Fires/" - f"GetFireOverview?FireId={fire_id}" - ) + url = f"{API_BASE}/api/Fires/GetFireOverview?FireId={fire_id}" mode_val = encode_parameter( ModeParam( mode=FireMode.MANUAL, @@ -1456,19 +1205,14 @@ async def test_decode_failure_logs_warning( ) mock_api.get(url, payload=payload) - with caplog.at_level( - logging.WARNING, logger="flameconnect.client" - ): - async with FlameConnectClient( - token_auth - ) as client: + with caplog.at_level(logging.WARNING, logger="flameconnect.client"): + async with FlameConnectClient(token_auth) as client: await client.get_fire_overview(fire_id) warnings = [ r for r in caplog.records - if r.name == "flameconnect.client" - and r.levelno == logging.WARNING + if r.name == "flameconnect.client" and r.levelno == logging.WARNING ] assert len(warnings) >= 1 msg = warnings[0].message diff --git a/tests/test_fireplace_visual.py b/tests/test_fireplace_visual.py index b28263e..f330cf4 100644 --- a/tests/test_fireplace_visual.py +++ b/tests/test_fireplace_visual.py @@ -100,9 +100,7 @@ def test_flames_hidden_in_standby(self): if "\u2591" in inner or "\u2593" in inner: continue # Inner content should be spaces only - assert inner.strip() == "", ( - f"Expected blank flame row, got: {inner!r}" - ) + assert inner.strip() == "", f"Expected blank flame row, got: {inner!r}" # --------------------------------------------------------------------------- @@ -124,8 +122,7 @@ def test_led_style_applied(self): if ch == "\u2591": span_style = _style_at(text, idx) assert led in span_style, ( - f"Expected led_style {led!r} at offset {idx}, " - f"got {span_style!r}" + f"Expected led_style {led!r} at offset {idx}, got {span_style!r}" ) found = True break @@ -149,9 +146,7 @@ def test_media_style_applied(self): # Compute the absolute offset of this line in the full # plain text. line_offset = plain.index(line) - for rel, ch in enumerate( - line[inner_start:inner_end] - ): + for rel, ch in enumerate(line[inner_start:inner_end]): if ch == "\u2593": abs_offset = line_offset + inner_start + rel span_style = _style_at(text, abs_offset) @@ -184,8 +179,7 @@ def test_outer_hearth_always_dim(self): abs_offset = line_offset + hearth_start + rel span_style = _style_at(text, abs_offset) assert "dim" in span_style, ( - f"Expected 'dim' style on outer hearth, " - f"got {span_style!r}" + f"Expected 'dim' style on outer hearth, got {span_style!r}" ) return raise AssertionError( # noqa: TRY003 @@ -223,10 +217,7 @@ def test_height_adaptation_minimum(self): # not LED, media, or structural) flame_count = 0 for line in lines: - if ( - line.startswith("\u2502\u2502") - and line.endswith("\u2502\u2502") - ): + if line.startswith("\u2502\u2502") and line.endswith("\u2502\u2502"): inner = line[2:-2] if "\u2591" not in inner and "\u2593" not in inner: flame_count += 1 @@ -244,9 +235,7 @@ class TestFlamePalette: def test_flame_palette_applied(self): """Flame spans use the given palette style strings.""" palette = ("bright_cyan", "bright_blue", "blue") - text = _build_fire_art( - 50, 20, fire_on=True, flame_palette=palette - ) + text = _build_fire_art(50, 20, fire_on=True, flame_palette=palette) # Collect all unique style strings from spans styles_found: set[str] = set() for span in text._spans: @@ -256,8 +245,7 @@ def test_flame_palette_applied(self): # At least one of the palette entries should appear in spans palette_found = styles_found & set(palette) assert palette_found, ( - f"Expected one of {palette} in spans, " - f"found styles: {styles_found}" + f"Expected one of {palette} in spans, found styles: {styles_found}" ) @@ -281,8 +269,7 @@ def test_width_consistency(self): lines = text.plain.split("\n") for i, line in enumerate(lines): assert len(line) >= w, ( - f"Line {i} has width {len(line)}, expected >= {w}: " - f"{line!r}" + f"Line {i} has width {len(line)}, expected >= {w}: {line!r}" ) def test_structural_lines_exact_width(self): @@ -292,8 +279,7 @@ def test_structural_lines_exact_width(self): lines = text.plain.split("\n") for i, line in enumerate(lines): assert len(line) == w, ( - f"Line {i} has width {len(line)}, expected {w}: " - f"{line!r}" + f"Line {i} has width {len(line)}, expected {w}: {line!r}" ) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 5bbfe9d..eaa5280 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -579,37 +579,27 @@ class TestCheckLengthErrorMessages: def test_mode_error_says_mode(self): raw = _make_header(321, 3) + bytes([1]) - with pytest.raises( - ProtocolError, match=r"for Mode:" - ): + with pytest.raises(ProtocolError, match=r"for Mode:"): decode_parameter(ParameterId.MODE, raw) def test_flame_effect_error_says_flame_effect(self): raw = _make_header(322, 20) + bytes([0] * 5) - with pytest.raises( - ProtocolError, match=r"for FlameEffect:" - ): + with pytest.raises(ProtocolError, match=r"for FlameEffect:"): decode_parameter(ParameterId.FLAME_EFFECT, raw) def test_heat_settings_error_says_heat_settings(self): raw = _make_header(323, 7) + bytes([0]) - with pytest.raises( - ProtocolError, match=r"for HeatSettings:" - ): + with pytest.raises(ProtocolError, match=r"for HeatSettings:"): decode_parameter(ParameterId.HEAT_SETTINGS, raw) def test_heat_mode_error_says_heat_mode(self): raw = _make_header(325, 1) - with pytest.raises( - ProtocolError, match=r"for HeatMode:" - ): + with pytest.raises(ProtocolError, match=r"for HeatMode:"): decode_parameter(ParameterId.HEAT_MODE, raw) def test_timer_error_says_timer(self): raw = _make_header(326, 3) + bytes([0]) - with pytest.raises( - ProtocolError, match=r"for Timer:" - ): + with pytest.raises(ProtocolError, match=r"for Timer:"): decode_parameter(ParameterId.TIMER, raw) def test_software_version_error(self): @@ -618,38 +608,26 @@ def test_software_version_error(self): ProtocolError, match=r"for SoftwareVersion:", ): - decode_parameter( - ParameterId.SOFTWARE_VERSION, raw - ) + decode_parameter(ParameterId.SOFTWARE_VERSION, raw) def test_error_param_error_says_error(self): raw = _make_header(329, 4) + bytes([0]) - with pytest.raises( - ProtocolError, match=r"for Error:" - ): + with pytest.raises(ProtocolError, match=r"for Error:"): decode_parameter(ParameterId.ERROR, raw) def test_temp_unit_error_says_temp_unit(self): raw = _make_header(236, 1) - with pytest.raises( - ProtocolError, match=r"for TempUnit:" - ): - decode_parameter( - ParameterId.TEMPERATURE_UNIT, raw - ) + with pytest.raises(ProtocolError, match=r"for TempUnit:"): + decode_parameter(ParameterId.TEMPERATURE_UNIT, raw) def test_sound_error_says_sound(self): raw = _make_header(369, 2) + bytes([0]) - with pytest.raises( - ProtocolError, match=r"for Sound:" - ): + with pytest.raises(ProtocolError, match=r"for Sound:"): decode_parameter(ParameterId.SOUND, raw) def test_log_effect_error_says_log_effect(self): raw = _make_header(370, 8) + bytes([0] * 3) - with pytest.raises( - ProtocolError, match=r"for LogEffect:" - ): + with pytest.raises(ProtocolError, match=r"for LogEffect:"): decode_parameter(ParameterId.LOG_EFFECT, raw) def test_error_includes_expected_and_got(self): @@ -672,9 +650,15 @@ class TestEncodeErrorMessages: def test_software_version_error_text(self): param = SoftwareVersionParam( - ui_major=1, ui_minor=0, ui_test=0, - control_major=1, control_minor=0, control_test=0, - relay_major=1, relay_minor=0, relay_test=0, + ui_major=1, + ui_minor=0, + ui_test=0, + control_major=1, + control_minor=0, + control_test=0, + relay_major=1, + relay_minor=0, + relay_test=0, ) with pytest.raises(ProtocolError) as exc_info: encode_parameter(param) @@ -684,8 +668,10 @@ def test_software_version_error_text(self): def test_error_param_error_text(self): param = ErrorParam( - error_byte1=0, error_byte2=0, - error_byte3=0, error_byte4=0, + error_byte1=0, + error_byte2=0, + error_byte3=0, + error_byte4=0, ) with pytest.raises(ProtocolError) as exc_info: encode_parameter(param) @@ -720,9 +706,7 @@ def test_temp_unit_encoded_bytes(self): assert raw[3] == 1 # CELSIUS = 1 def test_mode_encoded_bytes(self): - param = ModeParam( - mode=FireMode.MANUAL, target_temperature=22.5 - ) + param = ModeParam(mode=FireMode.MANUAL, target_temperature=22.5) b64 = encode_parameter(param) raw = base64.b64decode(b64) assert len(raw) == 6 @@ -742,9 +726,7 @@ def test_heat_mode_encoded_bytes(self): assert raw[3] == 2 # ENABLED = 2 def test_timer_encoded_bytes(self): - param = TimerParam( - timer_status=TimerStatus.ENABLED, duration=300 - ) + param = TimerParam(timer_status=TimerStatus.ENABLED, duration=300) b64 = encode_parameter(param) raw = base64.b64decode(b64) assert len(raw) == 6 @@ -814,13 +796,9 @@ def test_flame_effect_encoded_bytes(self): pulsating_effect=PulsatingEffect.ON, media_theme=MediaTheme.BLUE, media_light=LightStatus.ON, - media_color=RGBWColor( - red=10, green=20, blue=30, white=40 - ), + media_color=RGBWColor(red=10, green=20, blue=30, white=40), overhead_light=LightStatus.ON, - overhead_color=RGBWColor( - red=50, green=60, blue=70, white=80 - ), + overhead_color=RGBWColor(red=50, green=60, blue=70, white=80), light_status=LightStatus.ON, flame_color=FlameColor.YELLOW_RED, ambient_sensor=LightStatus.ON, @@ -883,18 +861,14 @@ class TestTemperatureEncodingExact: """Verify exact byte values for temperature encoding.""" def test_temp_22_5_encodes_to_22_and_5(self): - param = ModeParam( - mode=FireMode.MANUAL, target_temperature=22.5 - ) + param = ModeParam(mode=FireMode.MANUAL, target_temperature=22.5) b64 = encode_parameter(param) raw = base64.b64decode(b64) assert raw[4] == 22 # integer assert raw[5] == 5 # 0.5 * 10 = 5 def test_temp_18_5_encodes_to_18_and_5(self): - param = ModeParam( - mode=FireMode.MANUAL, target_temperature=18.5 - ) + param = ModeParam(mode=FireMode.MANUAL, target_temperature=18.5) b64 = encode_parameter(param) raw = base64.b64decode(b64) assert raw[4] == 18 @@ -919,9 +893,26 @@ def test_pulsating_on_decodes_from_bit1(self): """brightness_byte=0b10 -> brightness=LOW(0), pulsating=ON.""" raw = _make_header(322, 20) + bytes( [ - 1, 0, 0b10, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, + 0, + 0b10, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, ] ) result = decode_parameter(ParameterId.FLAME_EFFECT, raw) @@ -932,9 +923,26 @@ def test_both_on_decodes_from_0b11(self): """brightness_byte=0b11 -> brightness=LOW(1), pulsating=ON.""" raw = _make_header(322, 20) + bytes( [ - 1, 0, 0b11, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, + 0, + 0b11, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, ] ) result = decode_parameter(ParameterId.FLAME_EFFECT, raw) @@ -973,9 +981,22 @@ class TestFlameColorIndex: def test_flame_color_from_index_19(self): raw = _make_header(322, 20) + bytes( [ - 0, 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, 3, # raw[19] = BLUE(3) 99, # raw[20] = padding (different) 99, # raw[21] = padding @@ -996,45 +1017,29 @@ class TestHeatSettingsBoostBoundary: def test_exactly_7_bytes_boost_defaults_zero(self): """With exactly 7 bytes, boost_lo defaults to 0.""" - raw = _make_header(323, 4) + bytes( - [1, 0, 22, 0] - ) # 7 bytes total - result = decode_parameter( - ParameterId.HEAT_SETTINGS, raw - ) + raw = _make_header(323, 4) + bytes([1, 0, 22, 0]) # 7 bytes total + result = decode_parameter(ParameterId.HEAT_SETTINGS, raw) # boost_lo=0, boost_hi=0 => duration=(0|0)+1=1 assert result.boost_duration == 1 def test_exactly_8_bytes_boost_lo_read(self): """With 8 bytes, boost_lo is read from raw[7].""" - raw = _make_header(323, 5) + bytes( - [1, 0, 22, 0, 5] - ) # 8 bytes - result = decode_parameter( - ParameterId.HEAT_SETTINGS, raw - ) + raw = _make_header(323, 5) + bytes([1, 0, 22, 0, 5]) # 8 bytes + result = decode_parameter(ParameterId.HEAT_SETTINGS, raw) # boost_lo=5, boost_hi=0 => (5|0)+1=6 assert result.boost_duration == 6 def test_exactly_9_bytes_boost_hi_read(self): """With 9 bytes, boost_hi is read from raw[8].""" - raw = _make_header(323, 6) + bytes( - [1, 0, 22, 0, 5, 2] - ) # 9 bytes - result = decode_parameter( - ParameterId.HEAT_SETTINGS, raw - ) + raw = _make_header(323, 6) + bytes([1, 0, 22, 0, 5, 2]) # 9 bytes + result = decode_parameter(ParameterId.HEAT_SETTINGS, raw) # boost_lo=5, boost_hi=2 => (5|(2<<8))+1 = (5|512)+1=518 assert result.boost_duration == 518 def test_boost_hi_shift_amount(self): """Verify boost_hi is shifted left by 8 (not 9).""" - raw = _make_header(323, 6) + bytes( - [1, 0, 22, 0, 0, 1] - ) # 9 bytes - result = decode_parameter( - ParameterId.HEAT_SETTINGS, raw - ) + raw = _make_header(323, 6) + bytes([1, 0, 22, 0, 0, 1]) # 9 bytes + result = decode_parameter(ParameterId.HEAT_SETTINGS, raw) # boost_lo=0, boost_hi=1 => (0|(1<<8))+1 = 257 assert result.boost_duration == 257 @@ -1043,18 +1048,14 @@ def test_boost_duration_7_bytes_fallback(self): raw = bytes(7) raw = _make_header(323, 4) + bytes([0, 0, 20, 0]) assert len(raw) == 7 - result = decode_parameter( - ParameterId.HEAT_SETTINGS, raw - ) + result = decode_parameter(ParameterId.HEAT_SETTINGS, raw) assert result.boost_duration == 1 def test_check_length_mutant_7_vs_8(self): """_check_length is called with 7 (not 8).""" # Exactly 7 bytes should NOT raise raw = _make_header(323, 4) + bytes([0, 0, 20, 0]) - result = decode_parameter( - ParameterId.HEAT_SETTINGS, raw - ) + result = decode_parameter(ParameterId.HEAT_SETTINGS, raw) assert isinstance(result, HeatParam) @@ -1096,8 +1097,7 @@ def test_decode_mode_logs(self, caplog): def test_decode_flame_effect_logs(self, caplog): raw = _make_header(322, 20) + bytes( - [1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0] + [1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ) with caplog.at_level(logging.DEBUG, "flameconnect"): decode_parameter(ParameterId.FLAME_EFFECT, raw) @@ -1106,9 +1106,7 @@ def test_decode_flame_effect_logs(self, caplog): assert "Decoded FlameEffect" in caplog.text def test_decode_heat_settings_logs(self, caplog): - raw = _make_header(323, 7) + bytes( - [1, 0, 22, 0, 0, 0, 0] - ) + raw = _make_header(323, 7) + bytes([1, 0, 22, 0, 0, 0, 0]) with caplog.at_level(logging.DEBUG, "flameconnect"): decode_parameter(ParameterId.HEAT_SETTINGS, raw) _assert_no_xx(caplog.text) @@ -1133,21 +1131,15 @@ def test_decode_timer_logs(self, caplog): assert "Decoded Timer" in caplog.text def test_decode_software_version_logs(self, caplog): - raw = _make_header(327, 9) + bytes( - [1, 2, 3, 4, 5, 6, 7, 8, 9] - ) + raw = _make_header(327, 9) + bytes([1, 2, 3, 4, 5, 6, 7, 8, 9]) with caplog.at_level(logging.DEBUG, "flameconnect"): - decode_parameter( - ParameterId.SOFTWARE_VERSION, raw - ) + decode_parameter(ParameterId.SOFTWARE_VERSION, raw) _assert_no_xx(caplog.text) _assert_no_none(caplog.text) assert "Decoded SoftwareVersion" in caplog.text def test_decode_error_logs(self, caplog): - raw = _make_header(329, 4) + bytes( - [0xFF, 1, 0x80, 0x42] - ) + raw = _make_header(329, 4) + bytes([0xFF, 1, 0x80, 0x42]) with caplog.at_level(logging.DEBUG, "flameconnect"): decode_parameter(ParameterId.ERROR, raw) _assert_no_xx(caplog.text) @@ -1163,9 +1155,7 @@ def test_decode_sound_logs(self, caplog): assert "Decoded Sound" in caplog.text def test_decode_log_effect_logs(self, caplog): - raw = _make_header(370, 8) + bytes( - [1, 0, 100, 200, 150, 50, 1, 0] - ) + raw = _make_header(370, 8) + bytes([1, 0, 100, 200, 150, 50, 1, 0]) with caplog.at_level(logging.DEBUG, "flameconnect"): decode_parameter(ParameterId.LOG_EFFECT, raw) _assert_no_xx(caplog.text) @@ -1193,9 +1183,7 @@ def test_encode_temp_unit_logs(self, caplog): assert "Encoding TempUnit" in caplog.text def test_encode_mode_logs(self, caplog): - param = ModeParam( - mode=FireMode.MANUAL, target_temperature=22.5 - ) + param = ModeParam(mode=FireMode.MANUAL, target_temperature=22.5) with caplog.at_level(logging.DEBUG, "flameconnect"): encode_parameter(param) _assert_no_xx(caplog.text) @@ -1237,9 +1225,7 @@ def test_encode_heat_settings_logs(self, caplog): assert "Encoding HeatSettings" in caplog.text def test_encode_heat_mode_logs(self, caplog): - param = HeatModeParam( - heat_control=HeatControl.ENABLED - ) + param = HeatModeParam(heat_control=HeatControl.ENABLED) with caplog.at_level(logging.DEBUG, "flameconnect"): encode_parameter(param) _assert_no_xx(caplog.text) @@ -1248,9 +1234,7 @@ def test_encode_heat_mode_logs(self, caplog): assert "%s" not in caplog.text def test_encode_timer_logs(self, caplog): - param = TimerParam( - timer_status=TimerStatus.ENABLED, duration=120 - ) + param = TimerParam(timer_status=TimerStatus.ENABLED, duration=120) with caplog.at_level(logging.DEBUG, "flameconnect"): encode_parameter(param) _assert_no_xx(caplog.text) @@ -1307,13 +1291,10 @@ def test_mode_log_values(self, caplog): def test_flame_effect_log_values(self, caplog): """Verify log has speed, brightness, pulsating.""" raw = _make_header(322, 20) + bytes( - [1, 4, 0b11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0] + [1, 4, 0b11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ) with caplog.at_level(logging.DEBUG, "flameconnect"): - decode_parameter( - ParameterId.FLAME_EFFECT, raw - ) + decode_parameter(ParameterId.FLAME_EFFECT, raw) _assert_no_xx(caplog.text) _assert_no_none(caplog.text) assert "5" in caplog.text @@ -1321,13 +1302,9 @@ def test_flame_effect_log_values(self, caplog): assert "ON" in caplog.text def test_heat_settings_log_values(self, caplog): - raw = _make_header(323, 7) + bytes( - [1, 1, 25, 5, 14, 0, 0] - ) + raw = _make_header(323, 7) + bytes([1, 1, 25, 5, 14, 0, 0]) with caplog.at_level(logging.DEBUG, "flameconnect"): - decode_parameter( - ParameterId.HEAT_SETTINGS, raw - ) + decode_parameter(ParameterId.HEAT_SETTINGS, raw) _assert_no_xx(caplog.text) _assert_no_none(caplog.text) assert "25.5" in caplog.text @@ -1352,22 +1329,16 @@ def test_timer_log_values(self, caplog): assert "1" in caplog.text def test_software_version_log_values(self, caplog): - raw = _make_header(327, 9) + bytes( - [2, 3, 4, 5, 6, 7, 8, 9, 10] - ) + raw = _make_header(327, 9) + bytes([2, 3, 4, 5, 6, 7, 8, 9, 10]) with caplog.at_level(logging.DEBUG, "flameconnect"): - decode_parameter( - ParameterId.SOFTWARE_VERSION, raw - ) + decode_parameter(ParameterId.SOFTWARE_VERSION, raw) _assert_no_xx(caplog.text) _assert_no_none(caplog.text) for n in [2, 3, 4, 5, 6, 7, 8, 9, 10]: assert str(n) in caplog.text def test_error_log_values(self, caplog): - raw = _make_header(329, 4) + bytes( - [0xAA, 0xBB, 0xCC, 0xDD] - ) + raw = _make_header(329, 4) + bytes([0xAA, 0xBB, 0xCC, 0xDD]) with caplog.at_level(logging.DEBUG, "flameconnect"): decode_parameter(ParameterId.ERROR, raw) _assert_no_xx(caplog.text) @@ -1388,9 +1359,7 @@ def test_sound_log_values(self, caplog): assert "7" in caplog.text def test_log_effect_log_values(self, caplog): - raw = _make_header(370, 8) + bytes( - [1, 0, 10, 20, 30, 40, 3, 0] - ) + raw = _make_header(370, 8) + bytes([1, 0, 10, 20, 30, 40, 3, 0]) with caplog.at_level(logging.DEBUG, "flameconnect"): decode_parameter(ParameterId.LOG_EFFECT, raw) _assert_no_xx(caplog.text) @@ -1420,9 +1389,7 @@ def test_encode_temp_unit_log_value(self, caplog): assert "0" in caplog.text def test_encode_mode_log_values(self, caplog): - param = ModeParam( - mode=FireMode.MANUAL, target_temperature=22.5 - ) + param = ModeParam(mode=FireMode.MANUAL, target_temperature=22.5) with caplog.at_level(logging.DEBUG, "flameconnect"): encode_parameter(param) _assert_no_xx(caplog.text) @@ -1466,9 +1433,7 @@ def test_encode_heat_settings_log_values(self, caplog): assert "15" in caplog.text def test_encode_heat_mode_log_value(self, caplog): - param = HeatModeParam( - heat_control=HeatControl.SOFTWARE_DISABLED - ) + param = HeatModeParam(heat_control=HeatControl.SOFTWARE_DISABLED) with caplog.at_level(logging.DEBUG, "flameconnect"): encode_parameter(param) _assert_no_xx(caplog.text) @@ -1477,9 +1442,7 @@ def test_encode_heat_mode_log_value(self, caplog): assert "0" in caplog.text def test_encode_timer_log_values(self, caplog): - param = TimerParam( - timer_status=TimerStatus.ENABLED, duration=300 - ) + param = TimerParam(timer_status=TimerStatus.ENABLED, duration=300) with caplog.at_level(logging.DEBUG, "flameconnect"): encode_parameter(param) _assert_no_xx(caplog.text) @@ -1524,12 +1487,8 @@ def test_mode_fields(self): assert result.target_temperature == 30.7 def test_heat_settings_fields(self): - raw = _make_header(323, 7) + bytes( - [1, 2, 20, 5, 10, 1, 0] - ) - result = decode_parameter( - ParameterId.HEAT_SETTINGS, raw - ) + raw = _make_header(323, 7) + bytes([1, 2, 20, 5, 10, 1, 0]) + result = decode_parameter(ParameterId.HEAT_SETTINGS, raw) assert result.heat_status == HeatStatus.ON assert result.heat_mode == HeatMode.ECO assert result.setpoint_temperature == 20.5 @@ -1554,9 +1513,7 @@ def test_sound_fields(self): assert result.sound_file == 7 def test_log_effect_fields(self): - raw = _make_header(370, 8) + bytes( - [1, 0, 10, 20, 30, 40, 5, 0] - ) + raw = _make_header(370, 8) + bytes([1, 0, 10, 20, 30, 40, 5, 0]) result = decode_parameter(ParameterId.LOG_EFFECT, raw) assert result.log_effect == LogEffect.ON assert result.color.red == 10 @@ -1569,19 +1526,26 @@ def test_flame_effect_all_fields(self): """Verify every field in a flame effect decode.""" raw = _make_header(322, 20) + bytes( [ - 1, # flame_effect ON - 4, # wire speed -> model 5 + 1, # flame_effect ON + 4, # wire speed -> model 5 0b11, # brightness LOW, pulsating ON - 7, # media_theme KALEIDOSCOPE - 1, # media_light ON - 10, 20, 30, 40, # media RBGW - 0, # padding - 1, # overhead_light ON - 50, 60, 70, 80, # overhead RBGW - 1, # light_status ON - 5, # flame_color YELLOW - 0, 0, # padding - 1, # ambient_sensor ON + 7, # media_theme KALEIDOSCOPE + 1, # media_light ON + 10, + 20, + 30, + 40, # media RBGW + 0, # padding + 1, # overhead_light ON + 50, + 60, + 70, + 80, # overhead RBGW + 1, # light_status ON + 5, # flame_color YELLOW + 0, + 0, # padding + 1, # ambient_sensor ON ] ) r = decode_parameter(ParameterId.FLAME_EFFECT, raw) @@ -1591,13 +1555,9 @@ def test_flame_effect_all_fields(self): assert r.pulsating_effect == PulsatingEffect.ON assert r.media_theme == MediaTheme.KALEIDOSCOPE assert r.media_light == LightStatus.ON - assert r.media_color == RGBWColor( - red=10, blue=20, green=30, white=40 - ) + assert r.media_color == RGBWColor(red=10, blue=20, green=30, white=40) assert r.overhead_light == LightStatus.ON - assert r.overhead_color == RGBWColor( - red=50, blue=60, green=70, white=80 - ) + assert r.overhead_color == RGBWColor(red=50, blue=60, green=70, white=80) assert r.light_status == LightStatus.ON assert r.flame_color == FlameColor.YELLOW assert r.ambient_sensor == LightStatus.ON diff --git a/tests/test_tui_actions.py b/tests/test_tui_actions.py index 737427f..a44f55a 100644 --- a/tests/test_tui_actions.py +++ b/tests/test_tui_actions.py @@ -837,9 +837,7 @@ async def test_no_op_when_write_in_progress(self, mock_client, mock_dashboard): with patch.object(type(app), "screen", new_callable=PropertyMock) as prop: prop.return_value = mock_dashboard - app._apply_overhead_color( - RGBWColor(red=0, green=0, blue=255, white=80) - ) + app._apply_overhead_color(RGBWColor(red=0, green=0, blue=255, white=80)) await _run_workers(app) mock_client.write_parameters.assert_not_awaited() @@ -889,7 +887,8 @@ async def test_timer_disable_error_logs_and_clears_flag( self, mock_client, mock_dashboard ): mock_dashboard.current_parameters[TimerParam] = TimerParam( - timer_status=TimerStatus.ENABLED, duration=60, + timer_status=TimerStatus.ENABLED, + duration=60, ) mock_client.write_parameters.side_effect = Exception("timeout") app = _make_app(mock_client, mock_dashboard) @@ -1131,12 +1130,8 @@ def test_creates_downloads_dir_before_delivery( target = tmp_path / "nonexistent" / "downloads" with ( - patch( - "platformdirs.user_downloads_path", return_value=target - ), - patch.object( - App, "deliver_screenshot", return_value=None - ) as super_deliver, + patch("platformdirs.user_downloads_path", return_value=target), + patch.object(App, "deliver_screenshot", return_value=None) as super_deliver, ): app.deliver_screenshot() @@ -2338,9 +2333,7 @@ async def test_no_fires_shows_error(self, mock_client, mock_dashboard): async def test_multiple_fires_shows_selector(self, mock_client, mock_dashboard): """When multiple fires, show selection list.""" - mock_client.get_fires = AsyncMock( - return_value=[_TEST_FIRE, _TEST_FIRE_2] - ) + mock_client.get_fires = AsyncMock(return_value=[_TEST_FIRE, _TEST_FIRE_2]) app = _make_app(mock_client, mock_dashboard) app.mount = AsyncMock() @@ -2349,16 +2342,12 @@ async def test_multiple_fires_shows_selector(self, mock_client, mock_dashboard): await app._load_fires() - mock_loading.update.assert_called_once_with( - "[bold]Select a fireplace:[/bold]" - ) + mock_loading.update.assert_called_once_with("[bold]Select a fireplace:[/bold]") app.mount.assert_awaited_once() async def test_get_fires_exception_notifies(self, mock_client, mock_dashboard): """When get_fires raises, notify the user.""" - mock_client.get_fires = AsyncMock( - side_effect=Exception("connection refused") - ) + mock_client.get_fires = AsyncMock(side_effect=Exception("connection refused")) app = _make_app(mock_client, mock_dashboard) app.notify = MagicMock() @@ -2630,9 +2619,7 @@ class TestApplyMediaThemeWorker: async def test_worker_error_logs_and_clears_flag(self, mock_client, mock_dashboard): """_apply_media_theme worker logs error and clears write flag on failure.""" - mock_client.write_parameters = AsyncMock( - side_effect=Exception("API failure") - ) + mock_client.write_parameters = AsyncMock(side_effect=Exception("API failure")) app = _make_app(mock_client, mock_dashboard) with patch.object(type(app), "screen", new_callable=PropertyMock) as prop: diff --git a/tests/test_tui_screens.py b/tests/test_tui_screens.py index e7d3436..b9843b4 100644 --- a/tests/test_tui_screens.py +++ b/tests/test_tui_screens.py @@ -395,9 +395,7 @@ def on_mount(self) -> None: def _on_dismiss(result): self.dismiss_result = result - self.push_screen( - HeatModeScreen(self._mode, self._boost), callback=_on_dismiss - ) + self.push_screen(HeatModeScreen(self._mode, self._boost), callback=_on_dismiss) class TestHeatModeScreen: @@ -789,18 +787,14 @@ def on_mount(self) -> None: def _on_dismiss(result): self.dismiss_result = result - self.push_screen( - ColorScreen(self._current, self._title), callback=_on_dismiss - ) + self.push_screen(ColorScreen(self._current, self._title), callback=_on_dismiss) class TestColorScreen: """Tests for ColorScreen.""" async def test_compose_shows_title_and_current(self): - app = ColorScreenApp( - RGBWColor(red=10, green=20, blue=30, white=40), "My Color" - ) + app = ColorScreenApp(RGBWColor(red=10, green=20, blue=30, white=40), "My Color") async with app.run_test(size=(100, 30)): title = app.screen.query_one("#color-title", Static) rendered = str(title._Static__content) @@ -860,9 +854,7 @@ async def test_set_rgbw_button_custom_values(self): btn = app.screen.query_one("#set-rgbw", Button) btn.press() await pilot.pause() - assert app.dismiss_result == RGBWColor( - red=128, green=64, blue=32, white=16 - ) + assert app.dismiss_result == RGBWColor(red=128, green=64, blue=32, white=16) async def test_set_rgbw_invalid_value_does_not_dismiss(self): app = ColorScreenApp() @@ -1486,9 +1478,7 @@ async def test_update_display_tracks_param_changes(self): parameters=[changed_mode, _DEFAULT_FLAME_EFFECT], ) client = MagicMock() - client.get_fire_overview = AsyncMock( - side_effect=[overview1, overview2] - ) + client.get_fire_overview = AsyncMock(side_effect=[overview1, overview2]) app = DashboardApp(client=client, fire=_TEST_FIRE) async with app.run_test(size=(120, 40)) as pilot: await app.screen.refresh_state() @@ -1675,9 +1665,7 @@ async def test_logs_changed_fields(self): overview1 = FireOverview(fire=_TEST_FIRE, parameters=[old_mode]) overview2 = FireOverview(fire=_TEST_FIRE, parameters=[new_mode]) client = MagicMock() - client.get_fire_overview = AsyncMock( - side_effect=[overview1, overview2] - ) + client.get_fire_overview = AsyncMock(side_effect=[overview1, overview2]) app = DashboardApp(client=client, fire=_TEST_FIRE) async with app.run_test(size=(120, 40)) as pilot: await app.screen.refresh_state() @@ -1687,9 +1675,7 @@ async def test_logs_changed_fields(self): async def test_no_log_for_unchanged_params(self): """Identical params between refreshes should not trigger log.""" - overview = FireOverview( - fire=_TEST_FIRE, parameters=[_DEFAULT_MODE] - ) + overview = FireOverview(fire=_TEST_FIRE, parameters=[_DEFAULT_MODE]) client = MagicMock() client.get_fire_overview = AsyncMock(return_value=overview) app = DashboardApp(client=client, fire=_TEST_FIRE) @@ -1707,9 +1693,7 @@ async def test_log_new_param_type_not_in_old(self): parameters=[_DEFAULT_MODE, _DEFAULT_FLAME_EFFECT], ) client = MagicMock() - client.get_fire_overview = AsyncMock( - side_effect=[overview1, overview2] - ) + client.get_fire_overview = AsyncMock(side_effect=[overview1, overview2]) app = DashboardApp(client=client, fire=_TEST_FIRE) async with app.run_test(size=(120, 40)) as pilot: await app.screen.refresh_state() @@ -1724,9 +1708,7 @@ async def test_multiple_field_changes_logged(self): overview1 = FireOverview(fire=_TEST_FIRE, parameters=[old_mode]) overview2 = FireOverview(fire=_TEST_FIRE, parameters=[new_mode]) client = MagicMock() - client.get_fire_overview = AsyncMock( - side_effect=[overview1, overview2] - ) + client.get_fire_overview = AsyncMock(side_effect=[overview1, overview2]) app = DashboardApp(client=client, fire=_TEST_FIRE) async with app.run_test(size=(120, 40)) as pilot: await app.screen.refresh_state() @@ -1837,9 +1819,7 @@ def on_mount(self) -> None: def _on_dismiss(result): self.dismiss_result = result - self.push_screen( - TimerScreen(self._duration), callback=_on_dismiss - ) + self.push_screen(TimerScreen(self._duration), callback=_on_dismiss) class TestTimerScreen: diff --git a/tests/test_widgets_format.py b/tests/test_widgets_format.py index d4758b6..59e25bb 100644 --- a/tests/test_widgets_format.py +++ b/tests/test_widgets_format.py @@ -273,7 +273,7 @@ def test_flame_effect_labels_and_actions(self): param = _sample_flame_effect() result = _format_flame_effect(param) labels_and_actions = [(r[0], r[2]) for r in result] - assert ("Flame Effect" in labels_and_actions[0][0]) + assert "Flame Effect" in labels_and_actions[0][0] assert labels_and_actions[0][1] == "toggle_flame_effect" assert labels_and_actions[1][1] == "set_flame_color" assert labels_and_actions[2][1] == "set_flame_speed" @@ -536,9 +536,7 @@ class TestFormatError: """Tests for _format_error.""" def test_no_errors(self): - param = ErrorParam( - error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=0 - ) + param = ErrorParam(error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=0) result = _format_error(param) assert len(result) == 1 assert "No Errors Recorded" in result[0][1] @@ -568,24 +566,18 @@ def test_has_error_all_bytes(self): assert "0x78" in result[0][1] def test_error_only_byte4(self): - param = ErrorParam( - error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=1 - ) + param = ErrorParam(error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=1) result = _format_error(param) # Any non-zero byte should flag an error assert "Error" in result[0][0] def test_no_error_label_bold(self): - param = ErrorParam( - error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=0 - ) + param = ErrorParam(error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=0) result = _format_error(param) assert "[bold]Errors:[/bold]" in result[0][0] def test_error_label_bold_red(self): - param = ErrorParam( - error_byte1=1, error_byte2=0, error_byte3=0, error_byte4=0 - ) + param = ErrorParam(error_byte1=1, error_byte2=0, error_byte3=0, error_byte4=0) result = _format_error(param) assert "bold red" in result[0][0] @@ -801,9 +793,7 @@ def test_single_sound_param(self): assert any("Sound" in r[0] for r in result) def test_single_log_effect_param(self): - params = [ - LogEffectParam(log_effect=LogEffect.ON, color=_black(), pattern=0) - ] + params = [LogEffectParam(log_effect=LogEffect.ON, color=_black(), pattern=0)] result = format_parameters(params) assert any("Log Effect" in r[0] for r in result) @@ -836,9 +826,7 @@ def test_temp_unit_applied_to_heat(self): def test_display_order(self): """Parameters are returned in the defined display order.""" params = [ - ErrorParam( - error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=0 - ), + ErrorParam(error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=0), ModeParam(mode=FireMode.MANUAL, target_temperature=22.0), TempUnitParam(unit=TempUnit.CELSIUS), ] @@ -877,9 +865,7 @@ def test_all_param_types_together(self): TempUnitParam(unit=TempUnit.CELSIUS), SoundParam(volume=5, sound_file=1), LogEffectParam(log_effect=LogEffect.ON, color=_black(), pattern=0), - ErrorParam( - error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=0 - ), + ErrorParam(error_byte1=0, error_byte2=0, error_byte3=0, error_byte4=0), ] result = format_parameters(params) # Should have results for all types; at minimum > 20 rows @@ -889,38 +875,20 @@ def test_all_param_types_together(self): # before FlameEffect before Timer before Software before # TempUnit before Sound before LogEffect before Error labels = [r[0] for r in result] - first_mode = next( - i for i, lbl in enumerate(labels) if "Mode" in lbl - ) + first_mode = next(i for i, lbl in enumerate(labels) if "Mode" in lbl) first_heat = next( i for i, lbl in enumerate(labels) - if "Heat" in lbl - and "Control" not in lbl - and "Mode" not in lbl - ) - first_flame = next( - i for i, lbl in enumerate(labels) if "Flame Effect" in lbl - ) - first_timer = next( - i for i, lbl in enumerate(labels) if "Timer" in lbl - ) - first_sw = next( - i for i, lbl in enumerate(labels) if "Software" in lbl - ) - first_temp_unit = next( - i for i, lbl in enumerate(labels) if "Temp Unit" in lbl - ) - first_sound = next( - i for i, lbl in enumerate(labels) if "Sound" in lbl - ) - first_log = next( - i for i, lbl in enumerate(labels) if "Log Effect" in lbl + if "Heat" in lbl and "Control" not in lbl and "Mode" not in lbl ) + first_flame = next(i for i, lbl in enumerate(labels) if "Flame Effect" in lbl) + first_timer = next(i for i, lbl in enumerate(labels) if "Timer" in lbl) + first_sw = next(i for i, lbl in enumerate(labels) if "Software" in lbl) + first_temp_unit = next(i for i, lbl in enumerate(labels) if "Temp Unit" in lbl) + first_sound = next(i for i, lbl in enumerate(labels) if "Sound" in lbl) + first_log = next(i for i, lbl in enumerate(labels) if "Log Effect" in lbl) first_error = next( - i - for i, lbl in enumerate(labels) - if "Errors" in lbl or "Error" in lbl + i for i, lbl in enumerate(labels) if "Errors" in lbl or "Error" in lbl ) assert ( first_mode