Skip to content

Commit

Permalink
Add more type hints, fix some pylint warnings (#71)
Browse files Browse the repository at this point in the history
* More type hints

* Add more type hints, fix some pylint warnings

* More fixes

* Move zeroconf import
  • Loading branch information
KapJI committed Apr 11, 2021
1 parent 0a6265c commit 8b572bd
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 73 deletions.
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ repos:
types: [python]
- id: flake8
name: flake8
entry: poetry run flake8 --max-line-length=88
exclude: glocaltokens/google/
entry: poetry run flake8
language: system
types: [python]
- id: pylint
Expand Down
22 changes: 15 additions & 7 deletions glocaltokens/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def get_access_token(self) -> Optional[str]:
)
return self.access_token

# pylint: disable=no-member
def get_homegraph(self):
"""Returns the entire Google Home Foyer V2 service"""
if self.homegraph is None or self._has_expired(
Expand All @@ -289,20 +288,26 @@ def get_homegraph(self):
"There is no stored homegraph, or it has expired, getting a new one..."
)
log_prefix = "[GRPC]"
access_token = self.get_access_token()
if not access_token:
LOGGER.debug("%s Unable to obtain access token.", log_prefix)
return None
try:
LOGGER.debug("%s Creating SSL channel credentials...", log_prefix)
scc = grpc.ssl_channel_credentials(root_certificates=None)
LOGGER.debug("%s Creating access token call credentials...", log_prefix)
tok = grpc.access_token_call_credentials(self.get_access_token())
tok = grpc.access_token_call_credentials(access_token)
LOGGER.debug("%s Compositing channel credentials...", log_prefix)
ccc = grpc.composite_channel_credentials(scc, tok)
channel_credentials = grpc.composite_channel_credentials(scc, tok)

