Skip to content

Commit

Permalink
Cache network information instead of fetching it each time (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Mar 2, 2024
1 parent 9d572f9 commit cdbb9b4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
32 changes: 19 additions & 13 deletions aiodiscover/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from contextlib import suppress
from functools import lru_cache
from ipaddress import IPv4Address
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast

from dns import exception, message, rdatatype
from dns.message import Message, QueryMessage
from dns.name import Name

from .network import SystemNetworkData

if TYPE_CHECKING:
from pyroute2.iproute import IPRoute # noqa: F401

HOSTNAME = "hostname"
MAC_ADDRESS = "macaddress"
IP_ADDRESS = "ip"
Expand Down Expand Up @@ -183,24 +186,27 @@ class DiscoverHosts:

def __init__(self) -> None:
"""Init the discovery hosts."""
self._ip_route = None
self._sys_network_data: SystemNetworkData | None = None

def _get_sys_network_data(self) -> SystemNetworkData:
if not self._ip_route:
with suppress(Exception):
from pyroute2.iproute import (
IPRoute,
) # type: ignore # pylint: disable=import-outside-toplevel
def _setup_sys_network_data(self) -> None:
ip_route: "IPRoute" | None = None
with suppress(Exception):
from pyroute2.iproute import ( # noqa: F811
IPRoute,
) # type: ignore # pylint: disable=import-outside-toplevel

self._ip_route = IPRoute()
sys_network_data = SystemNetworkData(self._ip_route)
ip_route = IPRoute()
sys_network_data = SystemNetworkData(ip_route)
sys_network_data.setup()
return sys_network_data
self._sys_network_data = sys_network_data

async def async_discover(self) -> list[dict[str, str]]:
"""Discover hosts on the network by ARP and PTR lookup."""
loop = asyncio.get_running_loop()
sys_network_data = await loop.run_in_executor(None, self._get_sys_network_data)
if not self._sys_network_data:
await asyncio.get_running_loop().run_in_executor(
None, self._setup_sys_network_data
)
sys_network_data = self._sys_network_data
network = sys_network_data.network
assert network is not None
if network.num_addresses > MAX_ADDRESSES:
Expand Down
8 changes: 5 additions & 3 deletions aiodiscover/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import sys
from contextlib import suppress
from ipaddress import IPv4Network, ip_network
from typing import Any, Iterable
from typing import TYPE_CHECKING, Any, Iterable

import ifaddr # type: ignore
from cached_ipaddress import cached_ip_addresses

from .util import asyncio_timeout

if TYPE_CHECKING:
from pyroute2.iproute import IPRoute
# Some MAC addresses will drop the leading zero so
# our mac validation must allow a single char
VALID_MAC_ADDRESS = re.compile("^([0-9A-Fa-f]{1,2}[:-]){5}([0-9A-Fa-f]{1,2})$")
Expand Down Expand Up @@ -93,7 +95,7 @@ def get_attrs_key(data: Any, key: Any) -> Any:
return attr_value


def get_router_ip(ipr: Any) -> Any:
def get_router_ip(ipr: "IPRoute") -> Any:
"""Obtain the router ip from the default route."""
return get_attrs_key(ipr.get_default_routes()[0], "RTA_GATEWAY")

Expand Down Expand Up @@ -132,7 +134,7 @@ def async_populate_arp(ip_addresses):
class SystemNetworkData:
"""Gather system network data."""

def __init__(self, ip_route: Any, local_ip: str | None = None) -> None:
def __init__(self, ip_route: "IPRoute" | None, local_ip: str | None = None) -> None:
"""Init system network data."""
self.ip_route = ip_route
self.local_ip = local_ip
Expand Down

0 comments on commit cdbb9b4

Please sign in to comment.