Skip to content

Commit

Permalink
Linting updates and MQTT disconnect functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jordanruthe committed Jan 14, 2024
1 parent 50326b0 commit 32dac15
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 56 deletions.
45 changes: 31 additions & 14 deletions aiophyn/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging

from datetime import datetime, timezone
from typing import Any, Dict, Union, Optional

import inspect
import json
Expand All @@ -15,13 +15,12 @@
import socks
import paho.mqtt.client as paho_mqtt

from typing import Any, Dict, Union, Optional

from .const import API_BASE

_LOGGER = logging.getLogger(__name__)

class AIOHelper:
"""Helper class for Asynchronous IO"""
def __init__(self, client: paho_mqtt.Client) -> None:
self.loop = asyncio.get_running_loop()
self.client = client
Expand Down Expand Up @@ -62,9 +61,11 @@ def _on_socket_unregister_write(self,
userdata: Any,
sock: socket.socket
) -> None:
# pylint: disable=unused-argument
self.loop.remove_writer(sock)

async def misc_loop(self) -> None:
"""Loop for MQTT"""
while self.client.loop_misc() == paho_mqtt.MQTT_ERR_SUCCESS:
try:
await asyncio.sleep(1)
Expand Down Expand Up @@ -103,7 +104,8 @@ def start(self, timeout):
self._task = asyncio.create_task(self._job(timeout))

class MQTTClient:
def __init__(self, api, client_id=None, verify_ssl=True, proxy=None, proxy_port=None):
"""AIO MQTT client """
def __init__(self, api, client_id: str =None, verify_ssl: bool =True, proxy: str =None, proxy_port: int =None):
self.event_loop = asyncio.get_running_loop()
self.api = api
self.pending_acks = {}
Expand All @@ -114,11 +116,12 @@ def __init__(self, api, client_id=None, verify_ssl=True, proxy=None, proxy_port=
self.reconnect_evt: asyncio.Event = asyncio.Event()
self.host = None
self.port = 443

if client_id is None:
client_id = "aiophyn-%s" % int(time.time())

self.client = paho_mqtt.Client(client_id=client_id, transport="websockets")
self.helper: AIOHelper = None
self.reconnect_timer = Timer(self._process_reconnect)

self.verify_ssl: bool = verify_ssl
Expand All @@ -137,6 +140,7 @@ def __init__(self, api, client_id=None, verify_ssl=True, proxy=None, proxy_port=
}

async def add_event_handler(self, type, target):
"""Add an event handler for MQTT events"""
if type not in self._handlers.keys():
return False

Expand All @@ -152,13 +156,13 @@ async def connect(self):
if self.verify_ssl:
self.client.tls_set()
else:
context = ssl.SSLContext()
context.verify_mode = ssl.CERT_NONE
context = ssl.SSLContext()
context.verify_mode = ssl.CERT_NONE
context.check_hostname = False
self.client.tls_set_context(context)
self.client.tls_insecure_set(True)
if self.proxy is not None and self.proxy_port is not None:

if self.proxy is not None and self.proxy_port is not None:
self.client.proxy_set(proxy_type=socks.HTTP, proxy_addr=self.proxy, proxy_port=self.proxy_port)

self.helper = AIOHelper(self.client)
Expand All @@ -170,12 +174,23 @@ async def connect(self):
self.host,
self.port,
)

def disconnect(self):
"""Disconnect from server"""
self.disconnect_evt = asyncio.Event()
_LOGGER.info("MQTT client disconnecting...")
self.client.disconnect()

async def disconnect_and_wait(self):
"""Disconnect from server and wait"""
self.disconnect()
await self.disconnect_evt.wait()

async def get_mqtt_info(self):
""" Gets WebSocket URL and parameters for a MQTT connection
Returns a list of url and path
"""
user_id = urllib.parse.quote_plus(self.api._username)
user_id = urllib.parse.quote_plus(self.api.username)
try:
wss_data = await self.api._request("post", f"{API_BASE}/users/{user_id}/iot_policy", token_type="id")
except:
Expand All @@ -189,6 +204,7 @@ async def get_mqtt_info(self):


