Skip to content

Commit

Permalink
Try both protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
dswd committed Oct 5, 2023
1 parent 4f64ee0 commit f20f849
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 21 deletions.
46 changes: 40 additions & 6 deletions octoprint_psucontrol_tapo/__init__.py
@@ -1,5 +1,6 @@
# coding=utf-8
from __future__ import absolute_import
import threading

__author__ = "Dennis Schwerdel <schwerdel@gmail.com>"
__license__ = "GNU Affero General Public License http://www.gnu.org/licenses/agpl.html"
Expand All @@ -17,6 +18,7 @@ class PSUControl_Tapo(octoprint.plugin.StartupPlugin,
def __init__(self):
self.config = dict()
self.device = None
self.last_status = None


def get_settings_defaults(self):
Expand Down Expand Up @@ -44,6 +46,10 @@ def on_settings_migrate(self, target, current=None):
pass


def _reconnect(self):
self._logger.info(f"Connecting to Tapo device at {self.config['address']}")
self.device = P100(self.config["address"], self.config["username"], self.config["password"])

def reload_settings(self):
for k, v in self.get_settings_defaults().items():
if type(v) == str:
Expand All @@ -60,7 +66,7 @@ def reload_settings(self):
try:
self._logger.info(f"Config: {self.config}")
tapo.log = self._logger
self.device = P100(self.config["address"], self.config["username"], self.config["password"])
self._reconnect()
except:
self._logger.exception(f"Failed to connect to Tapo device")

Expand All @@ -76,19 +82,47 @@ def on_startup(self, host, port):


def turn_psu_on(self):
if not self.device:
self._reconnect()
self._logger.debug("Switching PSU On")
self.device.set_status(True)

try:
self.device.set_status(True)
self.last_status = True
except:
self._logger.exception(f"Failed to switch PSU On")
self.device = None
raise

def turn_psu_off(self):
if not self.device:
self._reconnect()
self._logger.debug("Switching PSU Off")
self.device.set_status(False)
try:
self.device.set_status(False)
self.last_status = False
except:
self._logger.exception(f"Failed to switch PSU Off")
self.device = None
raise


def get_psu_state(self):
def _fetch_psu_state(self):
if not self.device:
self._reconnect()
self._logger.debug("get_psu_state")
return self.device.get_status()
try:
self.last_status = self.device.get_status()
except:
self._logger.exception(f"Failed to get PSU state")
self.device = None
raise

def get_psu_state(self):
if not self.last_status:
self._fetch_psu_state()
else:
threading.Thread(target=self._fetch_psu_state).start()
return self.last_status

def get_template_configs(self):
return [
Expand Down
207 changes: 193 additions & 14 deletions octoprint_psucontrol_tapo/tapo.py
@@ -1,7 +1,9 @@
import json, time, uuid, logging
import os.path
from base64 import b64encode, b64decode
import requests
from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
from Crypto.Cipher import AES, PKCS1_v1_5
from Crypto.Hash import SHA256, SHA1
from Crypto.Random import get_random_bytes
import hashlib
Expand All @@ -14,13 +16,9 @@ def sha1(data: bytes) -> bytes:
def sha256(data: bytes) -> bytes:
return SHA256.new(data).digest()

def calc_auth_hash(username: str, password: str) -> bytes:
return sha256(sha1(username.encode()) + sha1(password.encode()))

class Device:
def __init__(self, address: str, username: str, password: str, keypair_file: str = '/tmp/tapo.key'):
class NewProtocol:
def __init__(self, address: str, username: str, password: str):
self.session = requests.Session() # single session, stores cookie
self.terminal_uuid = str(uuid.uuid4())
self.address = address
self.username = username
self.password = password
Expand All @@ -29,6 +27,9 @@ def __init__(self, address: str, username: str, password: str, keypair_file: str
self.seq = None
self.sig = None

def calc_auth_hash(self, username: str, password: str) -> bytes:
return sha256(sha1(username.encode()) + sha1(password.encode()))

def _request_raw(self, path: str, data: bytes, params: dict = None):
url = f"http://{self.address}/app/{path}"
resp = self.session.post(url, data=data, timeout=2, params=params)
Expand All @@ -50,8 +51,14 @@ def _request(self, method: str, params: dict = None):
result = self._request_raw("request", encrypted, params={"seq": self.seq})
# Unwrap and decrypt result
data = json.loads(self._decrypt(result).decode("UTF-8"))
log.debug(f"Response: {data}")
return data.get("result")
# Check error code and get result
if data["error_code"] != 0:
log.error(f"Error: {data}")
self.key = None
raise Exception(f"Error code: {data['error_code']}")
result = data.get("result")
log.debug(f"Response: {result}")
return result

def _encrypt(self, data: bytes):
self.seq += 1
Expand Down Expand Up @@ -82,7 +89,7 @@ def _initialize(self):
remote_seed, server_hash = response[0:16], response[16:]
auth_hash = None
for creds in [(self.username, self.password), ("", ""), ("kasa@tp-link.net", "kasaSetup")]:
ah = calc_auth_hash(*creds)
ah = self.calc_auth_hash(*creds)
local_seed_auth_hash = sha256(local_seed + remote_seed + ah)
if local_seed_auth_hash == server_hash:
auth_hash = ah
Expand All @@ -97,12 +104,184 @@ def _initialize(self):
self.seq = int.from_bytes(ivseq[-4:], "big", signed=True)
self.sig = sha256(b"ldk" + local_seed + remote_seed + auth_hash)[:28]
log.debug(f"Initialized")



class OldProtocol:
def __init__(self, address: str, username: str, password: str, keypair_file: str = '/tmp/tapo.key'):
self.session = requests.Session() # single session, stores cookie
self.terminal_uuid = str(uuid.uuid4())
self.address = address
self.username = username
self.password = password
self.keypair_file = keypair_file
self._create_keypair()
self.key = None
self.iv = None

def _create_keypair(self):
if self.keypair_file and os.path.exists(self.keypair_file):
with open(self.keypair_file, 'r') as f:
self.keypair = RSA.importKey(f.read())
else:
self.keypair = RSA.generate(1024)
if self.keypair_file:
with open(self.keypair_file, "wb") as f:
f.write(self.keypair.exportKey("PEM"))


def _request_raw(self, method: str, params: dict = None):
# Construct url, add token if we have one
url = f"http://{self.address}/app"
if self.token:
url += f"?token={self.token}"

# Construct payload, add params if given
payload = {
"method": method,
"requestTimeMils": int(round(time.time() * 1000)),
"terminalUUID": self.terminal_uuid
}
if params:
payload["params"] = params
log.debug(f"Request raw: {payload}")

# Execute call
resp = self.session.post(url, json=payload, timeout=0.5)
resp.raise_for_status()
data = resp.json()

# Check error code and get result
if data["error_code"] != 0:
log.error(f"Error: {data}")
self.key = None
raise Exception(f"Error code: {data['error_code']}")
result = data.get("result")

log.debug(f"Response raw: {result}")
return result


def _request(self, method: str, params: dict = None):
if not self.key:
self._initialize()

# Construct payload, add params if given
payload = {
"method": method,
"requestTimeMils": int(round(time.time() * 1000)),
"terminalUUID": self.terminal_uuid
}
if params:
payload["params"] = params
log.debug(f"Request: {payload}")

# Encrypt payload and execute call
encrypted = self._encrypt(json.dumps(payload))

result = self._request_raw("securePassthrough", {"request": encrypted})

# Unwrap and decrypt result
data = json.loads(self._decrypt(result["response"]))
if data["error_code"] != 0:
log.error(f"Error: {data}")
self.key = None
raise Exception(f"Error code: {data['error_code']}")
result = data.get("result")

log.debug(f"Response: {result}")
return result


def _encrypt(self, data: str):
data = data.encode("UTF-8")

# Add PKCS#7 padding
pad_l = 16 - (len(data) % 16)
data = data + bytes([pad_l] * pad_l)

# Encrypt data with key
crypto = AES.new(self.key, AES.MODE_CBC, self.iv)
data = crypto.encrypt(data)

# Base64 encode
data = b64encode(data).decode("UTF-8")
return data


def _decrypt(self, data: str):
# Base64 decode data
data = b64decode(data.encode("UTF-8"))

# Decrypt data with key
crypto = AES.new(self.key, AES.MODE_CBC, self.iv)
data = crypto.decrypt(data)

# Remove PKCS#7 padding
data = data[:-data[-1]]
return data.decode("UTF-8")


def _initialize(self):
# Unset key and token
self.key = None
self.token = None

# Send public key and receive encrypted symmetric key
public_key = self.keypair.publickey().exportKey("PEM").decode("UTF-8")
public_key = public_key.replace("RSA PUBLIC KEY", "PUBLIC KEY")
result = self._request_raw("handshake", {
"key": public_key
})
encrypted = b64decode(result["key"].encode("UTF-8"))

# Decrypt symmetric key
cipher = PKCS1_v1_5.new(self.keypair)
decrypted = cipher.decrypt(encrypted, None)
self.key, self.iv = decrypted[:16], decrypted[16:]

# Base64 encode password and hashed username
digest = hashlib.sha1(self.username.encode("UTF-8")).hexdigest()
username = b64encode(digest.encode("UTF-8")).decode("UTF-8")
password = b64encode(self.password.encode("UTF-8")).decode("UTF-8")

# Send login info and receive session token
result = self._request("login_device", {
"username": username,
"password": password
})
self.token = result["token"]


class Device:
def __init__(self, address: str, username: str, password: str, **kwargs):
self.address = address
self.username = username
self.password = password
self.kwargs = kwargs
self.protocol = None

def _initialize(self):
for protocol_class in [NewProtocol, OldProtocol]:
if not self.protocol:
try:
protocol = protocol_class(self.address, self.username, self.password, **self.kwargs)
protocol._initialize()
self.protocol = protocol
except:
log.exception(f"Failed to initialize protocol {protocol_class.__name__}")
if not self.protocol:
raise Exception("Failed to initialize protocol")

def request(self, method: str, params: dict = None):
if not self.protocol:
self._initialize()
return self.protocol._request(method, params)

def _get_device_info(self):
return self._request("get_device_info")
return self.request("get_device_info")

def _set_device_info(self, params: dict):
return self._request("set_device_info", params)
return self.request("set_device_info", params)

def get_type(self) -> str:
return self._get_device_info()["model"]
Expand Down Expand Up @@ -133,7 +312,7 @@ def toggle(self):

class Metering(Device):
def get_energy_usage(self) -> dict:
return self._request("get_energy_usage")
return self.request("get_energy_usage")


class Dimmable(Device):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -6,7 +6,7 @@
plugin_identifier = "psucontrol_tapo"
plugin_package = "octoprint_%s" % plugin_identifier
plugin_name = "OctoPrint-PSUControl-Tapo"
plugin_version = "0.3.0"
plugin_version = "0.4.0"
plugin_description = "Adds Tapo Smart Plug support to OctoPrint-PSUControl as a sub-plugin"
plugin_author = "Dennis Schwerdel"
plugin_author_email = "schwerdel@gmail.com"
Expand Down

0 comments on commit f20f849

Please sign in to comment.