Skip to content

Commit

Permalink
Implement get on CaseInsensitiveDict
Browse files Browse the repository at this point in the history
get was previously provided by the parent class which
had to raise KeyError for missing values. Since try/except
is only cheap for the non-exception case the performance
was not good when the key was missing

similar to python/cpython#106665
but in the HA case we call this even more frequently
  • Loading branch information
bdraco committed Aug 13, 2023
1 parent ac4171a commit a7628d3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 26 deletions.
24 changes: 13 additions & 11 deletions async_upnp_client/advertisement.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,26 @@ def __init__(

def _on_data(self, request_line: str, headers: CaseInsensitiveDict) -> None:
"""Handle data."""
if headers.get("MAN") == SSDP_DISCOVER:
if headers.get_lower("man") == SSDP_DISCOVER:
# Ignore discover packets.
return
if "NTS" not in headers:

notification_sub_type = headers.get_lower("nts")
if notification_sub_type is None:
_LOGGER.debug("Got non-advertisement packet: %s, %s", request_line, headers)
return

_LOGGER.debug(
"Received advertisement, _remote_addr: %s, NT: %s, NTS: %s, USN: %s, location: %s",
headers.get("_remote_addr", ""),
headers.get("NT", "<no NT>"),
headers.get("NTS", "<no NTS>"),
headers.get("USN", "<no USN>"),
headers.get("location", ""),
)
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Received advertisement, _remote_addr: %s, NT: %s, NTS: %s, USN: %s, location: %s",
headers.get_lower("_remote_addr", ""),
headers.get_lower("nt", "<no NT>"),
headers.get_lower("nts", "<no NTS>"),
headers.get_lower("usn", "<no USN>"),
headers.get_lower("location", ""),
)

headers["_source"] = SsdpSource.ADVERTISEMENT
notification_sub_type = headers["NTS"]
if notification_sub_type == NotificationSubType.SSDP_ALIVE:
if self.async_on_alive:
coro = self.async_on_alive(headers)
Expand Down
24 changes: 12 additions & 12 deletions async_upnp_client/ssdp_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
def valid_search_headers(headers: CaseInsensitiveDict) -> bool:
"""Validate if this search is usable."""
# pylint: disable=invalid-name
udn = headers.get("_udn") # type: Optional[str]
st = headers.get("st") # type: Optional[str]
location = headers.get("location", "") # type: str
udn = headers.get_lower("_udn") # type: Optional[str]
st = headers.get_lower("st") # type: Optional[str]
location = headers.get_lower("location", "") # type: str
return bool(
udn
and st
Expand All @@ -60,10 +60,10 @@ def valid_search_headers(headers: CaseInsensitiveDict) -> bool:
def valid_advertisement_headers(headers: CaseInsensitiveDict) -> bool:
"""Validate if this advertisement is usable for connecting to a device."""
# pylint: disable=invalid-name
udn = headers.get("_udn") # type: Optional[str]
nt = headers.get("nt") # type: Optional[str]
nts = headers.get("nts") # type: Optional[str]
location = headers.get("location", "") # type: str
udn = headers.get_lower("_udn") # type: Optional[str]
nt = headers.get_lower("nt") # type: Optional[str]
nts = headers.get_lower("nts") # type: Optional[str]
location = headers.get_lower("location", "") # type: str
return bool(
udn
and nt
Expand All @@ -81,15 +81,15 @@ def valid_advertisement_headers(headers: CaseInsensitiveDict) -> bool:
def valid_byebye_headers(headers: CaseInsensitiveDict) -> bool:
"""Validate if this advertisement has required headers for byebye."""
# pylint: disable=invalid-name
udn = headers.get("_udn") # type: Optional[str]
nt = headers.get("nt") # type: Optional[str]
nts = headers.get("nts") # type: Optional[str]
udn = headers.get_lower("_udn") # type: Optional[str]
nt = headers.get_lower("nt") # type: Optional[str]
nts = headers.get_lower("nts") # type: Optional[str]
return bool(udn and nt and nts)


def extract_valid_to(headers: CaseInsensitiveDict) -> datetime:
"""Extract/create valid to."""
cache_control = headers.get("cache-control", "")
cache_control = headers.get_lower("cache-control", "")
match = CACHE_CONTROL_RE.search(cache_control)
if match:
max_age = int(match[1])
Expand Down Expand Up @@ -247,7 +247,7 @@ def ip_version_from_location(location: str) -> Optional[int]:

def location_changed(ssdp_device: SsdpDevice, headers: CaseInsensitiveDict) -> bool:
"""Test if location changed for device."""
new_location = headers.get("location", "")
new_location = headers.get_lower("location", "")
if not new_location:
return False

Expand Down
19 changes: 16 additions & 3 deletions async_upnp_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def as_lower_dict(self) -> Dict[str, Any]:
"""Return the underlying dict in lowercase."""
return {k.lower(): v for k, v in self._data.items()}

def get_lower(self, lower_key: str) -> Any:
def get_lower(self, lower_key: str, default: Any = None) -> Any:
"""Get a lower case key."""
data_key = self._case_map.get(lower_key, _SENTINEL)
if data_key is not _SENTINEL:
return self._data[data_key]
return None
return self._data.get(data_key, default)
return default

def replace(self, new_data: abcMapping) -> None:
"""Replace the underlying dict."""
Expand All @@ -56,6 +56,19 @@ def replace(self, new_data: abcMapping) -> None:
self._data = {**new_data}
self._case_map = {k.lower(): k for k in self._data}

def get(self, key: str, default: Any = None) -> Any:
"""Get item with default.
This implementation is case insensitive and avoids
calling __getitem__ which would raise KeyError and
cause unnecessary exception handling.
"""
case_map = self._case_map
data_key = case_map.get(key, case_map.get(key.lower(), _SENTINEL))
if data_key is not _SENTINEL:
return self._data.get(data_key, default)
return default

def __setitem__(self, key: str, value: Any) -> None:
"""Set item."""
lower_key = key.lower()
Expand Down

0 comments on commit a7628d3

Please sign in to comment.