async def subscribe(self, topic):
"""Subscribe to a MQTT topic"""
_LOGGER.info("Attempting to subscribe to: %s", topic)
res, msg_id = self.client.subscribe(topic, 0)
self.pending_acks[msg_id] = topic
Expand All @@ -213,7 +229,7 @@ def _on_connect(self,
err_str = paho_mqtt.connack_string(reason_code)
else:
err_str = reason_code.getName()
_LOGGER.info(f"MQTT Connection Failed: {err_str}")
_LOGGER.info("MQTT Connection Failed: %s", err_str)

def _on_disconnect(self,
client: paho_mqtt.Client,
Expand All @@ -224,6 +240,7 @@ def _on_disconnect(self,
# pylint: disable=unused-argument
if self.disconnect_evt is not None:
self.disconnect_evt.set()
_LOGGER.info("Client disconnected, not attempting to reconnect")
elif self.is_connected():
# The server connection was dropped, attempt to reconnect
_LOGGER.info("MQTT Server Disconnected, reason: %s", paho_mqtt.error_string(reason_code))
Expand Down Expand Up @@ -251,7 +268,7 @@ async def _do_reconnect(self, first: bool = False) -> None:
if self.reconnect_evt.is_set():
_LOGGER.info("Already attempting to reconnect, second attemp cancelled.")
return

_LOGGER.info("Attempting MQTT Connect/Reconnect")
self.reconnect_evt.set()
last_err: Exception = Exception()
Expand Down Expand Up @@ -281,7 +298,7 @@ async def _do_reconnect(self, first: bool = False) -> None:
self.host,
self.port,
)

await asyncio.wait_for(self.connect_evt.wait(), timeout=2.)
if not self.connect_evt.is_set():
_LOGGER.info("Timeout while waiting for MQTT connection")
Expand Down Expand Up @@ -309,7 +326,7 @@ def _on_message(
) -> None:
# pylint: disable=unused-argument
msg = message.payload.decode()
_LOGGER.debug("Message received on %s", message.topic)
_LOGGER.debug("Message received on %s %s", message.topic, msg)
try:
data = json.loads(msg)
except json.decoder.JSONDecodeError:
Expand Down
53 changes: 33 additions & 20 deletions aiophyn/partners/kohler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,24 @@
import json
import base64
import binascii
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad

from datetime import datetime, timedelta

from typing import Optional

from aiohttp import ClientSession, ClientTimeout, CookieJar


from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad

_LOGGER = logging.getLogger(__name__)

DEFAULT_TIMEOUT: int = 10

class KOHLER_API:
"""API for Kohler to access Phyn Devices"""
def __init__(
self, username: str, password: str, session: Optional[ClientSession] = None,
verify_ssl: bool = True, proxy: Optional[str] = None, proxy_port: Optional[int] = None
self, username: str, password: str, verify_ssl: bool = True, proxy: Optional[str] = None,
proxy_port: Optional[int] = None
):
self._username: str = username
self._password: str = password
Expand All @@ -41,20 +42,24 @@ def __init__(
self.proxy_port = proxy_port
self.proxy_url: Optional[str] = None
if self.proxy is not None and self.proxy_port is not None:
self.proxy_url = "https://%s:%s" % (proxy, proxy_port)
self.proxy_url = f"https://{proxy}:{proxy_port}"

self._session: ClientSession = None

def get_cognito_info(self):
"""Get cognito information"""
return self._mobile_data['cognito']

def get_mqtt_info(self):
"""Get MQTT url"""
return self._mobile_data['wss']

def get_phyn_password(self):
"""Get phyn password"""
return self._phyn_password

async def authenticate(self):
"""Authenticate with Kohler and Phyn"""
use_running_session = self._session and not self._session.closed
if not use_running_session:
self._session = ClientSession(timeout=ClientTimeout(total=DEFAULT_TIMEOUT), cookie_jar=CookieJar())
Expand All @@ -65,6 +70,7 @@ async def authenticate(self):
self._phyn_password = await self.token_to_password(token)

async def b2c_login(self):
"""Login to Kohler"""
_LOGGER.debug("Logging into Kohler")
client_request_id = str(uuid.uuid4())

Expand All @@ -73,7 +79,8 @@ async def b2c_login(self):
"response_type": "code",
"client_id": "8caf9530-1d13-48e6-867c-0f082878debc",
"client-request-id": client_request_id,
"scope": "https%3A%2F%2Fkonnectkohler.onmicrosoft.com%2Ff5d87f3d-bdeb-4933-ab70-ef56cc343744%2Fapiaccess%20openid%20offline_access%20profile",
"scope": "https%3A%2F%2Fkonnectkohler.onmicrosoft.com%2Ff5d87f3d-bdeb-4933-ab70-ef56cc343744%2Fapiaccess%20" +
"openid%20offline_access%20profile",
"redirect_uri": "msauth%3A%2F%2Fcom.kohler.hermoth%2F2DuDM2vGmcL4bKPn2xKzKpsy68k%253D",
"prompt": "login",
}
Expand All @@ -99,8 +106,9 @@ async def b2c_login(self):
"signInName": self._username,
"password": self._password,
}
resp = await self._session.post("https://konnectkohler.b2clogin.com/konnectkohler.onmicrosoft.com/B2C_1A_signin/SelfAsserted?p=B2C_1A_signin&" + state_properties, headers=headers, data=login_vars,
ssl=self.ssl, proxy=self.proxy_url)
resp = await self._session.post("https://konnectkohler.b2clogin.com/konnectkohler.onmicrosoft.com/" +
"B2C_1A_signin/SelfAsserted?p=B2C_1A_signin&" + state_properties,
headers=headers, data=login_vars, ssl=self.ssl, proxy=self.proxy_url)

params = {
"rememberMe": "false",
Expand All @@ -109,8 +117,9 @@ async def b2c_login(self):
"p": "B2C_1A_signin"
}
args = '&'.join([ f"{x[0]}={x[1]}" for x in params.items() ])
resp = await self._session.get("https://konnectkohler.b2clogin.com/konnectkohler.onmicrosoft.com/B2C_1A_signin/api/CombinedSigninAndSignup/confirmed?" + args, allow_redirects=False,
ssl=self.ssl, proxy=self.proxy_url)
resp = await self._session.get("https://konnectkohler.b2clogin.com/konnectkohler.onmicrosoft.com/" +
"B2C_1A_signin/api/CombinedSigninAndSignup/confirmed?" + args,
allow_redirects=False, ssl=self.ssl, proxy=self.proxy_url)
matches = re.search(r'code=([^&]+)', resp.headers['Location'])
code = matches.group(1)

Expand All @@ -126,12 +135,14 @@ async def b2c_login(self):
"x-app-name": "com.kohler.hermoth",
"x-app-ver": "2.7",
"redirect_uri": "msauth://com.kohler.hermoth/2DuDM2vGmcL4bKPn2xKzKpsy68k%3D",
"scope": "https://konnectkohler.onmicrosoft.com/f5d87f3d-bdeb-4933-ab70-ef56cc343744/apiaccess openid offline_access profile",
"scope": "https://konnectkohler.onmicrosoft.com/f5d87f3d-bdeb-4933-ab70-ef56cc343744/apiaccess" +
" openid offline_access profile",
"grant_type": "authorization_code",
"code": code,
}
resp = await self._session.post("https://konnectkohler.b2clogin.com/tfp/konnectkohler.onmicrosoft.com/B2C_1A_signin/%2FoAuth2%2Fv2.0%2Ftoken", data=params,
ssl=self.ssl, proxy=self.proxy_url)
resp = await self._session.post("https://konnectkohler.b2clogin.com/tfp/konnectkohler.onmicrosoft.com/" +
"B2C_1A_signin/%2FoAuth2%2Fv2.0%2Ftoken", data=params, ssl=self.ssl,
proxy=self.proxy_url)

