diff --git a/switchbot_mqtt/__init__.py b/switchbot_mqtt/__init__.py index 61c25b1..279baf1 100644 --- a/switchbot_mqtt/__init__.py +++ b/switchbot_mqtt/__init__.py @@ -54,6 +54,7 @@ def _run( mqtt_port: int, mqtt_username: typing.Optional[str], mqtt_password: typing.Optional[str], + mqtt_disable_tls: bool, retry_count: int, device_passwords: typing.Dict[str, str], fetch_device_info: bool, @@ -68,6 +69,8 @@ def _run( ) mqtt_client.on_connect = _mqtt_on_connect _LOGGER.info("connecting to MQTT broker %s:%d", mqtt_host, mqtt_port) + if not mqtt_disable_tls: + mqtt_client.tls_set(ca_certs=None) # enable tls trusting default system certs if mqtt_username: mqtt_client.username_pw_set(username=mqtt_username, password=mqtt_password) elif mqtt_password: diff --git a/switchbot_mqtt/_cli.py b/switchbot_mqtt/_cli.py index 21fe1c6..ab487d9 100644 --- a/switchbot_mqtt/_cli.py +++ b/switchbot_mqtt/_cli.py @@ -110,6 +110,7 @@ def _main() -> None: mqtt_port=args.mqtt_port, mqtt_username=args.mqtt_username, mqtt_password=mqtt_password, + mqtt_disable_tls=True, retry_count=args.retry_count, device_passwords=device_passwords, fetch_device_info=args.fetch_device_info diff --git a/tests/test_cli.py b/tests/test_cli.py index e6dbb4a..f4e43d6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -110,6 +110,7 @@ def test__main( mqtt_port=expected_mqtt_port, mqtt_username=expected_username, mqtt_password=expected_password, + mqtt_disable_tls=True, retry_count=expected_retry_count, device_passwords={}, fetch_device_info=False, @@ -153,6 +154,7 @@ def test__main_mqtt_password_file( mqtt_port=1883, mqtt_username="me", mqtt_password=expected_password, + mqtt_disable_tls=True, retry_count=3, device_passwords={}, fetch_device_info=False, @@ -214,6 +216,7 @@ def test__main_device_password_file( mqtt_port=1883, mqtt_username=None, mqtt_password=None, + mqtt_disable_tls=True, retry_count=3, device_passwords=device_passwords, fetch_device_info=False, @@ -235,6 +238,7 @@ def test__main_fetch_device_info() -> None: mqtt_port=1883, mqtt_username=None, mqtt_password=None, + mqtt_disable_tls=True, retry_count=3, device_passwords={}, ) diff --git a/tests/test_mqtt.py b/tests/test_mqtt.py index b177711..73eb68f 100644 --- a/tests/test_mqtt.py +++ b/tests/test_mqtt.py @@ -58,6 +58,7 @@ def test__run( mqtt_port=mqtt_port, mqtt_username=None, mqtt_password=None, + mqtt_disable_tls=False, retry_count=retry_count, device_passwords=device_passwords, fetch_device_info=fetch_device_info, @@ -72,6 +73,7 @@ def test__run( fetch_device_info=fetch_device_info, ) assert not mqtt_client_mock().username_pw_set.called + mqtt_client_mock().tls_set.assert_called_once_with(ca_certs=None) mqtt_client_mock().connect.assert_called_once_with(host=mqtt_host, port=mqtt_port) mqtt_client_mock().socket().getpeername.return_value = (mqtt_host, mqtt_port) with caplog.at_level(logging.DEBUG): @@ -125,6 +127,25 @@ def test__run( ) in caplog.record_tuples +@pytest.mark.parametrize("mqtt_disable_tls", [True, False]) +def test__run_tls(mqtt_disable_tls: bool) -> None: + with unittest.mock.patch("paho.mqtt.client.Client") as mqtt_client_mock: + switchbot_mqtt._run( + mqtt_host="mqtt.local", + mqtt_port=1234, + mqtt_username=None, + mqtt_password=None, + mqtt_disable_tls=mqtt_disable_tls, + retry_count=21, + device_passwords={}, + fetch_device_info=True, + ) + if mqtt_disable_tls: + mqtt_client_mock().tls_set.assert_not_called() + else: + mqtt_client_mock().tls_set.assert_called_once_with(ca_certs=None) + + @pytest.mark.parametrize("mqtt_host", ["mqtt-broker.local"]) @pytest.mark.parametrize("mqtt_port", [1833]) @pytest.mark.parametrize("mqtt_username", ["me"]) @@ -141,6 +162,7 @@ def test__run_authentication( mqtt_port=mqtt_port, mqtt_username=mqtt_username, mqtt_password=mqtt_password, + mqtt_disable_tls=True, retry_count=7, device_passwords={}, fetch_device_info=True, @@ -168,6 +190,7 @@ def test__run_authentication_missing_username( mqtt_port=mqtt_port, mqtt_username=None, mqtt_password=mqtt_password, + mqtt_disable_tls=True, retry_count=3, device_passwords={}, fetch_device_info=True,