From cdbb9b4d605a164cbae1c564af9e3b5b6d39a424 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 1 Mar 2024 22:38:25 -1000 Subject: [PATCH] Cache network information instead of fetching it each time (#43) --- aiodiscover/discovery.py | 32 +++++++++++++++++++------------- aiodiscover/network.py | 8 +++++--- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/aiodiscover/discovery.py b/aiodiscover/discovery.py index 5bc6325..492ac4b 100644 --- a/aiodiscover/discovery.py +++ b/aiodiscover/discovery.py @@ -6,7 +6,7 @@ from contextlib import suppress from functools import lru_cache from ipaddress import IPv4Address -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from dns import exception, message, rdatatype from dns.message import Message, QueryMessage @@ -14,6 +14,9 @@ from .network import SystemNetworkData +if TYPE_CHECKING: + from pyroute2.iproute import IPRoute # noqa: F401 + HOSTNAME = "hostname" MAC_ADDRESS = "macaddress" IP_ADDRESS = "ip" @@ -183,24 +186,27 @@ class DiscoverHosts: def __init__(self) -> None: """Init the discovery hosts.""" - self._ip_route = None + self._sys_network_data: SystemNetworkData | None = None - def _get_sys_network_data(self) -> SystemNetworkData: - if not self._ip_route: - with suppress(Exception): - from pyroute2.iproute import ( - IPRoute, - ) # type: ignore # pylint: disable=import-outside-toplevel + def _setup_sys_network_data(self) -> None: + ip_route: "IPRoute" | None = None + with suppress(Exception): + from pyroute2.iproute import ( # noqa: F811 + IPRoute, + ) # type: ignore # pylint: disable=import-outside-toplevel - self._ip_route = IPRoute() - sys_network_data = SystemNetworkData(self._ip_route) + ip_route = IPRoute() + sys_network_data = SystemNetworkData(ip_route) sys_network_data.setup() - return sys_network_data + self._sys_network_data = sys_network_data async def async_discover(self) -> list[dict[str, str]]: """Discover hosts on the network by ARP and PTR lookup.""" - loop = asyncio.get_running_loop() - sys_network_data = await loop.run_in_executor(None, self._get_sys_network_data) + if not self._sys_network_data: + await asyncio.get_running_loop().run_in_executor( + None, self._setup_sys_network_data + ) + sys_network_data = self._sys_network_data network = sys_network_data.network assert network is not None if network.num_addresses > MAX_ADDRESSES: diff --git a/aiodiscover/network.py b/aiodiscover/network.py index 846aebc..634d106 100644 --- a/aiodiscover/network.py +++ b/aiodiscover/network.py @@ -6,13 +6,15 @@ import sys from contextlib import suppress from ipaddress import IPv4Network, ip_network -from typing import Any, Iterable +from typing import TYPE_CHECKING, Any, Iterable import ifaddr # type: ignore from cached_ipaddress import cached_ip_addresses from .util import asyncio_timeout +if TYPE_CHECKING: + from pyroute2.iproute import IPRoute # Some MAC addresses will drop the leading zero so # our mac validation must allow a single char VALID_MAC_ADDRESS = re.compile("^([0-9A-Fa-f]{1,2}[:-]){5}([0-9A-Fa-f]{1,2})$") @@ -93,7 +95,7 @@ def get_attrs_key(data: Any, key: Any) -> Any: return attr_value -def get_router_ip(ipr: Any) -> Any: +def get_router_ip(ipr: "IPRoute") -> Any: """Obtain the router ip from the default route.""" return get_attrs_key(ipr.get_default_routes()[0], "RTA_GATEWAY") @@ -132,7 +134,7 @@ def async_populate_arp(ip_addresses): class SystemNetworkData: """Gather system network data.""" - def __init__(self, ip_route: Any, local_ip: str | None = None) -> None: + def __init__(self, ip_route: "IPRoute" | None, local_ip: str | None = None) -> None: """Init system network data.""" self.ip_route = ip_route self.local_ip = local_ip