data = await resp.json()
if "client_info" not in data:
Expand All @@ -149,6 +160,7 @@ async def b2c_login(self):
_LOGGER.debug("Received Kohler Token")

async def get_phyn_token(self):
""" Get a phyn access token"""
params = {
"partner": "kohler",
"partner_user_id": self._user_id,
Expand All @@ -169,7 +181,7 @@ async def get_phyn_token(self):
mobile_data = await resp.json()
if "error_msg" in mobile_data:
await self._session.close()
raise Exception("Kohler %s" % mobile_data['error_msg'])
raise Exception(f"Kohler {mobile_data['error_msg']}")

if "cognito" not in mobile_data:
await self._session.close()
Expand All @@ -182,11 +194,11 @@ async def get_phyn_token(self):
"partner": "kohler",
"partner_user_id": self._user_id
}
args = "&".join(["%s=%s" % (x, params[x]) for x in params.keys()])
args = "&".join([ f"{x[0]}={x[1]}" for x in params.items() ])
headers = {
"Accept": "application/json, text/plain, */*",
"Accept-encoding": "gzip",
"Authorization": "Bearer partner-%s" % self._token,
"Authorization": f"Bearer partner-{self._token}",
"Content-Type": "application/json",
"x-api-key": mobile_data['pws_api']['app_api_key']
}
Expand All @@ -200,12 +212,13 @@ async def get_phyn_token(self):
return data['token']