LOGGER.debug(
"%s Establishing secure channel with "
"the Google Home Foyer API...",
log_prefix,
)
with grpc.secure_channel(GOOGLE_HOME_FOYER_API, ccc) as channel:
with grpc.secure_channel(
GOOGLE_HOME_FOYER_API, channel_credentials
) as channel:
LOGGER.debug(
"%s Getting channels StructuresServiceStub...", log_prefix
)
Expand All @@ -316,7 +321,10 @@ def get_homegraph(self):
self.homegraph_date = datetime.now()
except grpc.RpcError as rpc_error:
LOGGER.debug("%s Got an RpcError", log_prefix)
if rpc_error.code().name == "UNAUTHENTICATED":
if (
rpc_error.code().name # pylint: disable=no-member
== "UNAUTHENTICATED"
):
LOGGER.warning(
"%s The access token has expired. Getting a new one.",
log_prefix,
Expand All @@ -326,8 +334,8 @@ def get_homegraph(self):
LOGGER.error(
"%s Received unknown RPC error: code=%s message=%s",
log_prefix,
rpc_error.code(),
rpc_error.details(),
rpc_error.code(), # pylint: disable=no-member
rpc_error.details(), # pylint: disable=no-member
)
return self.homegraph

Expand Down
96 changes: 49 additions & 47 deletions glocaltokens/scanner.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,61 @@
"""Zeroconf based scanner"""
import logging
from threading import Event
from typing import List, Optional
from typing import Callable, List, Optional

from zeroconf import ServiceListener
import zeroconf
from zeroconf import ServiceInfo, ServiceListener, Zeroconf

from .const import DISCOVERY_TIMEOUT
from .utils import network as net_utils, types as type_utils
from .utils import network as net_utils

LOGGER = logging.getLogger(__name__)


# pylint: disable=invalid-name
class CastListener(ServiceListener):
"""
Zeroconf Cast Services collection.
Credit (pychromecast):
https://github.com/home-assistant-libs/pychromecast/
"""

def __init__(self, add_callback=None, remove_callback=None, update_callback=None):
self.devices = []
def __init__(
self,
add_callback: Optional[Callable[[], None]] = None,
remove_callback: Optional[Callable[[], None]] = None,
update_callback: Optional[Callable[[], None]] = None,
):
self.devices: List[GoogleDevice] = []
self.add_callback = add_callback
self.remove_callback = remove_callback
self.update_callback = update_callback

@property
def count(self):
def count(self) -> int:
"""Number of discovered cast services."""
return len(self.devices)

def add_service(self, zc, type_, name):
def add_service(self, zc: Zeroconf, type_: str, name: str) -> None:
""" Add a service to the collection. """
LOGGER.debug("add_service %s, %s", type_, name)
self._add_update_service(zc, type_, name, self.add_callback)

def update_service(self, zc, type_, name):
def update_service(self, zc: Zeroconf, type_: str, name: str) -> None:
""" Update a service in the collection. """
LOGGER.debug("update_service %s, %s", type_, name)
self._add_update_service(zc, type_, name, self.update_callback)

def remove_service(self, _zconf, type_, name):
def remove_service(self, _zc: Zeroconf, type_: str, name: str) -> None:
"""Called when a cast has beeen lost (mDNS info expired or host down)."""
LOGGER.debug("remove_service %s, %s", type_, name)

def _add_update_service(self, zc, type_, name, callback):
def _add_update_service(
self,
zc: Zeroconf,
type_: str,
name: str,
callback: Optional[Callable[[], None]],
) -> None:
""" Add or update a service. """
service = None
tries = 0
Expand All @@ -64,25 +75,33 @@ def _add_update_service(self, zc, type_, name, callback):
LOGGER.debug("_add_update_service failed to add %s, %s", type_, name)
return

def get_value(key):
"""Retrieve value and decode to UTF-8."""
value = service.properties.get(key.encode("utf-8"))

if value is None or isinstance(value, str):
return value
return value.decode("utf-8")

addresses = service.parsed_addresses()
host = addresses[0] if addresses else service.server

model_name = get_value("md")
friendly_name = get_value("fn")
model_name = self.get_service_value(service, "md")
friendly_name = self.get_service_value(service, "fn")

if not model_name or not friendly_name or not service.port:
LOGGER.debug(
"Device %s doesn't have friendly name, model name or port, skipping...",
host,
)
return

self.devices.append((model_name, friendly_name, host, service.port))
self.devices.append(GoogleDevice(friendly_name, host, service.port, model_name))

if callback:
callback()

@staticmethod
def get_service_value(service: ServiceInfo, key: str) -> Optional[str]:
"""Retrieve value and decode to UTF-8."""
value = service.properties.get(key.encode("utf-8"))

if value is None or isinstance(value, str):
return value
return value.decode("utf-8")


class GoogleDevice:
"""Discovered Google device representation"""
Expand All @@ -95,10 +114,6 @@ def __init__(self, name: str, ip_address: str, port: int, model: str):
LOGGER.error("IP must be a valid IP address")
return

if not type_utils.is_integer(port):
LOGGER.error("PORT must be an integer value")
return

self.name = name
self.ip_address = ip_address
self.port = port
Expand Down Expand Up @@ -134,19 +149,16 @@ def discover_devices(
LOGGER.setLevel(logging_level)

LOGGER.debug("Discovering devices...")
LOGGER.debug("Importing zeroconf...")
# pylint: disable=import-outside-toplevel
import zeroconf

def callback():
def callback() -> None:
"""Called when zeroconf has discovered a new chromecast."""
if max_devices is not None and listener.count >= max_devices:
discovery_complete.set()

LOGGER.debug("Creating new Event for discovery completion...")
discovery_complete = Event()
LOGGER.debug("Creating new CastListener...")
listener = CastListener(callback)
listener = CastListener(add_callback=callback)
if not zeroconf_instance:
LOGGER.debug("Creating new Zeroconf instance")
zc = zeroconf.Zeroconf()
Expand All @@ -160,24 +172,14 @@ def callback():
LOGGER.debug("Waiting for discovery completion...")
discovery_complete.wait(timeout)

devices = []
devices: List[GoogleDevice] = []
LOGGER.debug("Got %s devices. Iterating...", len(listener.devices))
for service in listener.devices:
model = service[0]
name = service[1]
ip_address = service[2]
access_port = service[3]
if not models_list or model in models_list:
LOGGER.debug(
"Appending new device. name: %s, ip: %s, port: %s, model: %s",
name,
ip_address,
access_port,
model,
)
devices.append(GoogleDevice(name, ip_address, int(access_port), model))
for device in listener.devices:
if not models_list or device.model in models_list:
LOGGER.debug("Appending new device: %s", device)
devices.append(device)
else:
LOGGER.debug(
'Won\'t add device since model "%s" is not in models_list', model
'Won\'t add device since model "%s" is not in models_list', device.model
)
return devices
1 change: 0 additions & 1 deletion glocaltokens/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def is_float(variable):
return isinstance(variable, float)


# pylint: disable=too-few-public-methods
class Struct:
"""Structure type"""

Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,21 @@ ipdb = "^0.13.7"
extension-pkg-whitelist = [
"_socket",
]
ignore = ["google"]

[tool.pylint.basic]
good-names = [
"zc",
]

[tool.pylint.format]
max-line-length = 88
min-similarity-lines = 7

[tool.pylint.messages_control]
# Reasons disabled:
# too-many-* - are not enforced for the sake of readability
# too-few-* - same as too-many-*
disable = [
"too-few-public-methods",
"too-many-arguments",
Expand Down
18 changes: 18 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
[flake8]
exclude = glocaltokens/google
doctests = True
# To work with Black
max-line-length = 88
# E501: line too long
# W503: Line break occurred before a binary operator
# E203: Whitespace before ':'
# D202 No blank lines allowed after function docstring
# W504 line break after binary operator
ignore =
E501,
W503,
E203,
D202,
W504

[mypy]
python_version = 3.8
check_untyped_defs = True

[mypy-faker]
ignore_missing_imports = True
Expand Down
9 changes: 5 additions & 4 deletions tests/assertions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""
Common assertion helper classes used for unittesting
"""
# pylint: disable=no-member
# pylint: disable=invalid-name
from unittest import TestCase

import glocaltokens.utils.token as token_utils


class DeviceAssertions:
class DeviceAssertions(TestCase):
"""Device specific assessors"""

def assertDevice(self, homegraph_device, homegraph_device_struct):
Expand All @@ -26,7 +27,7 @@ def assertDevice(self, homegraph_device, homegraph_device_struct):
)


class TypeAssertions:
class TypeAssertions(TestCase):
"""Type assessors"""

def assertIsString(self, variable):
Expand All @@ -36,7 +37,7 @@ def assertIsString(self, variable):
msg=f"Given variable {variable} is not String type",
)

def assertIsAASET(self, variable):
def assertIsAasEt(self, variable):
"""Assert the given variable is a of string type and follows AAS token format"""
self.assertTrue(
isinstance(variable, str) and token_utils.is_aas_et(variable),
Expand Down
6 changes: 4 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_initialization(self):
self.assertIsNone(client.access_token_date)
self.assertIsNone(client.homegraph_date)

self.assertIsAASET(client.master_token)
self.assertIsAasEt(client.master_token)

@patch("glocaltokens.client.LOGGER.error")
def test_initialization__valid(self, m_log):
Expand Down Expand Up @@ -235,6 +235,7 @@ def test_get_access_token(self, m_perform_oauth, m_get_master_token, m_log):
self.assertEqual(m_perform_oauth.call_count, 0)

# Another request with expired token must return new token (new request)
assert self.client.access_token_date is not None
self.client.access_token_date = self.client.access_token_date - timedelta(
ACCESS_TOKEN_DURATION + 1
)
Expand All @@ -250,7 +251,7 @@ def test_get_access_token(self, m_perform_oauth, m_get_master_token, m_log):
@patch("glocaltokens.client.GLocalAuthenticationTokens.get_access_token")
def test_get_homegraph(
self,
m_get_access_token, # pylint: disable=unused-argument
_m_get_access_token,
m_get_home_graph_request,
m_structure_service_stub,
m_secure_channel,
Expand Down Expand Up @@ -279,6 +280,7 @@ def test_get_homegraph(
self.assertEqual(m_get_home_graph_request.call_count, 1)

# Expired homegraph
assert self.client.homegraph_date is not None
self.client.homegraph_date = self.client.homegraph_date - timedelta(
HOMEGRAPH_DURATION + 1
)
Expand Down
Loading

0 comments on commit 8b572bd

Please sign in to comment.