async def token_to_password(self, token):
"""Convert a token to a Phyn password"""
b64hex = base64.b64decode((token + '=' * (5 - (len(token) % 4))).replace('_','/').replace('-','+')).hex()

try:
keydata = binascii.hexlify(base64.b64decode(self._mobile_data['partner']['comm_id'])).decode()
except:
raise Exception("Error getting password decryption key")
except Exception as e:
raise Exception("Error getting password decryption key") from e

key = keydata[32:]
iv = b64hex[18:(18+32)]
Expand Down
44 changes: 23 additions & 21 deletions examples/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,42 @@

USERNAME = "USERNAME_HERE"
PASSWORD = "PASSWORD_HERE"
BRAND = "BRAND" # phyn or kohler
BRAND = "phyn" # phyn or kohler

async def on_message(device_id, data):
_LOGGER.info("Message for %s: %s" % (device_id, data))
"""Display a received MQTT message"""
_LOGGER.info("Message for %s: %s", device_id, data)

async def main() -> None:
"""Create the aiohttp session and run the example."""
logging.basicConfig(level=logging.INFO)
async with ClientSession() as session:
try:
api = await async_get_api(USERNAME, PASSWORD, phyn_brand=BRAND, session=session, verify_ssl=False)
try:
api = await async_get_api(USERNAME, PASSWORD, phyn_brand=BRAND)

all_home_info = await api.home.get_homes(USERNAME)
_LOGGER.info(all_home_info)
all_home_info = await api.home.get_homes(USERNAME)
_LOGGER.info(all_home_info)

home_info = all_home_info[0]
_LOGGER.info(home_info)
home_info = all_home_info[0]
_LOGGER.info(home_info)

first_device_id = home_info["device_ids"][0]
device_state = await api.device.get_state(first_device_id)
_LOGGER.info(device_state)
first_device_id = home_info["device_ids"][0]
device_state = await api.device.get_state(first_device_id)
_LOGGER.info(device_state)

await api.mqtt.add_event_handler("update", on_message)
await api.mqtt.connect()
await api.mqtt.add_event_handler("update", on_message)
await api.mqtt.connect()

for device in home_info['devices']:
if device['product_code'] in ['PP2']:
_LOGGER.info("Found PP2: %s" % device)
await api.mqtt.subscribe("prd/app_subscriptions/%s" % device['device_id'])
for device in home_info['devices']:
if device['product_code'] in ['PP1','PP2']:
_LOGGER.info("Found Phyn Plus: %s", device)
await api.mqtt.subscribe(f"prd/app_subscriptions/{device['device_id']}")

await asyncio.sleep(60)
await asyncio.sleep(10)

except PhynError as err:
_LOGGER.error("There was an error: %s", err)
await api.mqtt.disconnect_and_wait()

except PhynError as err:
_LOGGER.error("There was an error: %s", err)


asyncio.run(main())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aiophyn"
version = "2024.1.0"
version = "2024.1.1"
description = "An asynchronous library for Phyn Smart Water devices"
authors = ["MizterB <5458030+MizterB@users.noreply.github.com>","jordanruthe <31575189+jordanruthe@users.noreply.github.com>"]
license = "MIT"
Expand Down

0 comments on commit 32dac15

Please sign in to comment.