diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..7648cf0d --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[report] +exclude_lines = + pragma: no cover + if TYPE_CHECKING: + if sys.version_info diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..01b181fe --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,64 @@ +name: CI + +on: + push: + branches: + - master + pull_request: + branches: + - "**" + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: [3.6, 3.7, 3.8, 3.9, pypy3] + include: + - os: ubuntu-latest + venvcmd: . env/bin/activate + - os: macos-latest + venvcmd: . env/bin/activate + - os: windows-latest + venvcmd: env\Scripts\Activate.ps1 + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - uses: actions/cache@v2 + id: cache + with: + path: env + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/Makefile') }}-${{ hashFiles('**/requirements-dev.txt') }} + restore-keys: | + ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/Makefile') }} + - name: Install dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: | + python -m venv env + ${{ matrix.venvcmd }} + pip install --upgrade -r requirements-dev.txt pytest-github-actions-annotate-failures + - name: Run flake8 + if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }} + run: | + ${{ matrix.venvcmd }} + make flake8 + - name: Run mypy + if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }} + run: | + ${{ matrix.venvcmd }} + make mypy + - name: Run black_check + if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }} + run: | + ${{ matrix.venvcmd }} + make black_check + - name: Run tests + run: | + ${{ matrix.venvcmd }} + make test_coverage + - name: Report coverage to Codecov + uses: codecov/codecov-action@v1 diff --git a/.gitignore b/.gitignore index ddf8a0d7..0af9ce1e 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,6 @@ Thumbs.db .cache .mypy_cache/ docs/_build/ +.vscode +/dist/ +/zeroconf.egg-info/ diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index acf47981..00000000 --- a/.travis.yml +++ /dev/null @@ -1,19 +0,0 @@ -language: python -python: - - "3.5" - - "3.6" - - "3.7" - - "3.8" - - "pypy3.5" - - "pypy3" -install: - - pip install -r requirements-dev.txt - # mypy can't be installed on pypy - - if [[ "${TRAVIS_PYTHON_VERSION}" != "pypy"* ]] ; then pip install mypy ; fi - - if [[ "${TRAVIS_PYTHON_VERSION}" != *"3.5"* && "${TRAVIS_PYTHON_VERSION}" != "pypy"* ]] ; then - pip install black ; fi -script: - # no IPv6 support in Travis :( - - make TEST_ARGS='-a "!IPv6"' ci -after_success: - - coveralls diff --git a/Makefile b/Makefile index ea5f8c64..88980ff2 100644 --- a/Makefile +++ b/Makefile @@ -1,30 +1,28 @@ +# version: 1.1 + .PHONY: all virtualenv MAX_LINE_LENGTH=110 PYTHON_IMPLEMENTATION:=$(shell python -c "import sys;import platform;sys.stdout.write(platform.python_implementation())") PYTHON_VERSION:=$(shell python -c "import sys;sys.stdout.write('%d.%d' % sys.version_info[:2])") -TEST_ARGS= LINT_TARGETS:=flake8 ifneq ($(findstring PyPy,$(PYTHON_IMPLEMENTATION)),PyPy) - LINT_TARGETS:=$(LINT_TARGETS) mypy -endif -ifeq ($(or $(findstring 3.5,$(PYTHON_VERSION)),$(findstring PyPy,$(PYTHON_IMPLEMENTATION))),) - LINT_TARGETS:=$(LINT_TARGETS) black_check + LINT_TARGETS:=$(LINT_TARGETS) mypy black_check pylint endif virtualenv: ./env/requirements.built env: - virtualenv env + python -m venv env ./env/requirements.built: env requirements-dev.txt ./env/bin/pip install -r requirements-dev.txt cp requirements-dev.txt ./env/requirements.built .PHONY: ci -ci: test_coverage lint +ci: lint test_coverage .PHONY: lint lint: $(LINT_TARGETS) @@ -32,18 +30,24 @@ lint: $(LINT_TARGETS) flake8: flake8 --max-line-length=$(MAX_LINE_LENGTH) setup.py examples zeroconf +pylint: + pylint zeroconf + .PHONY: black_check black_check: black --check setup.py examples zeroconf mypy: - mypy examples/*.py zeroconf/*.py +# --no-warn-redundant-casts --no-warn-unused-ignores is needed since we support multiple python versions +# We should be able to drop this once python 3.6 goes away + mypy --no-warn-redundant-casts --no-warn-unused-ignores examples/*.py zeroconf test: - nosetests -v $(TEST_ARGS) + pytest --durations=20 --timeout=60 -v tests test_coverage: - nosetests -v --with-coverage --cover-package=zeroconf $(TEST_ARGS) + pytest --durations=20 --timeout=60 -v --cov=zeroconf --cov-branch --cov-report xml --cov-report html --cov-report term-missing tests autopep8: autopep8 --max-line-length=$(MAX_LINE_LENGTH) -i setup.py examples zeroconf + diff --git a/README.rst b/README.rst index 432d3c4b..ea7c3d1c 100644 --- a/README.rst +++ b/README.rst @@ -1,14 +1,14 @@ python-zeroconf =============== -.. image:: https://travis-ci.org/jstasiak/python-zeroconf.svg?branch=master - :target: https://travis-ci.org/jstasiak/python-zeroconf - +.. image:: https://github.com/jstasiak/python-zeroconf/workflows/CI/badge.svg + :target: https://github.com/jstasiak/python-zeroconf?query=workflow%3ACI+branch%3Amaster + .. image:: https://img.shields.io/pypi/v/zeroconf.svg :target: https://pypi.python.org/pypi/zeroconf -.. image:: https://img.shields.io/coveralls/jstasiak/python-zeroconf.svg - :target: https://coveralls.io/r/jstasiak/python-zeroconf +.. image:: https://codecov.io/gh/jstasiak/python-zeroconf/branch/master/graph/badge.svg + :target: https://codecov.io/gh/jstasiak/python-zeroconf `Documentation `_. @@ -37,15 +37,15 @@ Compared to some other Zeroconf/Bonjour/Avahi Python packages, python-zeroconf: * isn't tied to Bonjour or Avahi * doesn't use D-Bus -* doesn't force you to use particular event loop or Twisted +* doesn't force you to use particular event loop or Twisted (asyncio is used under the hood but not required) * is pip-installable * has PyPI distribution Python compatibility -------------------- -* CPython 3.5+ -* PyPy3 5.8+ +* CPython 3.6+ +* PyPy3 7.2+ Versioning ---------- @@ -59,8 +59,14 @@ This project's versions follow the following pattern: MAJOR.MINOR.PATCH. Status ------ -There are some people using this package. I don't actively use it and as such -any help I can offer with regard to any issues is very limited. +This project is actively maintained. + +Traffic Reduction +----------------- + +Before version 0.32, most traffic reduction techniques described in https://datatracker.ietf.org/doc/html/rfc6762#section-7 +where not implemented which could lead to excessive network traffic. It is highly recommended that version 0.32 or later +is used if this is a concern. IPv6 support ------------ @@ -69,8 +75,6 @@ IPv6 support is relatively new and currently limited, specifically: * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` on non-POSIX systems. -* On Windows specific interfaces can only be requested as interface indexes, - not as IP addresses. * Dual-stack IPv6 sockets are used, which may not be supported everywhere (some BSD variants do not have them). * Listening on localhost (`::1`) does not work. Help with understanding why is @@ -134,6 +138,655 @@ See examples directory for more. Changelog ========= +0.36.7 +====== + +* Improved performance of responding to queries (#994) (#996) (#997) @bdraco +* Improved log message when receiving an invalid or corrupt packet (#998) @bdraco + +0.36.6 +====== + +* Improved performance of sending outgoing packets (#990) @bdraco + +0.36.5 +====== + +* Reduced memory usage for incoming and outgoing packets (#987) @bdraco + +0.36.4 +====== + +* Improved performance of constructing outgoing packets (#978) (#979) @bdraco +* Deferred parsing of incoming packets when it can be avoided (#983) @bdraco + +0.36.3 +====== + +* Improved performance of parsing incoming packets (#975) @bdraco + +0.36.2 +====== + +* Include NSEC records for non-existent types when responding with addresses (#972) (#971) @bdraco + Implements RFC6762 sec 6.2 (http://datatracker.ietf.org/doc/html/rfc6762#section-6.2) + +0.36.1 +====== + +* Skip goodbye packets for addresses when there is another service registered with the same name (#968) @bdraco + + If a ServiceInfo that used the same server name as another ServiceInfo + was unregistered, goodbye packets would be sent for the addresses and + would cause the other service to be seen as offline. +* Fixed equality and hash for dns records with the unique bit (#969) @bdraco + + These records should have the same hash and equality since + the unique bit (cache flush bit) is not considered when adding or removing + the records from the cache. + +0.36.0 +====== + +Technically backwards incompatible: + +* Fill incomplete IPv6 tuples to avoid WinError on windows (#965) @lokesh2019 + + Fixed #932 + +0.35.1 +====== + +* Only reschedule types if the send next time changes (#958) @bdraco + + When the PTR response was seen again, the timer was being canceled and + rescheduled even if the timer was for the same time. While this did + not cause any breakage, it is quite inefficient. +* Cache DNS record and question hashes (#960) @bdraco + + The hash was being recalculated every time the object + was being used in a set or dict. Since the hashes are + effectively immutable, we only calculate them once now. + +0.35.0 +====== + +* Reduced chance of accidental synchronization of ServiceInfo requests (#955) @bdraco +* Sort aggregated responses to increase chance of name compression (#954) @bdraco + +Technically backwards incompatible: + +* Send unicast replies on the same socket the query was received (#952) @bdraco + + When replying to a QU question, we do not know if the sending host is reachable + from all of the sending sockets. We now avoid this problem by replying via + the receiving socket. This was the existing behavior when `InterfaceChoice.Default` + is set. + + This change extends the unicast relay behavior to used with `InterfaceChoice.Default` + to apply when `InterfaceChoice.All` or interfaces are explicitly passed when + instantiating a `Zeroconf` instance. + + Fixes #951 + +0.34.3 +====== + +* Fix sending immediate multicast responses (#949) @bdraco + +0.34.2 +====== + +* Coalesce aggregated multicast answers (#945) @bdraco + + When the random delay is shorter than the last scheduled response, + answers are now added to the same outgoing time group. + + This reduces traffic when we already know we will be sending a group of answers + inside the random delay window described in + datatracker.ietf.org/doc/html/rfc6762#section-6.3 +* Ensure ServiceInfo requests can be answered inside the default timeout with network protection (#946) @bdraco + + Adjust the time windows to ensure responses that have triggered the + protection against against excessive packet flooding due to + software bugs or malicious attack described in RFC6762 section 6 + can respond in under 1350ms to ensure ServiceInfo can ask two + questions within the default timeout of 3000ms + +0.34.1 +====== + +* Ensure multicast aggregation sends responses within 620ms (#942) @bdraco + + Responses that trigger the protection against against excessive + packet flooding due to software bugs or malicious attack described + in RFC6762 section 6 could cause the multicast aggregation response + to be delayed longer than 620ms (The maximum random delay of 120ms + and 500ms additional for aggregation). + + Only responses that trigger the protection are delayed longer than 620ms + +0.34.0 +====== + +* Implemented Multicast Response Aggregation (#940) @bdraco + + Responses are now aggregated when possible per rules in RFC6762 + section 6.4 + + Responses that trigger the protection against against excessive + packet flooding due to software bugs or malicious attack described + in RFC6762 section 6 are delayed instead of discarding as it was + causing responders that implement Passive Observation Of Failures + (POOF) to evict the records. + + Probe responses are now always sent immediately as there were cases + where they would fail to be answered in time to defend a name. + +0.33.4 +====== + +* Ensure zeroconf can be loaded when the system disables IPv6 (#933) @che0 + +0.33.3 +====== + +* Added support for forward dns compression pointers (#934) @bdraco +* Provide sockname when logging a protocol error (#935) @bdraco + +0.33.2 +====== + +* Handle duplicate goodbye answers in the same packet (#928) @bdraco + + Solves an exception being thrown when we tried to remove the known answer + from the cache when the second goodbye answer in the same packet was processed + + Fixed #926 +* Skip ipv6 interfaces that return ENODEV (#930) @bdraco + +0.33.1 +====== + +* Version number change only with less restrictive directory permissions + + Fixed #923 + +0.33.0 +====== + +This release eliminates all threading locks as all non-threadsafe operations +now happen in the event loop. + +* Let connection_lost close the underlying socket (#918) @bdraco + + The socket was closed during shutdown before asyncio's connection_lost + handler had a chance to close it which resulted in a traceback on + windows. + + Fixed #917 + +Technically backwards incompatible: + +* Removed duplicate unregister_all_services code (#910) @bdraco + + Calling Zeroconf.close from same asyncio event loop zeroconf is running in + will now skip unregister_all_services and log a warning as this a blocking + operation and is not async safe and never has been. + + Use AsyncZeroconf instead, or for legacy code call async_unregister_all_services before Zeroconf.close + +0.32.1 +====== + +* Increased timeout in ServiceInfo.request to handle loaded systems (#895) @bdraco + + It can take a few seconds for a loaded system to run the `async_request` + coroutine when the event loop is busy, or the system is CPU bound (example being + Home Assistant startup). We now add an additional `_LOADED_SYSTEM_TIMEOUT` (10s) + to the `run_coroutine_threadsafe` calls to ensure the coroutine has the total + amount of time to run up to its internal timeout (default of 3000ms). + + Ten seconds is a bit large of a timeout; however, it is only used in cases + where we wrap other timeouts. We now expect the only instance the + `run_coroutine_threadsafe` result timeout will happen in a production + circumstance is when someone is running a `ServiceInfo.request()` in a thread and + another thread calls `Zeroconf.close()` at just the right moment that the future + is never completed unless the system is so loaded that it is nearly unresponsive. + + The timeout for `run_coroutine_threadsafe` is the maximum time a thread can + cleanly shut down when zeroconf is closed out in another thread, which should + always be longer than the underlying thread operation. + +0.32.0 +====== + +This release offers 100% line and branch coverage. + +* Made ServiceInfo first question QU (#852) @bdraco + + We want an immediate response when requesting with ServiceInfo + by asking a QU question; most responders will not delay the response + and respond right away to our question. This also improves compatibility + with split networks as we may not have been able to see the response + otherwise. If the responder has not multicast the record recently, + it may still choose to do so in addition to responding via unicast + + Reduces traffic when there are multiple zeroconf instances running + on the network running ServiceBrowsers + + If we don't get an answer on the first try, we ask a QM question + in the event, we can't receive a unicast response for some reason + + This change puts ServiceInfo inline with ServiceBrowser which + also asks the first question as QU since ServiceInfo is commonly + called from ServiceBrowser callbacks +* Limited duplicate packet suppression to 1s intervals (#841) @bdraco + + Only suppress duplicate packets that happen within the same + second. Legitimate queriers will retry the question if they + are suppressed. The limit was reduced to one second to be + in line with rfc6762 +* Made multipacket known answer suppression per interface (#836) @bdraco + + The suppression was happening per instance of Zeroconf instead + of per interface. Since the same network can be seen on multiple + interfaces (usually and wifi and ethernet), this would confuse the + multi-packet known answer supression since it was not expecting + to get the same data more than once +* New ServiceBrowsers now request QU in the first outgoing when unspecified (#812) @bdraco + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 + When we start a ServiceBrowser and zeroconf has just started up, the known + answer list will be small. By asking a QU question first, it is likely + that we have a large known answer list by the time we ask the QM question + a second later (current default which is likely too low but would be + a breaking change to increase). This reduces the amount of traffic on + the network, and has the secondary advantage that most responders will + answer a QU question without the typical delay answering QM questions. +* IPv6 link-local addresses are now qualified with scope_id (#343) @ibygrave + + When a service is advertised on an IPv6 address where + the scope is link local, i.e. fe80::/64 (see RFC 4007) + the resolved IPv6 address must be extended with the + scope_id that identifies through the "%" symbol the + local interface to be used when routing to that address. + A new API `parsed_scoped_addresses()` is provided to + return qualified addresses to avoid breaking compatibility + on the existing parsed_addresses(). +* Network adapters that are disconnected are now skipped (#327) @ZLJasonG +* Fixed listeners missing initial packets if Engine starts too quickly (#387) @bdraco + + When manually creating a zeroconf.Engine object, it is no longer started automatically. + It must manually be started by calling .start() on the created object. + + The Engine thread is now started after all the listeners have been added to avoid a + race condition where packets could be missed at startup. +* Fixed answering matching PTR queries with the ANY query (#618) @bdraco +* Fixed lookup of uppercase names in the registry (#597) @bdraco + + If the ServiceInfo was registered with an uppercase name and the query was + for a lowercase name, it would not be found and vice-versa. +* Fixed unicast responses from any source port (#598) @bdraco + + Unicast responses were only being sent if the source port + was 53, this prevented responses when testing with dig: + + dig -p 5353 @224.0.0.251 media-12.local + + The above query will now see a response +* Fixed queries for AAAA records not being answered (#616) @bdraco +* Removed second level caching from ServiceBrowsers (#737) @bdraco + + The ServiceBrowser had its own cache of the last time it + saw a service that was reimplementing the DNSCache and + presenting a source of truth problem that lead to unexpected + queries when the two disagreed. +* Fixed server cache not being case-insensitive (#731) @bdraco + + If the server name had uppercase chars and any of the + matching records were lowercase, and the server would not be + found +* Fixed cache handling of records with different TTLs (#729) @bdraco + + There should only be one unique record in the cache at + a time as having multiple unique records will different + TTLs in the cache can result in unexpected behavior since + some functions returned all matching records and some + fetched from the right side of the list to return the + newest record. Instead we now store the records in a dict + to ensure that the newest record always replaces the same + unique record, and we never have a source of truth problem + determining the TTL of a record from the cache. +* Fixed ServiceInfo with multiple A records (#725) @bdraco + + If there were multiple A records for the host, ServiceInfo + would always return the last one that was in the incoming + packet, which was usually not the one that was wanted. +* Fixed stale unique records expiring too quickly (#706) @bdraco + + Records now expire 1s in the future instead of instant removal. + + tools.ietf.org/html/rfc6762#section-10.2 + Queriers receiving a Multicast DNS response with a TTL of zero SHOULD + NOT immediately delete the record from the cache, but instead record + a TTL of 1 and then delete the record one second later. In the case + of multiple Multicast DNS responders on the network described in + Section 6.6 above, if one of the responders shuts down and + incorrectly sends goodbye packets for its records, it gives the other + cooperating responders one second to send out their own response to + "rescue" the records before they expire and are deleted. +* Fixed exception when unregistering a service multiple times (#679) @bdraco +* Added an AsyncZeroconfServiceTypes to mirror ZeroconfServiceTypes to zeroconf.asyncio (#658) @bdraco +* Fixed interface_index_to_ip6_address not skiping ipv4 adapters (#651) @bdraco +* Added async_unregister_all_services to AsyncZeroconf (#649) @bdraco +* Fixed services not being removed from the registry when calling unregister_all_services (#644) @bdraco + + There was a race condition where a query could be answered for a service + in the registry, while goodbye packets which could result in a fresh record + being broadcast after the goodbye if a query came in at just the right + time. To avoid this, we now remove the services from the registry right + after we generate the goodbye packet +* Fixed zeroconf exception on load when the system disables IPv6 (#624) @bdraco +* Fixed the QU bit missing from for probe queries (#609) @bdraco + + The bit should be set per + datatracker.ietf.org/doc/html/rfc6762#section-8.1 + +* Fixed the TC bit missing for query packets where the known answers span multiple packets (#494) @bdraco +* Fixed packets not being properly separated when exceeding maximum size (#498) @bdraco + + Ensure that questions that exceed the max packet size are + moved to the next packet. This fixes DNSQuestions being + sent in multiple packets in violation of: + datatracker.ietf.org/doc/html/rfc6762#section-7.2 + + Ensure only one resource record is sent when a record + exceeds _MAX_MSG_TYPICAL + datatracker.ietf.org/doc/html/rfc6762#section-17 +* Fixed PTR questions asked in uppercase not being answered (#465) @bdraco +* Added Support for context managers in Zeroconf and AsyncZeroconf (#284) @shenek +* Implemented an AsyncServiceBrowser to compliment the sync ServiceBrowser (#429) @bdraco +* Added async_get_service_info to AsyncZeroconf and async_request to AsyncServiceInfo (#408) @bdraco +* Implemented allowing passing in a sync Zeroconf instance to AsyncZeroconf (#406) @bdraco +* Fixed IPv6 setup under MacOS when binding to "" (#392) @bdraco +* Fixed ZeroconfServiceTypes.find not always cancels the ServiceBrowser (#389) @bdraco + + There was a short window where the ServiceBrowser thread + could be left running after Zeroconf is closed because + the .join() was never waited for when a new Zeroconf + object was created +* Fixed duplicate packets triggering duplicate updates (#376) @bdraco + + If TXT or SRV records update was already processed and then + received again, it was possible for a second update to be + called back in the ServiceBrowser +* Fixed ServiceStateChange.Updated event happening for IPs that already existed (#375) @bdraco +* Fixed RFC6762 Section 10.2 paragraph 2 compliance (#374) @bdraco +* Reduced length of ServiceBrowser thread name with many types (#373) @bdraco +* Fixed empty answers being added in ServiceInfo.request (#367) @bdraco +* Fixed ServiceInfo not populating all AAAA records (#366) @bdraco + + Use get_all_by_details to ensure all records are loaded + into addresses. + + Only load A/AAAA records from the cache once in load_from_cache + if there is a SRV record present + + Move duplicate code that checked if the ServiceInfo was complete + into its own function +* Fixed a case where the cache list can change during iteration (#363) @bdraco +* Return task objects created by AsyncZeroconf (#360) @nocarryr + +Traffic Reduction: + +* Added support for handling QU questions (#621) @bdraco + + Implements RFC 6762 sec 5.4: + Questions Requesting Unicast Responses + datatracker.ietf.org/doc/html/rfc6762#section-5.4 +* Implemented protect the network against excessive packet flooding (#619) @bdraco +* Additionals are now suppressed when they are already in the answers section (#617) @bdraco +* Additionals are no longer included when the answer is suppressed by known-answer suppression (#614) @bdraco +* Implemented multi-packet known answer supression (#687) @bdraco + + Implements datatracker.ietf.org/doc/html/rfc6762#section-7.2 +* Implemented efficient bucketing of queries with known answers (#698) @bdraco +* Implemented duplicate question suppression (#770) @bdraco + + http://datatracker.ietf.org/doc/html/rfc6762#section-7.3 + +Technically backwards incompatible: + +* Update internal version check to match docs (3.6+) (#491) @bdraco + + Python version earlier then 3.6 were likely broken with zeroconf + already, however, the version is now explicitly checked. +* Update python compatibility as PyPy3 7.2 is required (#523) @bdraco + +Backwards incompatible: + +* Drop oversize packets before processing them (#826) @bdraco + + Oversized packets can quickly overwhelm the system and deny + service to legitimate queriers. In practice, this is usually due to broken mDNS + implementations rather than malicious actors. +* Guard against excessive ServiceBrowser queries from PTR records significantly lowerthan recommended (#824) @bdraco + + We now enforce a minimum TTL for PTR records to avoid + ServiceBrowsers generating excessive queries refresh queries. + Apple uses a 15s minimum TTL, however, we do not have the same + level of rate limit and safeguards, so we use 1/4 of the recommended value. +* RecordUpdateListener now uses async_update_records instead of update_record (#419, #726) @bdraco + + This allows the listener to receive all the records that have + been updated in a single transaction such as a packet or + cache expiry. + + update_record has been deprecated in favor of async_update_records + A compatibility shim exists to ensure classes that use + RecordUpdateListener as a base class continue to have + update_record called, however, they should be updated + as soon as possible. + + A new method async_update_records_complete is now called on each + listener when all listeners have completed processing updates + and the cache has been updated. This allows ServiceBrowsers + to delay calling handlers until they are sure the cache + has been updated as its a common pattern to call for + ServiceInfo when a ServiceBrowser handler fires. + + The async\_ prefix was chosen to make it clear that these + functions run in the eventloop and should never do blocking + I/O. Before 0.32+ these functions ran in a select() loop and + should not have been doing any blocking I/O, but it was not + clear to implementors that I/O would block the loop. +* Pass both the new and old records to async_update_records (#792) @bdraco + + Pass the old_record (cached) as the value and the new_record (wire) + to async_update_records instead of forcing each consumer to + check the cache since we will always have the old_record + when generating the async_update_records call. This avoids + the overhead of multiple cache lookups for each listener. + +0.31.0 +====== + +* Separated cache loading from I/O in ServiceInfo and fixed cache lookup (#356), + thanks to J. Nick Koston. + + The ServiceInfo class gained a load_from_cache() method to only fetch information + from Zeroconf cache (if it exists) with no IO performed. Additionally this should + reduce IO in cases where cache lookups were previously incorrectly failing. + +0.30.0 +====== + +* Some nice refactoring work including removal of the Reaper thread, + thanks to J. Nick Koston. + +* Fixed a Windows-specific The requested address is not valid in its context regression, + thanks to Timothee ‘TTimo’ Besset and J. Nick Koston. + +* Provided an asyncio-compatible service registration layer (in the zeroconf.asyncio module), + thanks to J. Nick Koston. + +0.29.0 +====== + +* A single socket is used for listening on responding when `InterfaceChoice.Default` is chosen. + Thanks to J. Nick Koston. + +Backwards incompatible: + +* Dropped Python 3.5 support + +0.28.8 +====== + +* Fixed the packet generation when multiple packets are necessary, previously invalid + packets were generated sometimes. Patch thanks to J. Nick Koston. + +0.28.7 +====== + +* Fixed the IPv6 address rendering in the browser example, thanks to Alexey Vazhnov. +* Fixed a crash happening when a service is added or removed during handle_response + and improved exception handling, thanks to J. Nick Koston. + +0.28.6 +====== + +* Loosened service name validation when receiving from the network this lets us handle + some real world devices previously causing errors, thanks to J. Nick Koston. + +0.28.5 +====== + +* Enabled ignoring duplicated messages which decreases CPU usage, thanks to J. Nick Koston. +* Fixed spurious AttributeError: module 'unittest' has no attribute 'mock' in tests. + +0.28.4 +====== + +* Improved cache reaper performance significantly, thanks to J. Nick Koston. +* Added ServiceListener to __all__ as it's part of the public API, thanks to Justin Nesselrotte. + +0.28.3 +====== + +* Reduced a time an internal lock is held which should eliminate deadlocks in high-traffic networks, + thanks to J. Nick Koston. + +0.28.2 +====== + +* Stopped asking questions we already have answers for in cache, thanks to Paul Daumlechner. +* Removed initial delay before querying for service info, thanks to Erik Montnemery. + +0.28.1 +====== + +* Fixed a resource leak connected to using ServiceBrowser with multiple types, thanks to + J. Nick Koston. + +0.28.0 +====== + +* Improved Windows support when using socket errno checks, thanks to Sandy Patterson. +* Added support for passing text addresses to ServiceInfo. +* Improved logging (includes fixing an incorrect logging call) +* Improved Windows compatibility by using Adapter.index from ifaddr, thanks to PhilippSelenium. +* Improved Windows compatibility by stopping using socket.if_nameindex. +* Fixed an OS X edge case which should also eliminate a memory leak, thanks to Emil Styrke. + +Technically backwards incompatible: + +* ``ifaddr`` 0.1.7 or newer is required now. + +0.27.1 +------ + +* Improved the logging situation (includes fixing a false-positive "packets() made no progress + adding records", thanks to Greg Badros) + +0.27.0 +------ + +* Large multi-resource responses are now split into separate packets which fixes a bad + mdns-repeater/ChromeCast Audio interaction ending with ChromeCast Audio crash (and possibly + some others) and improves RFC 6762 compliance, thanks to Greg Badros +* Added a warning presented when the listener passed to ServiceBrowser lacks update_service() + callback +* Added support for finding all services available in the browser example, thanks to Perry Kunder + +Backwards incompatible: + +* Removed previously deprecated ServiceInfo address constructor parameter and property + +0.26.3 +------ + +* Improved readability of logged incoming data, thanks to Erik Montnemery +* Threads are given unique names now to aid debugging, thanks to Erik Montnemery +* Fixed a regression where get_service_info() called within a listener add_service method + would deadlock, timeout and incorrectly return None, fix thanks to Erik Montnemery, but + Matt Saxon and Hmmbob were also involved in debugging it. + +0.26.2 +------ + +* Added support for multiple types to ServiceBrowser, thanks to J. Nick Koston +* Fixed a race condition where a listener gets a message before the lock is created, thanks to + J. Nick Koston + +0.26.1 +------ + +* Fixed a performance regression introduced in 0.26.0, thanks to J. Nick Koston (this is close in + spirit to an optimization made in 0.24.5 by the same author) + +0.26.0 +------ + +* Fixed a regression where service update listener wasn't called on IP address change (it's called + on SRV/A/AAAA record changes now), thanks to Matt Saxon + +Technically backwards incompatible: + +* Service update hook is no longer called on service addition (service added hook is still called), + this is related to the fix above + +0.25.1 +------ + +* Eliminated 5s hangup when calling Zeroconf.close(), thanks to Erik Montnemery + +0.25.0 +------ + +* Reverted uniqueness assertions when browsing, they caused a regression + +Backwards incompatible: + +* Rationalized handling of TXT records. Non-bytes values are converted to str and encoded to bytes + using UTF-8 now, None values mean value-less attributes. When receiving TXT records no decoding + is performed now, keys are always bytes and values are either bytes or None in value-less + attributes. + +0.24.5 +------ + +* Fixed issues with shared records being used where they shouldn't be (TXT, SRV, A records are + unique now), thanks to Matt Saxon +* Stopped unnecessarily excluding host-only interfaces from InterfaceChoice.all as they don't + forbid multicast, thanks to Andreas Oberritter +* Fixed repr() of IPv6 DNSAddress, thanks to Aldo Hoeben +* Removed duplicate update messages sent to listeners, thanks to Matt Saxon +* Added support for cooperating responders, thanks to Matt Saxon +* Optimized handle_response cache check, thanks to J. Nick Koston +* Fixed memory leak in DNSCache, thanks to J. Nick Koston + 0.24.4 ------ diff --git a/docs/api.rst b/docs/api.rst index 5bd2508f..1704db5a 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -5,3 +5,8 @@ python-zeroconf API reference :members: :undoc-members: :show-inheritance: + +.. automodule:: zeroconf.asyncio + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/index.rst b/docs/index.rst index c4fa6143..8929f417 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,16 +1,14 @@ Welcome to python-zeroconf documentation! ========================================= -.. image:: https://travis-ci.org/jstasiak/python-zeroconf.svg?branch=master - :alt: Build status - :target: https://travis-ci.org/jstasiak/python-zeroconf +.. image:: https://github.com/jstasiak/python-zeroconf/workflows/CI/badge.svg + :target: https://github.com/jstasiak/python-zeroconf?query=workflow%3ACI+branch%3Amaster .. image:: https://img.shields.io/pypi/v/zeroconf.svg :target: https://pypi.python.org/pypi/zeroconf - -.. image:: https://coveralls.io/repos/github/jstasiak/python-zeroconf/badge.svg?branch=master - :alt: Covergage status - :target: https://coveralls.io/github/jstasiak/python-zeroconf?branch=master + +.. image:: https://codecov.io/gh/jstasiak/python-zeroconf/branch/master/graph/badge.svg + :target: https://codecov.io/gh/jstasiak/python-zeroconf GitHub (code repository, issues): https://github.com/jstasiak/python-zeroconf @@ -18,7 +16,7 @@ PyPI (installable, stable distributions): https://pypi.org/project/zeroconf. You pip install zeroconf -python-zeroconf works with CPython 3.5+ and PyPy 3 implementing Python 3.5+. +python-zeroconf works with CPython 3.6+ and PyPy 3 implementing Python 3.6+. Contents -------- @@ -28,4 +26,4 @@ Contents api -See `the project's README `_ for more information. +See `the project's README `_ for more information. diff --git a/examples/async_apple_scanner.py b/examples/async_apple_scanner.py new file mode 100644 index 00000000..88b54e4a --- /dev/null +++ b/examples/async_apple_scanner.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +""" Scan for apple devices. """ + +import argparse +import asyncio +import logging +from typing import Any, Optional, cast + +from zeroconf import DNSQuestionType, IPVersion, ServiceStateChange, Zeroconf +from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf + +HOMESHARING_SERVICE: str = "_appletv-v2._tcp.local." +DEVICE_SERVICE: str = "_touch-able._tcp.local." +MEDIAREMOTE_SERVICE: str = "_mediaremotetv._tcp.local." +AIRPLAY_SERVICE: str = "_airplay._tcp.local." +COMPANION_SERVICE: str = "_companion-link._tcp.local." +RAOP_SERVICE: str = "_raop._tcp.local." +AIRPORT_ADMIN_SERVICE: str = "_airport._tcp.local." +DEVICE_INFO_SERVICE: str = "_device-info._tcp.local." + +ALL_SERVICES = [ + HOMESHARING_SERVICE, + DEVICE_SERVICE, + MEDIAREMOTE_SERVICE, + AIRPLAY_SERVICE, + COMPANION_SERVICE, + RAOP_SERVICE, + AIRPORT_ADMIN_SERVICE, + DEVICE_INFO_SERVICE, +] + +log = logging.getLogger(__name__) + + +def async_on_service_state_change( + zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange +) -> None: + print(f"Service {name} of type {service_type} state changed: {state_change}") + if state_change is not ServiceStateChange.Added: + return + base_name = name[: -len(service_type) - 1] + device_name = f"{base_name}.{DEVICE_INFO_SERVICE}" + asyncio.ensure_future(_async_show_service_info(zeroconf, service_type, name)) + # Also probe for device info + asyncio.ensure_future(_async_show_service_info(zeroconf, DEVICE_INFO_SERVICE, device_name)) + + +async def _async_show_service_info(zeroconf: Zeroconf, service_type: str, name: str) -> None: + info = AsyncServiceInfo(service_type, name) + await info.async_request(zeroconf, 3000, question_type=DNSQuestionType.QU) + print("Info from zeroconf.get_service_info: %r" % (info)) + if info: + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()] + print(" Name: %s" % name) + print(" Addresses: %s" % ", ".join(addresses)) + print(" Weight: %d, priority: %d" % (info.weight, info.priority)) + print(f" Server: {info.server}") + if info.properties: + print(" Properties are:") + for key, value in info.properties.items(): + print(f" {key}: {value}") + else: + print(" No properties") + else: + print(" No info") + print('\n') + + +class AsyncAppleScanner: + def __init__(self, args: Any) -> None: + self.args = args + self.aiobrowser: Optional[AsyncServiceBrowser] = None + self.aiozc: Optional[AsyncZeroconf] = None + + async def async_run(self) -> None: + self.aiozc = AsyncZeroconf(ip_version=ip_version) + await self.aiozc.zeroconf.async_wait_for_start() + print("\nBrowsing %s service(s), press Ctrl-C to exit...\n" % ALL_SERVICES) + kwargs = {'handlers': [async_on_service_state_change], 'question_type': DNSQuestionType.QU} + if self.args.target: + kwargs["addr"] = self.args.target + self.aiobrowser = AsyncServiceBrowser(self.aiozc.zeroconf, ALL_SERVICES, **kwargs) # type: ignore + while True: + await asyncio.sleep(1) + + async def async_close(self) -> None: + assert self.aiozc is not None + assert self.aiobrowser is not None + await self.aiobrowser.async_cancel() + await self.aiozc.async_close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument('--debug', action='store_true') + version_group = parser.add_mutually_exclusive_group() + version_group.add_argument('--target', help='Unicast target') + version_group.add_argument('--v6', action='store_true') + version_group.add_argument('--v6-only', action='store_true') + args = parser.parse_args() + + if args.debug: + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + if args.v6: + ip_version = IPVersion.All + elif args.v6_only: + ip_version = IPVersion.V6Only + else: + ip_version = IPVersion.V4Only + + loop = asyncio.get_event_loop() + runner = AsyncAppleScanner(args) + try: + loop.run_until_complete(runner.async_run()) + except KeyboardInterrupt: + loop.run_until_complete(runner.async_close()) diff --git a/examples/async_browser.py b/examples/async_browser.py new file mode 100644 index 00000000..1cce5c20 --- /dev/null +++ b/examples/async_browser.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 + +""" Example of browsing for a service. + +The default is HTTP and HAP; use --find to search for all available services in the network +""" + +import argparse +import asyncio +import logging +from typing import Any, Optional, cast + +from zeroconf import IPVersion, ServiceStateChange, Zeroconf +from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes + + +def async_on_service_state_change( + zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange +) -> None: + print(f"Service {name} of type {service_type} state changed: {state_change}") + if state_change is not ServiceStateChange.Added: + return + asyncio.ensure_future(async_display_service_info(zeroconf, service_type, name)) + + +async def async_display_service_info(zeroconf: Zeroconf, service_type: str, name: str) -> None: + info = AsyncServiceInfo(service_type, name) + await info.async_request(zeroconf, 3000) + print("Info from zeroconf.get_service_info: %r" % (info)) + if info: + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_scoped_addresses()] + print(" Name: %s" % name) + print(" Addresses: %s" % ", ".join(addresses)) + print(" Weight: %d, priority: %d" % (info.weight, info.priority)) + print(f" Server: {info.server}") + if info.properties: + print(" Properties are:") + for key, value in info.properties.items(): + print(f" {key}: {value}") + else: + print(" No properties") + else: + print(" No info") + print('\n') + + +class AsyncRunner: + def __init__(self, args: Any) -> None: + self.args = args + self.aiobrowser: Optional[AsyncServiceBrowser] = None + self.aiozc: Optional[AsyncZeroconf] = None + + async def async_run(self) -> None: + self.aiozc = AsyncZeroconf(ip_version=ip_version) + + services = ["_http._tcp.local.", "_hap._tcp.local."] + if self.args.find: + services = list( + await AsyncZeroconfServiceTypes.async_find(aiozc=self.aiozc, ip_version=ip_version) + ) + + print("\nBrowsing %s service(s), press Ctrl-C to exit...\n" % services) + self.aiobrowser = AsyncServiceBrowser( + self.aiozc.zeroconf, services, handlers=[async_on_service_state_change] + ) + while True: + await asyncio.sleep(1) + + async def async_close(self) -> None: + assert self.aiozc is not None + assert self.aiobrowser is not None + await self.aiobrowser.async_cancel() + await self.aiozc.async_close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument('--debug', action='store_true') + parser.add_argument('--find', action='store_true', help='Browse all available services') + version_group = parser.add_mutually_exclusive_group() + version_group.add_argument('--v6', action='store_true') + version_group.add_argument('--v6-only', action='store_true') + args = parser.parse_args() + + if args.debug: + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + if args.v6: + ip_version = IPVersion.All + elif args.v6_only: + ip_version = IPVersion.V6Only + else: + ip_version = IPVersion.V4Only + + loop = asyncio.get_event_loop() + runner = AsyncRunner(args) + try: + loop.run_until_complete(runner.async_run()) + except KeyboardInterrupt: + loop.run_until_complete(runner.async_close()) diff --git a/examples/async_registration.py b/examples/async_registration.py new file mode 100644 index 00000000..c3aab326 --- /dev/null +++ b/examples/async_registration.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +"""Example of announcing 250 services (in this case, a fake HTTP server).""" + +import argparse +import asyncio +import logging +import socket +from typing import List, Optional + +from zeroconf import IPVersion +from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf + + +class AsyncRunner: + def __init__(self, ip_version: IPVersion) -> None: + self.ip_version = ip_version + self.aiozc: Optional[AsyncZeroconf] = None + + async def register_services(self, infos: List[AsyncServiceInfo]) -> None: + self.aiozc = AsyncZeroconf(ip_version=self.ip_version) + tasks = [self.aiozc.async_register_service(info) for info in infos] + background_tasks = await asyncio.gather(*tasks) + await asyncio.gather(*background_tasks) + print("Finished registration, press Ctrl-C to exit...") + while True: + await asyncio.sleep(1) + + async def unregister_services(self, infos: List[AsyncServiceInfo]) -> None: + assert self.aiozc is not None + tasks = [self.aiozc.async_unregister_service(info) for info in infos] + background_tasks = await asyncio.gather(*tasks) + await asyncio.gather(*background_tasks) + await self.aiozc.async_close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument('--debug', action='store_true') + version_group = parser.add_mutually_exclusive_group() + version_group.add_argument('--v6', action='store_true') + version_group.add_argument('--v6-only', action='store_true') + args = parser.parse_args() + + if args.debug: + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + if args.v6: + ip_version = IPVersion.All + elif args.v6_only: + ip_version = IPVersion.V6Only + else: + ip_version = IPVersion.V4Only + + infos = [] + for i in range(250): + infos.append( + AsyncServiceInfo( + "_http._tcp.local.", + f"Paul's Test Web Site {i}._http._tcp.local.", + addresses=[socket.inet_aton("127.0.0.1")], + port=80, + properties={'path': '/~paulsm/'}, + server=f"zcdemohost-{i}.local.", + ) + ) + + print("Registration of 250 services...") + loop = asyncio.get_event_loop() + runner = AsyncRunner(ip_version) + try: + loop.run_until_complete(runner.register_services(infos)) + except KeyboardInterrupt: + loop.run_until_complete(runner.unregister_services(infos)) diff --git a/examples/async_service_info_request.py b/examples/async_service_info_request.py new file mode 100644 index 00000000..dd8265b7 --- /dev/null +++ b/examples/async_service_info_request.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +"""Example of perodic dump of homekit services. + +This example is useful when a user wants an ondemand +list of HomeKit devices on the network. + +""" + +import argparse +import asyncio +import logging +from typing import Any, Optional, cast + + +from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf +from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf + + +HAP_TYPE = "_hap._tcp.local." + + +async def async_watch_services(aiozc: AsyncZeroconf) -> None: + zeroconf = aiozc.zeroconf + while True: + await asyncio.sleep(5) + infos = [] + for name in zeroconf.cache.names(): + if not name.endswith(HAP_TYPE): + continue + infos.append(AsyncServiceInfo(HAP_TYPE, name)) + tasks = [info.async_request(aiozc.zeroconf, 3000) for info in infos] + await asyncio.gather(*tasks) + for info in infos: + print("Info for %s" % (info.name)) + if info: + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()] + print(" Addresses: %s" % ", ".join(addresses)) + print(" Weight: %d, priority: %d" % (info.weight, info.priority)) + print(f" Server: {info.server}") + if info.properties: + print(" Properties are:") + for key, value in info.properties.items(): + print(f" {key}: {value}") + else: + print(" No properties") + else: + print(" No info") + print('\n') + + +class AsyncRunner: + def __init__(self, args: Any) -> None: + self.args = args + self.threaded_browser: Optional[ServiceBrowser] = None + self.aiozc: Optional[AsyncZeroconf] = None + + async def async_run(self) -> None: + self.aiozc = AsyncZeroconf(ip_version=ip_version) + assert self.aiozc is not None + + def on_service_state_change( + zeroconf: Zeroconf, service_type: str, state_change: ServiceStateChange, name: str + ) -> None: + """Dummy handler.""" + + self.threaded_browser = ServiceBrowser( + self.aiozc.zeroconf, [HAP_TYPE], handlers=[on_service_state_change] + ) + await async_watch_services(self.aiozc) + + async def async_close(self) -> None: + assert self.aiozc is not None + assert self.threaded_browser is not None + self.threaded_browser.cancel() + await self.aiozc.async_close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument('--debug', action='store_true') + version_group = parser.add_mutually_exclusive_group() + version_group.add_argument('--v6', action='store_true') + version_group.add_argument('--v6-only', action='store_true') + args = parser.parse_args() + + if args.debug: + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + if args.v6: + ip_version = IPVersion.All + elif args.v6_only: + ip_version = IPVersion.V6Only + else: + ip_version = IPVersion.V4Only + + print(f"Services with {HAP_TYPE} will be shown every 5s, press Ctrl-C to exit...") + loop = asyncio.get_event_loop() + runner = AsyncRunner(args) + try: + loop.run_until_complete(runner.async_run()) + except KeyboardInterrupt: + loop.run_until_complete(runner.async_close()) diff --git a/examples/browser.py b/examples/browser.py index bf3ebfbd..8c50e409 100755 --- a/examples/browser.py +++ b/examples/browser.py @@ -1,32 +1,36 @@ #!/usr/bin/env python3 -""" Example of browsing for a service (in this case, HTTP) """ +""" Example of browsing for a service. + +The default is HTTP and HAP; use --find to search for all available services in the network +""" import argparse import logging -import socket from time import sleep from typing import cast -from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf +from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf, ZeroconfServiceTypes def on_service_state_change( zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange ) -> None: - print("Service %s of type %s state changed: %s" % (name, service_type, state_change)) + print(f"Service {name} of type {service_type} state changed: {state_change}") if state_change is ServiceStateChange.Added: info = zeroconf.get_service_info(service_type, name) + print("Info from zeroconf.get_service_info: %r" % (info)) + if info: - addresses = ["%s:%d" % (socket.inet_ntoa(addr), cast(int, info.port)) for addr in info.addresses] + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_scoped_addresses()] print(" Addresses: %s" % ", ".join(addresses)) print(" Weight: %d, priority: %d" % (info.weight, info.priority)) - print(" Server: %s" % (info.server,)) + print(f" Server: {info.server}") if info.properties: print(" Properties are:") for key, value in info.properties.items(): - print(" %s: %s" % (key, value)) + print(f" {key}: {value}") else: print(" No properties") else: @@ -39,6 +43,7 @@ def on_service_state_change( parser = argparse.ArgumentParser() parser.add_argument('--debug', action='store_true') + parser.add_argument('--find', action='store_true', help='Browse all available services') version_group = parser.add_mutually_exclusive_group() version_group.add_argument('--v6', action='store_true') version_group.add_argument('--v6-only', action='store_true') @@ -54,8 +59,13 @@ def on_service_state_change( ip_version = IPVersion.V4Only zeroconf = Zeroconf(ip_version=ip_version) - print("\nBrowsing services, press Ctrl-C to exit...\n") - browser = ServiceBrowser(zeroconf, "_http._tcp.local.", handlers=[on_service_state_change]) + + services = ["_http._tcp.local.", "_hap._tcp.local."] + if args.find: + services = list(ZeroconfServiceTypes.find(zc=zeroconf)) + + print("\nBrowsing %d service(s), press Ctrl-C to exit...\n" % len(services)) + browser = ServiceBrowser(zeroconf, services, handlers=[on_service_state_change]) try: while True: diff --git a/examples/self_test.py b/examples/self_test.py index 35007db1..2178629b 100755 --- a/examples/self_test.py +++ b/examples/self_test.py @@ -14,7 +14,7 @@ # Test a few module features, including service registration, service # query (for Zoe), and service unregistration. - print("Multicast DNS Service Discovery for Python, version %s" % (__version__,)) + print(f"Multicast DNS Service Discovery for Python, version {__version__}") r = Zeroconf() print("1. Testing registration of a service...") desc = {'version': '0.10', 'a': 'test value', 'b': 'another value'} @@ -40,7 +40,7 @@ queried_info = r.get_service_info("_http._tcp.local.", "My Service Name._http._tcp.local.") assert queried_info assert set(queried_info.parsed_addresses()) == expected - print(" Getting self: %s" % (queried_info,)) + print(f" Getting self: {queried_info}") print(" Query done.") print("4. Testing unregister of service information...") r.unregister_service(info) diff --git a/pyproject.toml b/pyproject.toml index a5d30b54..cd79b3e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,37 @@ [tool.black] line-length = 110 -target_version = ['py35', 'py36', 'py37'] +target_version = ['py35', 'py36', 'py37', 'py38'] skip_string_normalization = true + +[tool.pylint.BASIC] +class-const-naming-style = "any" +good-names = [ + "e", + "er", + "h", + "i", + "id", + "ip", + "os", + "n", + "rr", + "rs", + "s", + "t", + "wr", + "zc", + "_GLOBAL_DONE", +] + +[tool.pylint."MESSAGES CONTROL"] +disable = [ + "duplicate-code", + "fixme", + "format", + "missing-class-docstring", + "missing-function-docstring", + "too-few-public-methods", + "too-many-arguments", + "too-many-instance-attributes", + "too-many-public-methods" +] diff --git a/requirements-dev.txt b/requirements-dev.txt index ec443c0b..e7483666 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,16 @@ autopep8 +black;implementation_name=="cpython" +bump2version coveralls coverage -# Version restricted because of https://github.com/PyCQA/pycodestyle/issues/741 -flake8>=3.6.0 +flake8 flake8-import-order ifaddr -nose -pep8-naming!=0.6.0 +mypy;implementation_name=="cpython" +# 0.11.0 breaks things https://github.com/PyCQA/pep8-naming/issues/152 +pep8-naming!=0.6.0,!=0.11.0 +pylint +pytest +pytest-asyncio +pytest-cov +pytest-timeout diff --git a/setup.cfg b/setup.cfg index 5610cf68..67e50408 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,12 +1,25 @@ +[bumpversion] +current_version = 0.36.7 +commit = True +tag = True +tag_name = {new_version} + +[bumpversion:file:zeroconf/__init__.py] +search = __version__ = '{current_version}' +replace = __version__ = '{new_version}' + +[tool:pytest] +testpaths = tests + [flake8] show-source = 1 -application-import-names=zeroconf -max-line-length=110 -ignore=E203,W503 +application-import-names = zeroconf +max-line-length = 110 +ignore = E203,W503,N818 [mypy] ignore_missing_imports = true -follow_imports = error +follow_imports = skip check_untyped_defs = true no_implicit_optional = true warn_incomplete_stub = true @@ -15,7 +28,6 @@ warn_redundant_casts = true warn_unused_configs = true warn_unused_ignores = true warn_return_any = true -# TODO: disallow untyped calls and defs once we have full type hint coverage disallow_untyped_calls = false disallow_untyped_defs = true diff --git a/setup.py b/setup.py index a011e602..41e74842 100755 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ readme = f.read() version = ( - [l for l in open(join(PROJECT_ROOT, 'zeroconf', '__init__.py')) if '__version__' in l][0] + [ln for ln in open(join(PROJECT_ROOT, 'zeroconf', '__init__.py')) if '__version__' in ln][0] .split('=')[-1] .strip() .strip('\'"') @@ -23,7 +23,7 @@ author='Paul Scott-Murphy, William McBrine, Jakub Stasiak', url='https://github.com/jstasiak/python-zeroconf', package_data={"zeroconf": ["py.typed"]}, - packages=["zeroconf"], + packages=["zeroconf", "zeroconf._protocol", "zeroconf._services", "zeroconf._utils"], platforms=['unix', 'linux', 'osx'], license='LGPL', zip_safe=False, @@ -39,9 +39,10 @@ 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', ], keywords=['Bonjour', 'Avahi', 'Zeroconf', 'Multicast DNS', 'Service Discovery', 'mDNS'], - install_requires=['ifaddr', 'typing;python_version<"3.5"'], + install_requires=['ifaddr>=0.1.7'], ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..2671fe62 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,80 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import socket +from functools import lru_cache +from typing import List + +import ifaddr + + +from zeroconf import DNSIncoming, Zeroconf + + +def _inject_responses(zc: Zeroconf, msgs: List[DNSIncoming]) -> None: + """Inject a DNSIncoming response.""" + assert zc.loop is not None + + async def _wait_for_response(): + for msg in msgs: + zc.handle_response(msg) + + asyncio.run_coroutine_threadsafe(_wait_for_response(), zc.loop).result() + + +def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None: + """Inject a DNSIncoming response.""" + _inject_responses(zc, [msg]) + + +def _wait_for_start(zc: Zeroconf) -> None: + """Wait for all sockets to be up and running.""" + assert zc.loop is not None + asyncio.run_coroutine_threadsafe(zc.async_wait_for_start(), zc.loop).result() + + +@lru_cache(maxsize=None) +def has_working_ipv6(): + """Return True if if the system can bind an IPv6 address.""" + if not socket.has_ipv6: + return False + + try: + sock = socket.socket(socket.AF_INET6) + sock.bind(('::1', 0)) + except Exception: + return False + finally: + if sock: + sock.close() + + for iface in ifaddr.get_adapters(): + for addr in iface.ips: + if addr.is_IPv6 and iface.index is not None: + return True + return False + + +def _clear_cache(zc): + zc.cache.cache.clear() + zc.question_history._history.clear() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..f900e094 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + + +""" conftest for zeroconf tests. """ + +import threading + +import pytest + +import unittest + +from zeroconf import _core, const + + +@pytest.fixture(autouse=True) +def verify_threads_ended(): + """Verify that the threads are not running after the test.""" + threads_before = frozenset(threading.enumerate()) + yield + threads = frozenset(threading.enumerate()) - threads_before + assert not threads + + +@pytest.fixture +def run_isolated(): + """Change the mDNS port to run the test in isolation.""" + with unittest.mock.patch.object(_core, "_MDNS_PORT", 5454), unittest.mock.patch.object( + const, "_MDNS_PORT", 5454 + ): + yield diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 00000000..2ef4b15b --- /dev/null +++ b/tests/services/__init__.py @@ -0,0 +1,21 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py new file mode 100644 index 00000000..e22ebfe3 --- /dev/null +++ b/tests/services/test_browser.py @@ -0,0 +1,1020 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf._services.browser. """ + +import asyncio +import logging +import socket +import time +import os +import unittest +from threading import Event +from unittest.mock import patch + +import pytest + +import zeroconf as r +from zeroconf import DNSPointer, DNSQuestion, const, current_time_millis, millis_to_seconds +import zeroconf._services.browser as _services_browser +from zeroconf import Zeroconf +from zeroconf._services import ServiceStateChange +from zeroconf._services.browser import ServiceBrowser +from zeroconf._services.info import ServiceInfo +from zeroconf.asyncio import AsyncZeroconf + +from .. import has_working_ipv6, _inject_response, _wait_for_start + + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +def test_service_browser_cancel_multiple_times(): + """Test we can cancel a ServiceBrowser multiple times before close.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + + class MyServiceListener(r.ServiceListener): + pass + + listener = MyServiceListener() + + browser = r.ServiceBrowser(zc, type_, None, listener) + + browser.cancel() + browser.cancel() + browser.cancel() + + zc.close() + + +def test_service_browser_cancel_multiple_times_after_close(): + """Test we can cancel a ServiceBrowser multiple times after close.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + + class MyServiceListener(r.ServiceListener): + pass + + listener = MyServiceListener() + + browser = r.ServiceBrowser(zc, type_, None, listener) + + zc.close() + + browser.cancel() + browser.cancel() + browser.cancel() + + +def test_service_browser_started_after_zeroconf_closed(): + """Test starting a ServiceBrowser after close raises RuntimeError.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + + class MyServiceListener(r.ServiceListener): + pass + + listener = MyServiceListener() + zc.close() + + with pytest.raises(RuntimeError): + browser = r.ServiceBrowser(zc, type_, None, listener) + + +def test_multiple_instances_running_close(): + """Test we can shutdown multiple instances.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + zc2 = Zeroconf(interfaces=['127.0.0.1']) + zc3 = Zeroconf(interfaces=['127.0.0.1']) + + assert zc.loop != zc2.loop + assert zc.loop != zc3.loop + + class MyServiceListener(r.ServiceListener): + pass + + listener = MyServiceListener() + + zc2.add_service_listener("zca._hap._tcp.local.", listener) + + zc.close() + zc2.remove_service_listener(listener) + zc2.close() + zc3.close() + + +class TestServiceBrowser(unittest.TestCase): + def test_update_record(self): + enable_ipv6 = has_working_ipv6() and not os.environ.get('SKIP_IPV6') + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + service_v6_address = "2001:db8::1" + service_v6_second_address = "6001:db8::1" + + service_added_count = 0 + service_removed_count = 0 + service_updated_count = 0 + service_add_event = Event() + service_removed_event = Event() + service_updated_event = Event() + + class MyServiceListener(r.ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal service_added_count + service_added_count += 1 + service_add_event.set() + + def remove_service(self, zc, type_, name) -> None: + nonlocal service_removed_count + service_removed_count += 1 + service_removed_event.set() + + def update_service(self, zc, type_, name) -> None: + nonlocal service_updated_count + service_updated_count += 1 + service_info = zc.get_service_info(type_, name) + assert socket.inet_aton(service_address) in service_info.addresses + if enable_ipv6: + assert socket.inet_pton( + socket.AF_INET6, service_v6_address + ) in service_info.addresses_by_version(r.IPVersion.V6Only) + assert socket.inet_pton( + socket.AF_INET6, service_v6_second_address + ) in service_info.addresses_by_version(r.IPVersion.V6Only) + assert service_info.text == service_text + assert service_info.server == service_server + service_updated_event.set() + + def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + assert generated.is_response() is True + + if service_state_change == r.ServiceStateChange.Removed: + ttl = 0 + else: + ttl = 120 + + generated.add_answer_at_time( + r.DNSText( + service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text + ), + 0, + ) + + generated.add_answer_at_time( + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ), + 0, + ) + + # Send the IPv6 address first since we previously + # had a bug where the IPv4 would be missing if the + # IPv6 was seen first + if enable_ipv6: + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET6, service_v6_address), + ), + 0, + ) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET6, service_v6_second_address), + ), + 0, + ) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_aton(service_address), + ), + 0, + ) + + generated.add_answer_at_time( + r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 + ) + + return r.DNSIncoming(generated.packets()[0]) + + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener()) + + try: + wait_time = 3 + + # service added + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) + service_add_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 0 + assert service_removed_count == 0 + + # service SRV updated + service_updated_event.clear() + service_server = 'ash-2.local.' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 1 + assert service_removed_count == 0 + + # service TXT updated + service_updated_event.clear() + service_text = b'path=/~matt2/' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 2 + assert service_removed_count == 0 + + # service TXT updated - duplicate update should not trigger another service_updated + service_updated_event.clear() + service_text = b'path=/~matt2/' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 2 + assert service_removed_count == 0 + + # service A updated + service_updated_event.clear() + service_address = '10.0.1.3' + # Verify we match on uppercase + service_server = service_server.upper() + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 3 + assert service_removed_count == 0 + + # service all updated + service_updated_event.clear() + service_server = 'ash-3.local.' + service_text = b'path=/~matt3/' + service_address = '10.0.1.3' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 4 + assert service_removed_count == 0 + + # service removed + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) + service_removed_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 4 + assert service_removed_count == 1 + + finally: + assert len(zeroconf.listeners) == 1 + service_browser.cancel() + time.sleep(0.2) + assert len(zeroconf.listeners) == 0 + zeroconf.remove_all_service_listeners() + zeroconf.close() + + +class TestServiceBrowserMultipleTypes(unittest.TestCase): + def test_update_record(self): + + service_names = ['name2._type2._tcp.local.', 'name._type._tcp.local.', 'name._type._udp.local'] + service_types = ['_type2._tcp.local.', '_type._tcp.local.', '_type._udp.local.'] + + service_added_count = 0 + service_removed_count = 0 + service_add_event = Event() + service_removed_event = Event() + + class MyServiceListener(r.ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal service_added_count + service_added_count += 1 + if service_added_count == 3: + service_add_event.set() + + def remove_service(self, zc, type_, name) -> None: + nonlocal service_removed_count + service_removed_count += 1 + if service_removed_count == 3: + service_removed_event.set() + + def mock_incoming_msg( + service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int + ) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 + ) + return r.DNSIncoming(generated.packets()[0]) + + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + service_browser = r.ServiceBrowser(zeroconf, service_types, listener=MyServiceListener()) + + try: + wait_time = 3 + + # all three services added + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), + ) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120), + ) + time.sleep(0.1) + + called_with_refresh_time_check = False + + def _mock_get_expiration_time(self, percent): + nonlocal called_with_refresh_time_check + if percent == const._EXPIRE_REFRESH_TIME_PERCENT: + called_with_refresh_time_check = True + return 0 + return self.created + (percent * self.ttl * 10) + + # Set an expire time that will force a refresh + with patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time): + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), + ) + # Add the last record after updating the first one + # to ensure the service_add_event only gets set + # after the update + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120), + ) + service_add_event.wait(wait_time) + assert called_with_refresh_time_check is True + assert service_added_count == 3 + assert service_removed_count == 0 + + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Updated, service_types[0], service_names[0], 0), + ) + + # all three services removed + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0], 0), + ) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1], 0), + ) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[2], service_names[2], 0), + ) + service_removed_event.wait(wait_time) + assert service_added_count == 3 + assert service_removed_count == 3 + + finally: + assert len(zeroconf.listeners) == 1 + service_browser.cancel() + time.sleep(0.2) + assert len(zeroconf.listeners) == 0 + zeroconf.remove_all_service_listeners() + zeroconf.close() + + +def test_backoff(): + got_query = Event() + + type_ = "_http._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + _wait_for_start(zeroconf_browser) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_browser.async_send + + time_offset = 0.0 + start_time = time.time() * 1000 + initial_query_interval = _services_browser._BROWSER_TIME / 1000 + + def current_time_millis(): + """Current system time in milliseconds""" + return start_time + time_offset * 1000 + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): + """Sends an outgoing packet.""" + got_query.set() + old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope) + + # patch the zeroconf send + # patch the zeroconf current_time_millis + # patch the backoff limit to prevent test running forever + with patch.object(zeroconf_browser, "async_send", send), patch.object( + zeroconf_browser.question_history, "suppresses", return_value=False + ), patch.object(_services_browser, "current_time_millis", current_time_millis), patch.object( + _services_browser, "_BROWSER_BACKOFF_LIMIT", 10 + ), patch.object( + _services_browser, "_FIRST_QUERY_DELAY_RANDOM_INTERVAL", (0, 0) + ): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + + try: + # Test that queries are sent at increasing intervals + sleep_count = 0 + next_query_interval = 0.0 + expected_query_time = 0.0 + while True: + sleep_count += 1 + got_query.wait(0.1) + if time_offset == expected_query_time: + assert got_query.is_set() + got_query.clear() + if next_query_interval == _services_browser._BROWSER_BACKOFF_LIMIT: + # Only need to test up to the point where we've seen a query + # after the backoff limit has been hit + break + elif next_query_interval == 0: + next_query_interval = initial_query_interval + expected_query_time = initial_query_interval + else: + next_query_interval = min( + 2 * next_query_interval, _services_browser._BROWSER_BACKOFF_LIMIT + ) + expected_query_time += next_query_interval + else: + assert not got_query.is_set() + time_offset += initial_query_interval + zeroconf_browser.loop.call_soon_threadsafe(browser._async_send_ready_queries_schedule_next) + + finally: + browser.cancel() + zeroconf_browser.close() + + +def test_first_query_delay(): + """Verify the first query is delayed. + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 + """ + type_ = "_http._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + _wait_for_start(zeroconf_browser) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_browser.async_send + + first_query_time = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_query_time + if first_query_time is None: + first_query_time = current_time_millis() + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with patch.object(zeroconf_browser, "async_send", send): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + start_time = current_time_millis() + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5)) + try: + assert ( + current_time_millis() - start_time > _services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[0] + ) + finally: + browser.cancel() + zeroconf_browser.close() + + +def test_asking_default_is_asking_qm_questions_after_the_first_qu(): + """Verify the service browser's first question is QU and subsequent ones are QM questions.""" + type_ = "_quservice._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_browser.async_send + + first_outgoing = None + second_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + nonlocal second_outgoing + if first_outgoing is not None and second_outgoing is None: + second_outgoing = out + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with patch.object(zeroconf_browser, "async_send", send): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change], delay=5) + time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 120 + 5)) + try: + assert first_outgoing.questions[0].unicast == True + assert second_outgoing.questions[0].unicast == False + finally: + browser.cancel() + zeroconf_browser.close() + + +def test_asking_qm_questions(): + """Verify explictly asking QM questions.""" + type_ = "_quservice._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_browser.async_send + + first_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with patch.object(zeroconf_browser, "async_send", send): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + browser = ServiceBrowser( + zeroconf_browser, type_, [on_service_state_change], question_type=r.DNSQuestionType.QM + ) + time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5)) + try: + assert first_outgoing.questions[0].unicast == False + finally: + browser.cancel() + zeroconf_browser.close() + + +def test_asking_qu_questions(): + """Verify the service browser can ask QU questions.""" + type_ = "_quservice._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_browser.async_send + + first_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with patch.object(zeroconf_browser, "async_send", send): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + browser = ServiceBrowser( + zeroconf_browser, type_, [on_service_state_change], question_type=r.DNSQuestionType.QU + ) + time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5)) + try: + assert first_outgoing.questions[0].unicast == True + finally: + browser.cancel() + zeroconf_browser.close() + + +def test_legacy_record_update_listener(): + """Test a RecordUpdateListener that does not implement update_records.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + with pytest.raises(RuntimeError): + r.RecordUpdateListener().update_record( + zc, 0, r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) + ) + + updates = [] + + class LegacyRecordUpdateListener(r.RecordUpdateListener): + """A RecordUpdateListener that does not implement update_records.""" + + def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None: + nonlocal updates + updates.append(record) + + listener = LegacyRecordUpdateListener() + + zc.add_listener(listener, None) + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + # start a browser + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + browser = ServiceBrowser(zc, type_, [on_service_state_change]) + + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + zc.register_service(info_service) + + time.sleep(0.001) + + browser.cancel() + + assert len(updates) + assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1 + + zc.remove_listener(listener) + # Removing a second time should not throw + zc.remove_listener(listener) + + zc.close() + + +def test_service_browser_is_aware_of_port_changes(): + """Test that the ServiceBrowser is aware of port changes.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + + callbacks = [] + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + nonlocal callbacks + if name == registration_name: + callbacks.append((service_type, state_change, name)) + + browser = ServiceBrowser(zc, type_, [on_service_state_change]) + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + + def mock_incoming_msg(records) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + for record in records: + generated.add_answer_at_time(record, 0) + return r.DNSIncoming(generated.packets()[0]) + + _inject_response( + zc, + mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), + ) + time.sleep(0.1) + + assert callbacks == [('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.')] + assert zc.get_service_info(type_, registration_name).port == 80 + + info.port = 400 + _inject_response( + zc, + mock_incoming_msg([info.dns_service()]), + ) + time.sleep(0.1) + + assert callbacks == [ + ('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.'), + ('_hap._tcp.local.', ServiceStateChange.Updated, 'xxxyyy._hap._tcp.local.'), + ] + assert zc.get_service_info(type_, registration_name).port == 400 + browser.cancel() + + zc.close() + + +def test_service_browser_listeners_update_service(): + """Test that the ServiceBrowser ServiceListener that implements update_service.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + callbacks = [] + + class MyServiceListener(r.ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("add", type_, name)) + + def remove_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("remove", type_, name)) + + def update_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("update", type_, name)) + + listener = MyServiceListener() + + browser = r.ServiceBrowser(zc, type_, None, listener) + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + + def mock_incoming_msg(records) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + for record in records: + generated.add_answer_at_time(record, 0) + return r.DNSIncoming(generated.packets()[0]) + + _inject_response( + zc, + mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), + ) + time.sleep(0.2) + info.port = 400 + _inject_response( + zc, + mock_incoming_msg([info.dns_service()]), + ) + time.sleep(0.2) + + assert callbacks == [ + ('add', type_, registration_name), + ('update', type_, registration_name), + ] + browser.cancel() + + zc.close() + + +def test_service_browser_listeners_no_update_service(): + """Test that the ServiceBrowser ServiceListener that does not implement update_service.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + callbacks = [] + + class MyServiceListener: + def add_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("add", type_, name)) + + def remove_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("remove", type_, name)) + + listener = MyServiceListener() + + browser = r.ServiceBrowser(zc, type_, None, listener) + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + + def mock_incoming_msg(records) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + for record in records: + generated.add_answer_at_time(record, 0) + return r.DNSIncoming(generated.packets()[0]) + + _inject_response( + zc, + mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), + ) + time.sleep(0.2) + info.port = 400 + _inject_response( + zc, + mock_incoming_msg([info.dns_service()]), + ) + time.sleep(0.2) + + assert callbacks == [ + ('add', type_, registration_name), + ] + browser.cancel() + + zc.close() + + +def test_servicebrowser_uses_non_strict_names(): + """Verify we can look for technically invalid names as we cannot change what others do.""" + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + browser = ServiceBrowser(zc, ["_tivo-videostream._tcp.local."], [on_service_state_change]) + browser.cancel() + + # Still fail on completely invalid + with pytest.raises(r.BadTypeInNameException): + browser = ServiceBrowser(zc, ["tivo-videostream._tcp.local."], [on_service_state_change]) + zc.close() + + +def test_group_ptr_queries_with_known_answers(): + questions_with_known_answers: _services_browser._QuestionWithKnownAnswers = {} + now = current_time_millis() + for i in range(120): + name = f"_hap{i}._tcp._local." + questions_with_known_answers[DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN)] = set( + DNSPointer( + name, + const._TYPE_PTR, + const._CLASS_IN, + 4500, + f"zoo{counter}.{name}", + ) + for counter in range(i) + ) + outs = _services_browser._group_ptr_queries_with_known_answers(now, True, questions_with_known_answers) + for out in outs: + packets = out.packets() + # If we generate multiple packets there must + # only be one question + assert len(packets) == 1 or len(out.questions) == 1 + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_generate_service_query_suppress_duplicate_questions(): + """Generate a service query for sending with zeroconf.send.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + now = current_time_millis() + name = "_suppresstest._tcp.local." + question = r.DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN) + answer = r.DNSPointer( + name, + const._TYPE_PTR, + const._CLASS_IN, + 10000, + f'known-to-other.{name}', + ) + other_known_answers = set([answer]) + zc.question_history.add_question_at_time(question, now, other_known_answers) + assert zc.question_history.suppresses(question, now, other_known_answers) + + # The known answer list is different, do not suppress + outs = _services_browser.generate_service_query(zc, now, [name], multicast=True) + assert outs + + zc.cache.async_add_records([answer]) + # The known answer list contains all the asked questions in the history + # we should suppress + + outs = _services_browser.generate_service_query(zc, now, [name], multicast=True) + assert not outs + + # We do not suppress once the question history expires + outs = _services_browser.generate_service_query(zc, now + 1000, [name], multicast=True) + assert outs + + # We do not suppress QU queries ever + outs = _services_browser.generate_service_query(zc, now, [name], multicast=False) + assert outs + + zc.question_history.async_expire(now + 2000) + # No suppression after clearing the history + outs = _services_browser.generate_service_query(zc, now, [name], multicast=True) + assert outs + + # The previous query we just sent is still remembered and + # the next one is suppressed + outs = _services_browser.generate_service_query(zc, now, [name], multicast=True) + assert not outs + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_query_scheduler(): + delay = const._BROWSER_TIME + types_ = set(["_hap._tcp.local.", "_http._tcp.local."]) + query_scheduler = _services_browser.QueryScheduler(types_, delay, (0, 0)) + + now = current_time_millis() + query_scheduler.start(now) + + # Test query interval is increasing + assert query_scheduler.millis_to_wait(now - 1) == 1 + assert query_scheduler.millis_to_wait(now) is 0 + assert query_scheduler.millis_to_wait(now + 1) is 0 + + assert set(query_scheduler.process_ready_types(now)) == types_ + assert set(query_scheduler.process_ready_types(now)) == set() + assert query_scheduler.millis_to_wait(now) == delay + + assert set(query_scheduler.process_ready_types(now + delay)) == types_ + assert set(query_scheduler.process_ready_types(now + delay)) == set() + assert query_scheduler.millis_to_wait(now) == delay * 3 + + assert set(query_scheduler.process_ready_types(now + delay * 3)) == types_ + assert set(query_scheduler.process_ready_types(now + delay * 3)) == set() + assert query_scheduler.millis_to_wait(now) == delay * 7 + + assert set(query_scheduler.process_ready_types(now + delay * 7)) == types_ + assert set(query_scheduler.process_ready_types(now + delay * 7)) == set() + assert query_scheduler.millis_to_wait(now) == delay * 15 + + assert set(query_scheduler.process_ready_types(now + delay * 15)) == types_ + assert set(query_scheduler.process_ready_types(now + delay * 15)) == set() + + # Test if we reschedule 1 second later, the millis_to_wait goes up by 1 + query_scheduler.reschedule_type("_hap._tcp.local.", now + delay * 16) + assert query_scheduler.millis_to_wait(now) == delay * 16 + + assert set(query_scheduler.process_ready_types(now + delay * 15)) == set() + + # Test if we reschedule 1 second later... and its ready for processing + assert set(query_scheduler.process_ready_types(now + delay * 16)) == set(["_hap._tcp.local."]) + assert query_scheduler.millis_to_wait(now) == delay * 31 + assert set(query_scheduler.process_ready_types(now + delay * 20)) == set() + + assert set(query_scheduler.process_ready_types(now + delay * 31)) == set(["_http._tcp.local."]) diff --git a/tests/services/test_info.py b/tests/services/test_info.py new file mode 100644 index 00000000..a72d82f9 --- /dev/null +++ b/tests/services/test_info.py @@ -0,0 +1,779 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf._services.info. """ + +import logging +import socket +import threading +import os +import unittest +from unittest.mock import patch +from threading import Event +from typing import List + +import pytest + +import zeroconf as r +from zeroconf import DNSAddress, const +from zeroconf._services.info import ServiceInfo +from zeroconf.asyncio import AsyncZeroconf + +from .. import has_working_ipv6, _inject_response + + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class TestServiceInfo(unittest.TestCase): + def test_get_name(self): + """Verify the name accessor can strip the type.""" + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + assert info.get_name() == "name" + + def test_service_info_rejects_non_matching_updates(self): + """Verify records with the wrong name are rejected.""" + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + ttl = 120 + now = r.current_time_millis() + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + # Verify backwards compatiblity with calling with None + info.update_record(zc, now, None) + # Matching updates + info.update_record( + zc, + now, + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + info.update_record( + zc, + now, + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + 'ASH-2.local.', + ), + ) + assert info.server_key == 'ash-2.local.' + assert info.server == 'ASH-2.local.' + new_address = socket.inet_aton("10.0.1.3") + info.update_record( + zc, + now, + r.DNSAddress( + 'ASH-2.local.', + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + new_address, + ), + ) + assert new_address in info.addresses + # Non-matching updates + info.update_record( + zc, + now, + r.DNSText( + "incorrect.name.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + info.update_record( + zc, + now, + r.DNSService( + "incorrect.name.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + 'ASH-2.local.', + ), + ) + assert info.server_key == 'ash-2.local.' + assert info.server == 'ASH-2.local.' + new_address = socket.inet_aton("10.0.1.4") + info.update_record( + zc, + now, + r.DNSAddress( + "incorrect.name.", + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + new_address, + ), + ) + assert new_address not in info.addresses + zc.close() + + def test_service_info_rejects_expired_records(self): + """Verify records that are expired are rejected.""" + zc = r.Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + ttl = 120 + now = r.current_time_millis() + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + # Matching updates + info.update_record( + zc, + now, + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + # Expired record + expired_record = r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', + ) + expired_record.set_created_ttl(1000, 1) + info.update_record(zc, now, expired_record) + assert info.properties[b"ci"] == b"2" + zc.close() + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_get_info_partial(self): + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + service_address_v6_ll = 'fe80::52e:c2f2:bc5f:e9c6' + service_scope_id = 12 + + service_info = None + send_event = Event() + service_info_event = Event() + + last_sent = None # type: Optional[r.DNSOutgoing] + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): + """Sends an outgoing packet.""" + nonlocal last_sent + + last_sent = out + send_event.set() + + # patch the zeroconf send + with patch.object(zc, "async_send", send): + + def mock_incoming_msg(records) -> r.DNSIncoming: + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + + for record in records: + generated.add_answer_at_time(record, 0) + + return r.DNSIncoming(generated.packets()[0]) + + def get_service_info_helper(zc, type, name): + nonlocal service_info + service_info = zc.get_service_info(type, name) + service_info_event.set() + + try: + ttl = 120 + helper_thread = threading.Thread( + target=get_service_info_helper, args=(zc, service_type, service_name) + ) + helper_thread.start() + wait_time = 1 + + # Expext query for SRV, TXT, A, AAAA + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 4 + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext query for SRV, A, AAAA + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + service_text, + ) + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 3 + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext query for A, AAAA + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ) + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 2 + assert r.DNSQuestion(service_server, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_server, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + last_sent = None + assert service_info is None + + # Expext no further queries + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET, service_address), + ), + r.DNSAddress( + service_server, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET6, service_address_v6_ll), + scope_id=service_scope_id, + ), + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is None + assert service_info is not None + + finally: + helper_thread.join() + zc.remove_all_service_listeners() + zc.close() + + def test_get_info_single(self): + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + + service_info = None + send_event = Event() + service_info_event = Event() + + last_sent = None # type: Optional[r.DNSOutgoing] + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): + """Sends an outgoing packet.""" + nonlocal last_sent + + last_sent = out + send_event.set() + + # patch the zeroconf send + with patch.object(zc, "async_send", send): + + def mock_incoming_msg(records) -> r.DNSIncoming: + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + + for record in records: + generated.add_answer_at_time(record, 0) + + return r.DNSIncoming(generated.packets()[0]) + + def get_service_info_helper(zc, type, name): + nonlocal service_info + service_info = zc.get_service_info(type, name) + service_info_event.set() + + try: + ttl = 120 + helper_thread = threading.Thread( + target=get_service_info_helper, args=(zc, service_type, service_name) + ) + helper_thread.start() + wait_time = 1 + + # Expext query for SRV, TXT, A, AAAA + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 4 + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext no further queries + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + service_text, + ), + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ), + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET, service_address), + ), + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is None + assert service_info is not None + + finally: + helper_thread.join() + zc.remove_all_service_listeners() + zc.close() + + def test_service_info_duplicate_properties_txt_records(self): + """Verify the first property is always used when there are duplicates in a txt record.""" + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + ttl = 120 + now = r.current_time_millis() + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + info.async_update_records( + zc, + now, + [ + r.RecordUpdate( + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==\x04dd=0\x04jl=2\x04qq=0\x0brr=6fLM5A==\x04ci=3', + ), + None, + ) + ], + ) + assert info.properties[b"dd"] == b"0" + assert info.properties[b"jl"] == b"2" + assert info.properties[b"ci"] == b"2" + zc.close() + + +def test_multiple_addresses(): + type_ = "_http._tcp.local." + registration_name = "xxxyyy.%s" % type_ + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + + # New kwarg way + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address]) + + assert info.addresses == [address, address] + assert info.parsed_addresses() == [address_parsed, address_parsed] + assert info.parsed_scoped_addresses() == [address_parsed, address_parsed] + + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + parsed_addresses=[address_parsed, address_parsed], + ) + assert info.addresses == [address, address] + assert info.parsed_addresses() == [address_parsed, address_parsed] + assert info.parsed_scoped_addresses() == [address_parsed, address_parsed] + + if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): + address_v6_parsed = "2001:db8::1" + address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) + address_v6_ll_parsed = "fe80::52e:c2f2:bc5f:e9c6" + address_v6_ll_scoped_parsed = "fe80::52e:c2f2:bc5f:e9c6%12" + address_v6_ll = socket.inet_pton(socket.AF_INET6, address_v6_ll_parsed) + interface_index = 12 + infos = [ + ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[address, address_v6, address_v6_ll], + interface_index=interface_index, + ), + ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + parsed_addresses=[address_parsed, address_v6_parsed, address_v6_ll_parsed], + interface_index=interface_index, + ), + ] + for info in infos: + assert info.addresses == [address] + assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6, address_v6_ll] + assert info.addresses_by_version(r.IPVersion.V4Only) == [address] + assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6, address_v6_ll] + assert info.parsed_addresses() == [address_parsed, address_v6_parsed, address_v6_ll_parsed] + assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] + assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed, address_v6_ll_parsed] + assert info.parsed_scoped_addresses() == [ + address_v6_ll_scoped_parsed, + address_parsed, + address_v6_parsed, + ] + assert info.parsed_scoped_addresses(r.IPVersion.V4Only) == [address_parsed] + assert info.parsed_scoped_addresses(r.IPVersion.V6Only) == [ + address_v6_ll_scoped_parsed, + address_v6_parsed, + ] + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_multiple_a_addresses(): + type_ = "_http._tcp.local." + registration_name = "multiarec.%s" % type_ + desc = {'path': '/~paulsm/'} + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + cache = aiozc.zeroconf.cache + host = "multahost.local." + record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'a') + record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'b') + cache.async_add_records([record1, record2]) + + # New kwarg way + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host) + info.load_from_cache(aiozc.zeroconf) + assert set(info.addresses) == set([b'a', b'b']) + await aiozc.async_close() + + +@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') +@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') +def test_filter_address_by_type_from_service_info(): + """Verify dns_addresses can filter by ipversion.""" + desc = {'path': '/~paulsm/'} + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + registration_name = "%s.%s" % (name, type_) + ipv4 = socket.inet_aton("10.0.1.2") + ipv6 = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[ipv4, ipv6]) + + def dns_addresses_to_addresses(dns_address: List[DNSAddress]): + return [address.address for address in dns_address] + + assert dns_addresses_to_addresses(info.dns_addresses()) == [ipv4, ipv6] + assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.All)) == [ipv4, ipv6] + assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V4Only)) == [ipv4] + assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V6Only)) == [ipv6] + + +def test_changing_name_updates_serviceinfo_key(): + """Verify a name change will adjust the underlying key value.""" + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + assert info_service.key == "mytesthome._homeassistant._tcp.local." + info_service.name = "YourTestHome._homeassistant._tcp.local." + assert info_service.key == "yourtesthome._homeassistant._tcp.local." + + +def test_serviceinfo_address_updates(): + """Verify adding/removing/setting addresses on ServiceInfo.""" + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + + # Verify addresses and parsed_addresses are mutually exclusive + with pytest.raises(TypeError): + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + parsed_addresses=["10.0.1.2"], + ) + + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info_service.addresses = [socket.inet_aton("10.0.1.3")] + assert info_service.addresses == [socket.inet_aton("10.0.1.3")] + + +def test_serviceinfo_accepts_bytes_or_string_dict(): + """Verify a bytes or string dict can be passed to ServiceInfo.""" + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + addresses = [socket.inet_aton("10.0.1.2")] + server_name = "ash-2.local." + info_service = ServiceInfo( + type_, '%s.%s' % (name, type_), 80, 0, 0, {b'path': b'/~paulsm/'}, server_name, addresses=addresses + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + server_name, + addresses=addresses, + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {b'path': '/~paulsm/'}, + server_name, + addresses=addresses, + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': b'/~paulsm/'}, + server_name, + addresses=addresses, + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + + +def test_asking_qu_questions(): + """Verify explictly asking QU questions.""" + type_ = "_quservice._tcp.local." + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf.async_send + + first_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with patch.object(zeroconf, "async_send", send): + zeroconf.get_service_info(f"name.{type_}", type_, 500, question_type=r.DNSQuestionType.QU) + assert first_outgoing.questions[0].unicast == True + zeroconf.close() + + +def test_asking_qm_questions(): + """Verify explictly asking QM questions.""" + type_ = "_quservice._tcp.local." + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf.async_send + + first_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with patch.object(zeroconf, "async_send", send): + zeroconf.get_service_info(f"name.{type_}", type_, 500, question_type=r.DNSQuestionType.QM) + assert first_outgoing.questions[0].unicast == False + zeroconf.close() + + +def test_request_timeout(): + """Test that the timeout does not throw an exception and finishes close to the actual timeout.""" + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + start_time = r.current_time_millis() + assert zeroconf.get_service_info("_notfound.local.", "notthere._notfound.local.") is None + end_time = r.current_time_millis() + zeroconf.close() + # 3000ms for the default timeout + # 1000ms for loaded systems + schedule overhead + assert (end_time - start_time) < 3000 + 1000 + + +@pytest.mark.asyncio +async def test_we_try_four_times_with_random_delay(): + """Verify we try four times even with the random delay.""" + type_ = "_typethatisnothere._tcp.local." + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + request_count = 0 + + def async_send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal request_count + request_count += 1 + + # patch the zeroconf send + with patch.object(aiozc.zeroconf, "async_send", async_send): + await aiozc.async_get_service_info(f"willnotbefound.{type_}", type_) + + await aiozc.async_close() + + assert request_count == 4 diff --git a/tests/services/test_registry.py b/tests/services/test_registry.py new file mode 100644 index 00000000..3c105cbb --- /dev/null +++ b/tests/services/test_registry.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for zeroconf._services.registry.""" + +import unittest +import socket + +import zeroconf as r +from zeroconf import ServiceInfo + + +class TestServiceRegistry(unittest.TestCase): + def test_only_register_once(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.async_add(info) + self.assertRaises(r.ServiceNameAlreadyRegistered, registry.async_add, info) + registry.async_remove(info) + registry.async_add(info) + + def test_register_same_server(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "xxxyyy2" + registration_name = "%s.%s" % (name, type_) + registration_name2 = "%s.%s" % (name2, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "same.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + type_, registration_name2, 80, 0, 0, desc, "same.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + registry = r.ServiceRegistry() + registry.async_add(info) + registry.async_add(info2) + assert registry.async_get_infos_server("same.local.") == [info, info2] + registry.async_remove(info) + assert registry.async_get_infos_server("same.local.") == [info2] + registry.async_remove(info2) + assert registry.async_get_infos_server("same.local.") == [] + + def test_unregister_multiple_times(self): + """Verify we can unregister a service multiple times. + + In production unregister_service and unregister_all_services + may happen at the same time during shutdown. We want to treat + this as non-fatal since its expected to happen and it is unlikely + that the callers know about each other. + """ + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.async_add(info) + self.assertRaises(r.ServiceNameAlreadyRegistered, registry.async_add, info) + registry.async_remove(info) + registry.async_remove(info) + + def test_lookups(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.async_add(info) + + assert registry.async_get_service_infos() == [info] + assert registry.async_get_info_name(registration_name) == info + assert registry.async_get_infos_type(type_) == [info] + assert registry.async_get_infos_server("ash-2.local.") == [info] + assert registry.async_get_types() == [type_] + + def test_lookups_upper_case_by_lower_case(self): + type_ = "_test-SRVC-type._tcp.local." + name = "Xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ASH-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.async_add(info) + + assert registry.async_get_service_infos() == [info] + assert registry.async_get_info_name(registration_name.lower()) == info + assert registry.async_get_infos_type(type_.lower()) == [info] + assert registry.async_get_infos_server("ash-2.local.") == [info] + assert registry.async_get_types() == [type_.lower()] + + def test_lookups_lower_case_by_upper_case(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.async_add(info) + + assert registry.async_get_service_infos() == [info] + assert registry.async_get_info_name(registration_name.upper()) == info + assert registry.async_get_infos_type(type_.upper()) == [info] + assert registry.async_get_infos_server("ASH-2.local.") == [info] + assert registry.async_get_types() == [type_] diff --git a/tests/services/test_types.py b/tests/services/test_types.py new file mode 100644 index 00000000..b1c312db --- /dev/null +++ b/tests/services/test_types.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for zeroconf._services.types.""" + +import logging +import os +import unittest +import socket +import sys +from unittest.mock import patch + +import zeroconf as r +from zeroconf import Zeroconf, ServiceInfo, ZeroconfServiceTypes + +from .. import _clear_cache, has_working_ipv6 + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class ServiceTypesQuery(unittest.TestCase): + def test_integration_with_listener(self): + + type_ = "_test-listen-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + zeroconf_registrar.registry.async_add(info) + try: + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ): + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=2) + assert type_ in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2) + assert type_ in service_types + + finally: + zeroconf_registrar.close() + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_integration_with_listener_v6_records(self): + + type_ = "_test-listenv6rec-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_pton(socket.AF_INET6, addr)], + ) + zeroconf_registrar.registry.async_add(info) + try: + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ): + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=2) + assert type_ in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2) + assert type_ in service_types + + finally: + zeroconf_registrar.close() + + @unittest.skipIf(not has_working_ipv6() or sys.platform == 'win32', 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_integration_with_listener_ipv6(self): + + type_ = "_test-listenv6ip-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com + + zeroconf_registrar = Zeroconf(ip_version=r.IPVersion.V6Only) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_pton(socket.AF_INET6, addr)], + ) + zeroconf_registrar.registry.async_add(info) + try: + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ): + service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=2) + assert type_ in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2) + assert type_ in service_types + + finally: + zeroconf_registrar.close() + + def test_integration_with_subtype_and_listener(self): + subtype_ = "_subtype._sub" + type_ = "_listen._tcp.local." + name = "xxxyyy" + # Note: discovery returns only DNS-SD type not subtype + discovery_type = "%s.%s" % (subtype_, type_) + registration_name = "%s.%s" % (name, type_) + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + discovery_type, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + zeroconf_registrar.registry.async_add(info) + try: + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ): + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=2) + assert discovery_type in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2) + assert discovery_type in service_types + + finally: + zeroconf_registrar.close() diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py new file mode 100644 index 00000000..ea80d6f5 --- /dev/null +++ b/tests/test_asyncio.py @@ -0,0 +1,1124 @@ +#!/usr/bin/env python + + +"""Unit tests for aio.py.""" + +import asyncio +import logging +import os +import socket +import time +import threading +from unittest.mock import ANY, call, patch, MagicMock + + +import pytest + +from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes +from zeroconf import ( + DNSIncoming, + DNSOutgoing, + DNSQuestion, + DNSPointer, + DNSService, + DNSAddress, + DNSText, + ServiceStateChange, + Zeroconf, + const, +) +from zeroconf.const import _LISTENER_TIME +from zeroconf._core import AsyncListener +from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered +from zeroconf._services import ServiceListener +import zeroconf._services.browser as _services_browser +from zeroconf._services.info import ServiceInfo +from zeroconf._utils.time import current_time_millis + +from . import _clear_cache, has_working_ipv6 + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +@pytest.fixture(autouse=True) +def verify_threads_ended(): + """Verify that the threads are not running after the test.""" + threads_before = frozenset(threading.enumerate()) + yield + threads_after = frozenset(threading.enumerate()) + non_executor_threads = frozenset( + thread + for thread in threads_after + if "asyncio" not in thread.name and "ThreadPoolExecutor" not in thread.name + ) + threads = non_executor_threads - threads_before + assert not threads + + +@pytest.mark.asyncio +async def test_async_basic_usage() -> None: + """Test we can create and close the instance.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_close_twice() -> None: + """Test we can close twice.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.async_close() + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_with_sync_passed_in() -> None: + """Test we can create and close the instance when passing in a sync Zeroconf.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + aiozc = AsyncZeroconf(zc=zc) + assert aiozc.zeroconf is zc + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_with_sync_passed_in_closed_in_async() -> None: + """Test caller closes the sync version in async.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + aiozc = AsyncZeroconf(zc=zc) + assert aiozc.zeroconf is zc + zc.close() + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_sync_within_event_loop_executor() -> None: + """Test sync version still works from an executor within an event loop.""" + + def sync_code(): + zc = Zeroconf(interfaces=['127.0.0.1']) + assert zc.get_service_info("_neverused._tcp.local.", "xneverused._neverused._tcp.local.", 10) is None + zc.close() + + await asyncio.get_event_loop().run_in_executor(None, sync_code) + + +@pytest.mark.asyncio +async def test_async_service_registration() -> None: + """Test registering services broadcasts the registration by default.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + + aiozc.zeroconf.add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + await task + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + task = await aiozc.async_update_service(new_info) + await task + task = await aiozc.async_unregister_service(new_info) + await task + await aiozc.async_close() + + assert calls == [ + ('add', type_, registration_name), + ('update', type_, registration_name), + ('remove', type_, registration_name), + ] + + +@pytest.mark.asyncio +async def test_async_service_registration_same_server_different_ports() -> None: + """Test registering services with the same server with different srv records.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "xxxyyy2" + + registration_name = f"{name}.{type_}" + registration_name2 = f"{name2}.{type_}" + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + + aiozc.zeroconf.add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info2 = ServiceInfo( + type_, + registration_name2, + 81, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + tasks = [] + tasks.append(await aiozc.async_register_service(info)) + tasks.append(await aiozc.async_register_service(info2)) + await asyncio.gather(*tasks) + + task = await aiozc.async_unregister_service(info) + await task + entries = aiozc.zeroconf.cache.async_entries_with_server("ash-2.local.") + assert len(entries) == 1 + assert info2.dns_service() in entries + await aiozc.async_close() + assert calls == [ + ('add', type_, registration_name), + ('add', type_, registration_name2), + ('remove', type_, registration_name), + ('remove', type_, registration_name2), + ] + + +@pytest.mark.asyncio +async def test_async_service_registration_same_server_same_ports() -> None: + """Test registering services with the same server with the exact same srv record.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "xxxyyy2" + + registration_name = f"{name}.{type_}" + registration_name2 = f"{name2}.{type_}" + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + + aiozc.zeroconf.add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info2 = ServiceInfo( + type_, + registration_name2, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + tasks = [] + tasks.append(await aiozc.async_register_service(info)) + tasks.append(await aiozc.async_register_service(info2)) + await asyncio.gather(*tasks) + + task = await aiozc.async_unregister_service(info) + await task + entries = aiozc.zeroconf.cache.async_entries_with_server("ash-2.local.") + assert len(entries) == 1 + assert info2.dns_service() in entries + await aiozc.async_close() + assert calls == [ + ('add', type_, registration_name), + ('add', type_, registration_name2), + ('remove', type_, registration_name), + ('remove', type_, registration_name2), + ] + + +@pytest.mark.asyncio +async def test_async_service_registration_name_conflict() -> None: + """Test registering services throws on name conflict.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc2-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + await task + + with pytest.raises(NonUniqueNameException): + task = await aiozc.async_register_service(info) + await task + + with pytest.raises(ServiceNameAlreadyRegistered): + task = await aiozc.async_register_service(info, cooperating_responders=True) + await task + + conflicting_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-3.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + + with pytest.raises(NonUniqueNameException): + task = await aiozc.async_register_service(conflicting_info) + await task + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_service_registration_name_does_not_match_type() -> None: + """Test registering services throws when the name does not match the type.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc3-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info.type = "_wrong._tcp.local." + with pytest.raises(BadTypeInNameException): + task = await aiozc.async_register_service(info) + await task + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_tasks() -> None: + """Test awaiting broadcast tasks""" + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc4-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + aiozc.zeroconf.add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + assert isinstance(task, asyncio.Task) + await task + + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + task = await aiozc.async_update_service(new_info) + assert isinstance(task, asyncio.Task) + await task + + task = await aiozc.async_unregister_service(new_info) + assert isinstance(task, asyncio.Task) + await task + + await aiozc.async_close() + + assert calls == [ + ('add', type_, registration_name), + ('update', type_, registration_name), + ('remove', type_, registration_name), + ] + + +@pytest.mark.asyncio +async def test_async_wait_unblocks_on_update() -> None: + """Test async_wait will unblock on update.""" + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc4-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + + # Should unblock due to update from the + # registration + now = current_time_millis() + await aiozc.zeroconf.async_wait(50000) + assert current_time_millis() - now < 3000 + await task + + now = current_time_millis() + await aiozc.zeroconf.async_wait(50) + assert current_time_millis() - now < 1000 + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_service_info_async_request() -> None: + """Test registering services broadcasts and query with AsyncServceInfo.async_request.""" + if not has_working_ipv6() or os.environ.get('SKIP_IPV6'): + pytest.skip('Requires IPv6') + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "abc" + registration_name = f"{name}.{type_}" + registration_name2 = f"{name2}.{type_}" + + # Start a tasks BEFORE the registration that will keep trying + # and see the registration a bit later + get_service_info_task1 = asyncio.ensure_future(aiozc.async_get_service_info(type_, registration_name)) + await asyncio.sleep(_LISTENER_TIME / 1000 / 2) + get_service_info_task2 = asyncio.ensure_future(aiozc.async_get_service_info(type_, registration_name)) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-1.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info2 = ServiceInfo( + type_, + registration_name2, + 80, + 0, + 0, + desc, + "ash-5.local.", + addresses=[socket.inet_aton("10.0.1.5")], + ) + tasks = [] + tasks.append(await aiozc.async_register_service(info)) + tasks.append(await aiozc.async_register_service(info2)) + await asyncio.gather(*tasks) + + aiosinfo = await get_service_info_task1 + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] + + aiosinfo = await get_service_info_task2 + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] + + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] + + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3"), socket.inet_pton(socket.AF_INET6, "6001:db8::1")], + ) + + task = await aiozc.async_update_service(new_info) + await task + + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] + + aiosinfos = await asyncio.gather( + aiozc.async_get_service_info(type_, registration_name), + aiozc.async_get_service_info(type_, registration_name2), + ) + assert aiosinfos[0] is not None + assert aiosinfos[0].addresses == [socket.inet_aton("10.0.1.3")] + assert aiosinfos[1] is not None + assert aiosinfos[1].addresses == [socket.inet_aton("10.0.1.5")] + + aiosinfo = AsyncServiceInfo(type_, registration_name) + _clear_cache(aiozc.zeroconf) + # Generating the race condition is almost impossible + # without patching since its a TOCTOU race + with patch("zeroconf.asyncio.AsyncServiceInfo._is_complete", False): + await aiosinfo.async_request(aiozc.zeroconf, 3000) + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] + + task = await aiozc.async_unregister_service(new_info) + await task + + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is None + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_service_browser() -> None: + """Test AsyncServiceBrowser.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test9-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + await aiozc.async_add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + await task + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + task = await aiozc.async_update_service(new_info) + await task + task = await aiozc.async_unregister_service(new_info) + await task + await aiozc.zeroconf.async_wait(1) + await aiozc.async_close() + + assert calls == [ + ('add', type_, registration_name), + ('update', type_, registration_name), + ('remove', type_, registration_name), + ] + + +@pytest.mark.asyncio +async def test_async_context_manager() -> None: + """Test using an async context manager.""" + type_ = "_test10-sr-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + async with AsyncZeroconf(interfaces=['127.0.0.1']) as aiozc: + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + await task + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is not None + + +@pytest.mark.asyncio +async def test_async_unregister_all_services() -> None: + """Test unregistering all services.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "abc" + registration_name = f"{name}.{type_}" + registration_name2 = f"{name2}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-1.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info2 = ServiceInfo( + type_, + registration_name2, + 80, + 0, + 0, + desc, + "ash-5.local.", + addresses=[socket.inet_aton("10.0.1.5")], + ) + tasks = [] + tasks.append(await aiozc.async_register_service(info)) + tasks.append(await aiozc.async_register_service(info2)) + await asyncio.gather(*tasks) + + tasks = [] + tasks.append(aiozc.async_get_service_info(type_, registration_name)) + tasks.append(aiozc.async_get_service_info(type_, registration_name2)) + results = await asyncio.gather(*tasks) + assert results[0] is not None + assert results[1] is not None + + await aiozc.async_unregister_all_services() + _clear_cache(aiozc.zeroconf) + + tasks = [] + tasks.append(aiozc.async_get_service_info(type_, registration_name)) + tasks.append(aiozc.async_get_service_info(type_, registration_name2)) + results = await asyncio.gather(*tasks) + assert results[0] is None + assert results[1] is None + + # Verify we can call again + await aiozc.async_unregister_all_services() + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_zeroconf_service_types(): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + zeroconf_registrar = AsyncZeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await zeroconf_registrar.async_register_service(info) + await task + # Ensure we do not clear the cache until after the last broadcast is processed + await asyncio.sleep(0.2) + _clear_cache(zeroconf_registrar.zeroconf) + try: + service_types = await AsyncZeroconfServiceTypes.async_find(interfaces=['127.0.0.1'], timeout=2) + assert type_ in service_types + _clear_cache(zeroconf_registrar.zeroconf) + service_types = await AsyncZeroconfServiceTypes.async_find(aiozc=zeroconf_registrar, timeout=2) + assert type_ in service_types + + finally: + await zeroconf_registrar.async_close() + + +@pytest.mark.asyncio +async def test_guard_against_running_serviceinfo_request_event_loop() -> None: + """Test that running ServiceInfo.request from the event loop throws.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + + service_info = AsyncServiceInfo("_hap._tcp.local.", "doesnotmatter._hap._tcp.local.") + with pytest.raises(RuntimeError): + service_info.request(aiozc.zeroconf, 3000) + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_service_browser_instantiation_generates_add_events_from_cache(): + """Test that the ServiceBrowser will generate Add events with the existing cache when starting.""" + + # instantiate a zeroconf instance + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + callbacks = [] + + class MyServiceListener(ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("add", type_, name)) + + def remove_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("remove", type_, name)) + + def update_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("update", type_, name)) + + listener = MyServiceListener() + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + zc.cache.async_add_records( + [info.dns_pointer(), info.dns_service(), *info.dns_addresses(), info.dns_text()] + ) + + browser = AsyncServiceBrowser(zc, type_, None, listener) + + await asyncio.sleep(0) + + assert callbacks == [ + ('add', type_, registration_name), + ] + await browser.async_cancel() + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_integration(): + service_added = asyncio.Event() + service_removed = asyncio.Event() + unexpected_ttl = asyncio.Event() + got_query = asyncio.Event() + + type_ = "_http._tcp.local." + registration_name = "xxxyyy.%s" % type_ + + def on_service_state_change(zeroconf, service_type, state_change, name): + if name == registration_name: + if state_change is ServiceStateChange.Added: + service_added.set() + elif state_change is ServiceStateChange.Removed: + service_removed.set() + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zeroconf_browser = aiozc.zeroconf + await zeroconf_browser.async_wait_for_start() + + # we are going to patch the zeroconf send to check packet sizes + old_send = zeroconf_browser.async_send + + time_offset = 0.0 + + def _new_current_time_millis(): + """Current system time in milliseconds""" + return (time.time() * 1000) + (time_offset * 1000) + + expected_ttl = const._DNS_HOST_TTL + nbr_answers = 0 + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): + """Sends an outgoing packet.""" + pout = DNSIncoming(out.packets()[0]) + nonlocal nbr_answers + for answer in pout.answers: + nbr_answers += 1 + if not answer.ttl > expected_ttl / 2: + unexpected_ttl.set() + + got_query.set() + + old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope) + + assert len(zeroconf_browser.engine.protocols) == 2 + + aio_zeroconf_registrar = AsyncZeroconf(interfaces=['127.0.0.1']) + zeroconf_registrar = aio_zeroconf_registrar.zeroconf + await aio_zeroconf_registrar.zeroconf.async_wait_for_start() + + assert len(zeroconf_registrar.engine.protocols) == 2 + # patch the zeroconf send + # patch the zeroconf current_time_millis + # patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL + # Disable duplicate question suppression and duplicate packet suppression for this test as it works + # by asking the same question over and over + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_browser.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_browser.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_browser.question_history, "suppresses", return_value=False + ), patch.object( + zeroconf_browser, "async_send", send + ), patch( + "zeroconf._services.browser.current_time_millis", _new_current_time_millis + ), patch.object( + _services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4) + ): + service_added = asyncio.Event() + service_removed = asyncio.Event() + + browser = AsyncServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + task = await aio_zeroconf_registrar.async_register_service(info) + await task + + try: + await asyncio.wait_for(service_added.wait(), 1) + assert service_added.is_set() + + # Test that we receive queries containing answers only if the remaining TTL + # is greater than half the original TTL + sleep_count = 0 + test_iterations = 50 + + while nbr_answers < test_iterations: + # Increase simulated time shift by 1/4 of the TTL in seconds + time_offset += expected_ttl / 4 + now = _new_current_time_millis() + browser.reschedule_type(type_, now) + sleep_count += 1 + await asyncio.wait_for(got_query.wait(), 1) + got_query.clear() + # Prevent the test running indefinitely in an error condition + assert sleep_count < test_iterations * 4 + assert not unexpected_ttl.is_set() + # Don't remove service, allow close() to cleanup + finally: + await aio_zeroconf_registrar.async_close() + await asyncio.wait_for(service_removed.wait(), 1) + assert service_removed.is_set() + await browser.async_cancel() + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_info_asking_default_is_asking_qm_questions_after_the_first_qu(): + """Verify the service info first question is QU and subsequent ones are QM questions.""" + type_ = "_quservice._tcp.local." + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zeroconf_info = aiozc.zeroconf + + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + zeroconf_info.registry.async_add(info) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_info.async_send + + first_outgoing = None + second_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + nonlocal second_outgoing + if out.questions: + if first_outgoing is not None and second_outgoing is None: + second_outgoing = out + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with patch.object(zeroconf_info, "async_send", send): + aiosinfo = AsyncServiceInfo(type_, registration_name) + # Patch _is_complete so we send multiple times + with patch("zeroconf.asyncio.AsyncServiceInfo._is_complete", False): + await aiosinfo.async_request(aiozc.zeroconf, 1200) + try: + assert first_outgoing.questions[0].unicast == True + assert second_outgoing.questions[0].unicast == False + finally: + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_service_browser_ignores_unrelated_updates(): + """Test that the ServiceBrowser ignores unrelated updates.""" + + # instantiate a zeroconf instance + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + type_ = "_veryuniqueone._tcp.local." + registration_name = "xxxyyy.%s" % type_ + callbacks = [] + + class MyServiceListener(ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("add", type_, name)) + + def remove_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("remove", type_, name)) + + def update_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("update", type_, name)) + + listener = MyServiceListener() + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + zc.cache.async_add_records( + [ + info.dns_pointer(), + info.dns_service(), + *info.dns_addresses(), + info.dns_text(), + DNSService( + "zoom._unrelated._tcp.local.", + const._TYPE_SRV, + const._CLASS_IN, + const._DNS_HOST_TTL, + 0, + 0, + 81, + 'unrelated.local.', + ), + ] + ) + + browser = AsyncServiceBrowser(zc, type_, None, listener) + + generated = DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + DNSPointer( + "_unrelated._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN, + const._DNS_OTHER_TTL, + "zoom._unrelated._tcp.local.", + ), + 0, + ) + generated.add_answer_at_time( + DNSAddress("unrelated.local.", const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b"1234"), + 0, + ) + generated.add_answer_at_time( + DNSText( + "zoom._unrelated._tcp.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b"zoom", + ), + 0, + ) + + zc.handle_response(DNSIncoming(generated.packets()[0])) + + await browser.async_cancel() + await asyncio.sleep(0) + + assert callbacks == [ + ('add', type_, registration_name), + ] + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_request_timeout(): + """Test that the timeout does not throw an exception and finishes close to the actual timeout.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.zeroconf.async_wait_for_start() + start_time = current_time_millis() + assert await aiozc.async_get_service_info("_notfound.local.", "notthere._notfound.local.") is None + end_time = current_time_millis() + await aiozc.async_close() + # 3000ms for the default timeout + # 1000ms for loaded systems + schedule overhead + assert (end_time - start_time) < 3000 + 1000 + + +@pytest.mark.asyncio +async def test_legacy_unicast_response(run_isolated): + """Verify legacy unicast responses include questions and correct id.""" + type_ = "_mservice._tcp.local." + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.zeroconf.async_wait_for_start() + + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + aiozc.zeroconf.registry.async_add(info) + query = DNSOutgoing(const._FLAGS_QR_QUERY, multicast=False, id_=888) + question = DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + query.add_question(question) + protocol = aiozc.zeroconf.engine.protocols[0] + + with patch.object(aiozc.zeroconf, "async_send") as send_mock: + protocol.datagram_received(query.packets()[0], ('127.0.0.1', 6503)) + + calls = send_mock.mock_calls + # Verify the response is sent back on the socket it was recieved from + assert calls == [call(ANY, '127.0.0.1', 6503, (), protocol.transport)] + outgoing = send_mock.call_args[0][0] + assert isinstance(outgoing, DNSOutgoing) + assert outgoing.questions == [question] + assert outgoing.id == query.id + await aiozc.async_close() diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 00000000..559b4357 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python + + +""" Unit tests for zeroconf._cache. """ + +import logging +import unittest +import unittest.mock + +import zeroconf as r +from zeroconf import const + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class TestDNSCache(unittest.TestCase): + def test_order(self): + record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN) + cached_record = cache.get(entry) + assert cached_record == record2 + + def test_adding_same_record_to_cache_different_ttls(self): + """We should always get back the last entry we added if there are different TTLs. + + This ensures we only have one source of truth for TTLs as a record cannot + be both expired and not expired. + """ + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + entry = r.DNSEntry(record2) + cached_record = cache.get(entry) + assert cached_record == record2 + + def test_adding_same_record_to_cache_different_ttls(self): + """Verify we only get one record back. + + The last record added should replace the previous since two + records with different ttls are __eq__. This ensures we + only have one source of truth for TTLs as a record cannot + be both expired and not expired. + """ + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN) + assert cached_records == [record2] + + def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): + record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert 'a' in cache.cache + cache.async_remove_records([record1, record2]) + assert 'a' not in cache.cache + + def test_cache_empty_multiple_calls(self): + record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert 'a' in cache.cache + cache.async_remove_records([record1, record2]) + assert 'a' not in cache.cache + + +class TestDNSAsyncCacheAPI(unittest.TestCase): + def test_async_get_unique(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert cache.async_get_unique(record1) == record1 + assert cache.async_get_unique(record2) == record2 + + def test_async_all_by_details(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == {record1, record2} + + def test_async_entries_with_server(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_entries_with_server('ab')) == {record1, record2} + assert set(cache.async_entries_with_server('AB')) == {record1, record2} + + def test_async_entries_with_name(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_entries_with_name('irrelevant')) == {record1, record2} + assert set(cache.async_entries_with_name('Irrelevant')) == {record1, record2} + + +# These functions have been seen in other projects so +# we try to maintain a stable API for all the threadsafe getters +class TestDNSCacheAPI(unittest.TestCase): + def test_get(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + record3 = r.DNSAddress('a', const._TYPE_AAAA, const._CLASS_IN, 1, b'ipv6') + cache = r.DNSCache() + cache.async_add_records([record1, record2, record3]) + assert cache.get(record1) == record1 + assert cache.get(record2) == record2 + assert cache.get(r.DNSEntry('a', const._TYPE_A, const._CLASS_IN)) == record2 + assert cache.get(r.DNSEntry('a', const._TYPE_AAAA, const._CLASS_IN)) == record3 + assert cache.get(r.DNSEntry('notthere', const._TYPE_A, const._CLASS_IN)) is None + + def test_get_by_details(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert cache.get_by_details('a', const._TYPE_A, const._CLASS_IN) == record2 + + def test_get_all_by_details(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == {record1, record2} + + def test_entries_with_server(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.entries_with_server('ab')) == {record1, record2} + assert set(cache.entries_with_server('AB')) == {record1, record2} + + def test_entries_with_name(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.entries_with_name('irrelevant')) == {record1, record2} + assert set(cache.entries_with_name('Irrelevant')) == {record1, record2} + + def test_current_entry_with_name_and_alias(self): + record1 = r.DNSPointer( + 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'x.irrelevant' + ) + record2 = r.DNSPointer( + 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'y.irrelevant' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert cache.current_entry_with_name_and_alias('irrelevant', 'x.irrelevant') == record1 + + def test_name(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert cache.names() == ['irrelevant'] diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 00000000..eab769be --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,800 @@ +#!/usr/bin/env python + + +""" Unit tests for zeroconf._core """ + +import asyncio +import itertools +import logging +import os +import pytest +import socket +import sys +import time +import threading +import unittest +import unittest.mock +from typing import cast +from unittest.mock import patch + +import zeroconf as r +from zeroconf import _core, const, Zeroconf, current_time_millis +from zeroconf.asyncio import AsyncZeroconf +from zeroconf._protocol import outgoing + +from . import has_working_ipv6, _clear_cache, _inject_response, _wait_for_start + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +def threadsafe_query(zc, protocol, *args): + async def make_query(): + protocol.handle_query_or_defer(*args) + + asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result() + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_reaper(): + with patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10): + assert _core._CACHE_CLEANUP_INTERVAL == 10 + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zeroconf = aiozc.zeroconf + cache = zeroconf.cache + original_entries = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names()))) + record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a') + record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') + zeroconf.cache.async_add_records([record_with_10s_ttl, record_with_1s_ttl]) + question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + now = r.current_time_millis() + other_known_answers = { + r.DNSPointer( + "_hap._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN, + 10000, + 'known-to-other._hap._tcp.local.', + ) + } + zeroconf.question_history.add_question_at_time(question, now, other_known_answers) + assert zeroconf.question_history.suppresses(question, now, other_known_answers) + entries_with_cache = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names()))) + await asyncio.sleep(1.2) + entries = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names()))) + assert zeroconf.cache.get(record_with_1s_ttl) is None + await aiozc.async_close() + assert not zeroconf.question_history.suppresses(question, now, other_known_answers) + assert entries != original_entries + assert entries_with_cache != original_entries + assert record_with_10s_ttl in entries + assert record_with_1s_ttl not in entries + + +@pytest.mark.asyncio +async def test_reaper_aborts_when_done(): + """Ensure cache cleanup stops when zeroconf is done.""" + with patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10): + assert _core._CACHE_CLEANUP_INTERVAL == 10 + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zeroconf = aiozc.zeroconf + record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a') + record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') + zeroconf.cache.async_add_records([record_with_10s_ttl, record_with_1s_ttl]) + assert zeroconf.cache.get(record_with_10s_ttl) is not None + assert zeroconf.cache.get(record_with_1s_ttl) is not None + await aiozc.async_close() + await asyncio.sleep(1.2) + assert zeroconf.cache.get(record_with_10s_ttl) is not None + assert zeroconf.cache.get(record_with_1s_ttl) is not None + + +class Framework(unittest.TestCase): + def test_launch_and_close(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.All) + rv.close() + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default) + rv.close() + + def test_launch_and_close_context_manager(self): + with r.Zeroconf(interfaces=r.InterfaceChoice.All) as rv: + assert rv.done is False + assert rv.done is True + + with r.Zeroconf(interfaces=r.InterfaceChoice.Default) as rv: + assert rv.done is False + assert rv.done is True + + def test_launch_and_close_unicast(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, unicast=True) + rv.close() + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, unicast=True) + rv.close() + + def test_close_multiple_times(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default) + rv.close() + rv.close() + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_launch_and_close_v4_v6(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.All) + rv.close() + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.All) + rv.close() + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_launch_and_close_v6_only(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.V6Only) + rv.close() + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.V6Only) + rv.close() + + @unittest.skipIf(sys.platform == 'darwin', reason="apple_p2p failure path not testable on mac") + def test_launch_and_close_apple_p2p_not_mac(self): + with pytest.raises(RuntimeError): + r.Zeroconf(apple_p2p=True) + + @unittest.skipIf(sys.platform != 'darwin', reason="apple_p2p happy path only testable on mac") + def test_launch_and_close_apple_p2p_on_mac(self): + rv = r.Zeroconf(apple_p2p=True) + rv.close() + + def test_handle_response(self): + def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: + ttl = 120 + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + + if service_state_change == r.ServiceStateChange.Updated: + generated.add_answer_at_time( + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + service_text, + ), + 0, + ) + return r.DNSIncoming(generated.packets()[0]) + + if service_state_change == r.ServiceStateChange.Removed: + ttl = 0 + + generated.add_answer_at_time( + r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 + ) + generated.add_answer_at_time( + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ), + 0, + ) + generated.add_answer_at_time( + r.DNSText( + service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text + ), + 0, + ) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_aton(service_address), + ), + 0, + ) + + return r.DNSIncoming(generated.packets()[0]) + + def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: + """Mock an incoming message for the case where the packet is split.""" + ttl = 120 + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_aton(service_address), + ), + 0, + ) + generated.add_answer_at_time( + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ), + 0, + ) + return r.DNSIncoming(generated.packets()[0]) + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-2.local.' + service_text = b'path=/~paulsm/' + service_address = '10.0.1.2' + + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + + try: + # service added + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) + dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) + assert dns_text is not None + assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~paulsm/' + all_dns_text = zeroconf.cache.get_all_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) + assert [dns_text] == all_dns_text + + # https://tools.ietf.org/html/rfc6762#section-10.2 + # Instead of merging this new record additively into the cache in addition + # to any previous records with the same name, rrtype, and rrclass, + # all old records with that name, rrtype, and rrclass that were received + # more than one second ago are declared invalid, + # and marked to expire from the cache in one second. + time.sleep(1.1) + + # service updated. currently only text record can be updated + service_text = b'path=/~humingchun/' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) + assert dns_text is not None + assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' + + time.sleep(1.1) + + # The split message only has a SRV and A record. + # This should not evict TXT records from the cache + _inject_response(zeroconf, mock_split_incoming_msg(r.ServiceStateChange.Updated)) + time.sleep(1.1) + dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) + assert dns_text is not None + assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' + + # service removed + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) + dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) + assert dns_text.is_expired(current_time_millis() + 1000) + + finally: + zeroconf.close() + + +def test_generate_service_query_set_qu_bit(): + """Test generate_service_query sets the QU bit.""" + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + type_ = "._hap._tcp.local." + registration_name = "this-host-is-not-used._hap._tcp.local." + info = r.ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + out = zeroconf_registrar.generate_service_query(info) + assert out.questions[0].unicast is True + zeroconf_registrar.close() + + +def test_invalid_packets_ignored_and_does_not_cause_loop_exception(): + """Ensure an invalid packet cannot cause the loop to collapse.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + generated = r.DNSOutgoing(0) + packet = generated.packets()[0] + packet = packet[:8] + b'deadbeef' + packet[8:] + parsed = r.DNSIncoming(packet) + assert parsed.valid is False + + # Invalid Packet + mock_out = unittest.mock.Mock() + mock_out.packets = lambda: [packet] + zc.send(mock_out) + + # Invalid oversized packet + mock_out = unittest.mock.Mock() + mock_out.packets = lambda: [packet * 1000] + zc.send(mock_out) + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + entry = r.DNSText( + "didnotcrashincoming._crash._tcp.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + assert isinstance(entry, r.DNSText) + assert isinstance(entry, r.DNSRecord) + assert isinstance(entry, r.DNSEntry) + + generated.add_answer_at_time(entry, 0) + zc.send(generated) + time.sleep(0.2) + zc.close() + assert zc.cache.get(entry) is not None + + +def test_goodbye_all_services(): + """Verify generating the goodbye query does not change with time.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + out = zc.generate_unregister_all_services() + assert out is None + type_ = "_http._tcp.local." + registration_name = "xxxyyy.%s" % type_ + desc = {'path': '/~paulsm/'} + info = r.ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + out = zc.generate_unregister_all_services() + assert out is not None + first_packet = out.packets() + zc.registry.async_add(info) + out2 = zc.generate_unregister_all_services() + assert out2 is not None + second_packet = out.packets() + assert second_packet == first_packet + + # Verify the registery is empty + out3 = zc.generate_unregister_all_services() + assert out3 is None + assert zc.registry.async_get_service_infos() == [] + + zc.close() + + +def test_register_service_with_custom_ttl(): + """Test a registering a service with a custom ttl.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # start a browser + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + info_service = r.ServiceInfo( + type_, + f'{name}.{type_}', + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-90.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + zc.register_service(info_service, ttl=3000) + assert zc.cache.get(info_service.dns_pointer()).ttl == 3000 + zc.close() + + +def test_logging_packets(caplog): + """Test packets are only logged with debug logging.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # start a browser + type_ = "_logging._tcp.local." + name = "TLD" + info_service = r.ServiceInfo( + type_, + f'{name}.{type_}', + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-90.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + caplog.clear() + zc.register_service(info_service, ttl=3000) + assert "Sending to" in caplog.text + assert zc.cache.get(info_service.dns_pointer()).ttl == 3000 + logging.getLogger('zeroconf').setLevel(logging.INFO) + caplog.clear() + zc.unregister_service(info_service) + assert "Sending to" not in caplog.text + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + + zc.close() + + +def test_get_service_info_failure_path(): + """Verify get_service_info return None when the underlying call returns False.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + assert zc.get_service_info("_neverused._tcp.local.", "xneverused._neverused._tcp.local.", 10) is None + zc.close() + + +def test_sending_unicast(): + """Test sending unicast response.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + entry = r.DNSText( + "didnotcrashincoming._crash._tcp.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + generated.add_answer_at_time(entry, 0) + zc.send(generated, "2001:db8::1", const._MDNS_PORT) # https://www.iana.org/go/rfc3849 + time.sleep(0.2) + assert zc.cache.get(entry) is None + + zc.send(generated, "198.51.100.0", const._MDNS_PORT) # Documentation (TEST-NET-2) + time.sleep(0.2) + assert zc.cache.get(entry) is None + + zc.send(generated) + time.sleep(0.2) + assert zc.cache.get(entry) is not None + + zc.close() + + +def test_tc_bit_defers(): + zc = Zeroconf(interfaces=['127.0.0.1']) + _wait_for_start(zc) + type_ = "_tcbitdefer._tcp.local." + name = "knownname" + name2 = "knownname2" + name3 = "knownname3" + + registration_name = f"{name}.{type_}" + registration2_name = f"{name2}.{type_}" + registration3_name = f"{name3}.{type_}" + + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + server_name2 = "ash-3.local." + server_name3 = "ash-4.local." + + info = r.ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = r.ServiceInfo( + type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + info3 = r.ServiceInfo( + type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + zc.registry.async_add(info2) + zc.registry.async_add(info3) + + protocol = zc.engine.protocols[0] + now = r.current_time_millis() + _clear_cache(zc) + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + for _ in range(300): + # Add so many answers we end up with another packet + generated.add_answer_at_time(info.dns_pointer(), now) + generated.add_answer_at_time(info2.dns_pointer(), now) + generated.add_answer_at_time(info3.dns_pointer(), now) + packets = generated.packets() + assert len(packets) == 4 + expected_deferred = [] + source_ip = '203.0.113.13' + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + assert source_ip not in protocol._deferred + assert source_ip not in protocol._timers + + # unregister + zc.unregister_service(info) + zc.close() + + +def test_tc_bit_defers_last_response_missing(): + zc = Zeroconf(interfaces=['127.0.0.1']) + _wait_for_start(zc) + type_ = "_knowndefer._tcp.local." + name = "knownname" + name2 = "knownname2" + name3 = "knownname3" + + registration_name = f"{name}.{type_}" + registration2_name = f"{name2}.{type_}" + registration3_name = f"{name3}.{type_}" + + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + server_name2 = "ash-3.local." + server_name3 = "ash-4.local." + + info = r.ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = r.ServiceInfo( + type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + info3 = r.ServiceInfo( + type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + zc.registry.async_add(info2) + zc.registry.async_add(info3) + + protocol = zc.engine.protocols[0] + now = r.current_time_millis() + _clear_cache(zc) + source_ip = '203.0.113.12' + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + for _ in range(300): + # Add so many answers we end up with another packet + generated.add_answer_at_time(info.dns_pointer(), now) + generated.add_answer_at_time(info2.dns_pointer(), now) + generated.add_answer_at_time(info3.dns_pointer(), now) + packets = generated.packets() + assert len(packets) == 4 + expected_deferred = [] + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + assert protocol._deferred[source_ip] == expected_deferred + timer1 = protocol._timers[source_ip] + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + assert protocol._deferred[source_ip] == expected_deferred + timer2 = protocol._timers[source_ip] + if sys.version_info >= (3, 7): + assert timer1.cancelled() + assert timer2 != timer1 + + # Send the same packet again to similar multi interfaces + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers + timer3 = protocol._timers[source_ip] + if sys.version_info >= (3, 7): + assert not timer3.cancelled() + assert timer3 == timer2 + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers + timer4 = protocol._timers[source_ip] + if sys.version_info >= (3, 7): + assert timer3.cancelled() + assert timer4 != timer3 + + for _ in range(8): + time.sleep(0.1) + if source_ip not in protocol._timers and source_ip not in protocol._deferred: + break + + assert source_ip not in protocol._deferred + assert source_ip not in protocol._timers + + # unregister + zc.registry.async_remove(info) + zc.close() + + +@pytest.mark.asyncio +async def test_open_close_twice_from_async() -> None: + """Test we can close twice from a coroutine when using Zeroconf. + + Ideally callers switch to using AsyncZeroconf, however there will + be a peroid where they still call the sync wrapper that we want + to ensure will not deadlock on shutdown. + + This test is expected to throw warnings about tasks being destroyed + since we force shutdown right away since we don't want to block + callers event loops and since they aren't using the AsyncZeroconf + version they won't yield with an await like async_close we don't + have much choice but to force things down. + """ + zc = Zeroconf(interfaces=['127.0.0.1']) + zc.close() + zc.close() + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_multiple_sync_instances_stared_from_async_close(): + """Test we can shutdown multiple sync instances from async.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + zc2 = Zeroconf(interfaces=['127.0.0.1']) + + assert zc.loop == zc2.loop + + zc.close() + assert zc.loop.is_running() + zc2.close() + assert zc2.loop.is_running() + + zc3 = Zeroconf(interfaces=['127.0.0.1']) + assert zc3.loop == zc2.loop + + zc3.close() + assert zc3.loop.is_running() + + await asyncio.sleep(0) + + +def test_guard_against_oversized_packets(): + """Ensure we do not process oversized packets. + + These packets can quickly overwhelm the system. + """ + zc = Zeroconf(interfaces=['127.0.0.1']) + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + + for i in range(5000): + generated.add_answer_at_time( + r.DNSText( + "packet{i}.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ), + 0, + ) + + # We are patching to generate an oversized packet + with patch.object(outgoing, "_MAX_MSG_ABSOLUTE", 100000), patch.object( + outgoing, "_MAX_MSG_TYPICAL", 100000 + ): + over_sized_packet = generated.packets()[0] + assert len(over_sized_packet) > const._MAX_MSG_ABSOLUTE + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + okpacket_record = r.DNSText( + "okpacket.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + + generated.add_answer_at_time( + okpacket_record, + 0, + ) + ok_packet = generated.packets()[0] + + # We cannot test though the network interface as some operating systems + # will guard against the oversized packet and we won't see it. + listener = _core.AsyncListener(zc) + listener.transport = unittest.mock.MagicMock() + + listener.datagram_received(ok_packet, ('127.0.0.1', const._MDNS_PORT)) + assert zc.cache.async_get_unique(okpacket_record) is not None + + listener.datagram_received(over_sized_packet, ('127.0.0.1', const._MDNS_PORT)) + assert ( + zc.cache.async_get_unique( + r.DNSText( + "packet0.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + ) + is None + ) + + zc.close() + + +def test_guard_against_duplicate_packets(): + """Ensure we do not process duplicate packets. + These packets can quickly overwhelm the system. + """ + zc = Zeroconf(interfaces=['127.0.0.1']) + listener = _core.AsyncListener(zc) + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is False + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is True + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is True + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis() + 1000) is False + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is True + assert listener.suppress_duplicate_packet(b"other packet", current_time_millis()) is False + assert listener.suppress_duplicate_packet(b"other packet", current_time_millis()) is True + assert listener.suppress_duplicate_packet(b"other packet", current_time_millis() + 1000) is False + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is False + zc.close() + + +def test_shutdown_while_register_in_process(): + """Test we can shutdown while registering a service in another thread.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # start a browser + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + info_service = r.ServiceInfo( + type_, + f'{name}.{type_}', + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-90.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + def _background_register(): + zc.register_service(info_service) + + bgthread = threading.Thread(target=_background_register, daemon=True) + bgthread.start() + time.sleep(0.3) + + zc.close() + bgthread.join() diff --git a/tests/test_dns.py b/tests/test_dns.py new file mode 100644 index 00000000..c2669205 --- /dev/null +++ b/tests/test_dns.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python + + +""" Unit tests for zeroconf._dns. """ + +import logging +import os +import socket +import time +import unittest +import unittest.mock + +import zeroconf as r +from zeroconf import const, current_time_millis +from zeroconf._dns import DNSRRSet +from zeroconf import ( + DNSHinfo, + DNSText, + ServiceInfo, +) + +from . import has_working_ipv6 + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class TestDunder(unittest.TestCase): + def test_dns_text_repr(self): + # There was an issue on Python 3 that prevented DNSText's repr + # from working when the text was longer than 10 bytes + text = DNSText('irrelevant', 0, 0, 0, b'12345678901') + repr(text) + + text = DNSText('irrelevant', 0, 0, 0, b'123') + repr(text) + + def test_dns_hinfo_repr_eq(self): + hinfo = DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os') + assert hinfo == hinfo + repr(hinfo) + + def test_dns_pointer_repr(self): + pointer = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123') + repr(pointer) + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_dns_address_repr(self): + address = r.DNSAddress('irrelevant', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + assert repr(address).endswith("b'a'") + + address_ipv4 = r.DNSAddress( + 'irrelevant', const._TYPE_SOA, const._CLASS_IN, 1, socket.inet_pton(socket.AF_INET, '127.0.0.1') + ) + assert repr(address_ipv4).endswith('127.0.0.1') + + address_ipv6 = r.DNSAddress( + 'irrelevant', const._TYPE_SOA, const._CLASS_IN, 1, socket.inet_pton(socket.AF_INET6, '::1') + ) + assert repr(address_ipv6).endswith('::1') + + def test_dns_question_repr(self): + question = r.DNSQuestion('irrelevant', const._TYPE_SRV, const._CLASS_IN | const._CLASS_UNIQUE) + repr(question) + assert not question != question + + def test_dns_service_repr(self): + service = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a' + ) + repr(service) + + def test_dns_record_abc(self): + record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) + self.assertRaises(r.AbstractMethodException, record.__eq__, record) + self.assertRaises(r.AbstractMethodException, record.write, None) + + def test_dns_record_reset_ttl(self): + record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) + time.sleep(1) + record2 = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) + now = r.current_time_millis() + + assert record.created != record2.created + assert record.get_remaining_ttl(now) != record2.get_remaining_ttl(now) + + record.reset_ttl(record2) + + assert record.ttl == record2.ttl + assert record.created == record2.created + assert record.get_remaining_ttl(now) == record2.get_remaining_ttl(now) + + def test_service_info_dunder(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + b'', + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + assert not info != info + repr(info) + + def test_service_info_text_properties_not_given(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + info = ServiceInfo( + type_=type_, + name=registration_name, + addresses=[socket.inet_aton("10.0.1.2")], + port=80, + server="ash-2.local.", + ) + + assert isinstance(info.text, bytes) + repr(info) + + def test_dns_outgoing_repr(self): + dns_outgoing = r.DNSOutgoing(const._FLAGS_QR_QUERY) + repr(dns_outgoing) + + def test_dns_record_is_expired(self): + record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8) + now = current_time_millis() + assert record.is_expired(now) is False + assert record.is_expired(now + (8 / 2 * 1000)) is False + assert record.is_expired(now + (8 * 1000)) is True + + def test_dns_record_is_stale(self): + record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8) + now = current_time_millis() + assert record.is_stale(now) is False + assert record.is_stale(now + (8 / 4.1 * 1000)) is False + assert record.is_stale(now + (8 / 2 * 1000)) is True + assert record.is_stale(now + (8 * 1000)) is True + + def test_dns_record_is_recent(self): + now = current_time_millis() + record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8) + assert record.is_recent(now + (8 / 4.1 * 1000)) is True + assert record.is_recent(now + (8 / 3 * 1000)) is False + assert record.is_recent(now + (8 / 2 * 1000)) is False + assert record.is_recent(now + (8 * 1000)) is False + + +def test_dns_question_hashablity(): + """Test DNSQuestions are hashable.""" + + record1 = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN) + record2 = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN) + + record_set = {record1, record2} + assert len(record_set) == 1 + + record_set.add(record1) + assert len(record_set) == 1 + + record3_dupe = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN) + assert record2 == record3_dupe + assert record2.__hash__() == record3_dupe.__hash__() + + record_set.add(record3_dupe) + assert len(record_set) == 1 + + record4_dupe = r.DNSQuestion('notsame', const._TYPE_A, const._CLASS_IN) + assert record2 != record4_dupe + assert record2.__hash__() != record4_dupe.__hash__() + + record_set.add(record4_dupe) + assert len(record_set) == 2 + + +def test_dns_record_hashablity_does_not_consider_ttl(): + """Test DNSRecord are hashable.""" + + # Verify the TTL is not considered in the hash + record1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_OTHER_TTL, b'same') + record2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same') + + record_set = {record1, record2} + assert len(record_set) == 1 + + record_set.add(record1) + assert len(record_set) == 1 + + record3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same') + assert record2 == record3_dupe + assert record2.__hash__() == record3_dupe.__hash__() + + record_set.add(record3_dupe) + assert len(record_set) == 1 + + +def test_dns_record_hashablity_does_not_consider_unique(): + """Test DNSRecord are hashable and unique is ignored.""" + + # Verify the unique value is not considered in the hash + record1 = r.DNSAddress( + 'irrelevant', const._TYPE_A, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_OTHER_TTL, b'same' + ) + record2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_OTHER_TTL, b'same') + + assert record1.class_ == record2.class_ + assert record1.__hash__() == record2.__hash__() + record_set = {record1, record2} + assert len(record_set) == 1 + + +def test_dns_address_record_hashablity(): + """Test DNSAddress are hashable.""" + address1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'a') + address2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'b') + address3 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c') + address4 = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 1, b'c') + + record_set = {address1, address2, address3, address4} + assert len(record_set) == 4 + + record_set.add(address1) + assert len(record_set) == 4 + + address3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c') + + record_set.add(address3_dupe) + assert len(record_set) == 4 + + # Verify we can remove records + additional_set = {address1, address2} + record_set -= additional_set + assert record_set == {address3, address4} + + +def test_dns_hinfo_record_hashablity(): + """Test DNSHinfo are hashable.""" + hinfo1 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu1', 'os') + hinfo2 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os') + + record_set = {hinfo1, hinfo2} + assert len(record_set) == 2 + + record_set.add(hinfo1) + assert len(record_set) == 2 + + hinfo2_dupe = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os') + assert hinfo2 == hinfo2_dupe + assert hinfo2.__hash__() == hinfo2_dupe.__hash__() + + record_set.add(hinfo2_dupe) + assert len(record_set) == 2 + + +def test_dns_pointer_record_hashablity(): + """Test DNSPointer are hashable.""" + ptr1 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123') + ptr2 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456') + + record_set = {ptr1, ptr2} + assert len(record_set) == 2 + + record_set.add(ptr1) + assert len(record_set) == 2 + + ptr2_dupe = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456') + assert ptr2 == ptr2 + assert ptr2.__hash__() == ptr2_dupe.__hash__() + + record_set.add(ptr2_dupe) + assert len(record_set) == 2 + + +def test_dns_text_record_hashablity(): + """Test DNSText are hashable.""" + text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') + text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901') + text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901') + text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK') + + record_set = {text1, text2, text3, text4} + + assert len(record_set) == 4 + + record_set.add(text1) + assert len(record_set) == 4 + + text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') + assert text1 == text1_dupe + assert text1.__hash__() == text1_dupe.__hash__() + + record_set.add(text1_dupe) + assert len(record_set) == 4 + + +def test_dns_service_record_hashablity(): + """Test DNSService are hashable.""" + srv1 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a') + srv2 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 1, 80, 'a') + srv3 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 81, 'a') + srv4 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab') + + record_set = {srv1, srv2, srv3, srv4} + + assert len(record_set) == 4 + + record_set.add(srv1) + assert len(record_set) == 4 + + srv1_dupe = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a' + ) + assert srv1 == srv1_dupe + assert srv1.__hash__() == srv1_dupe.__hash__() + + record_set.add(srv1_dupe) + assert len(record_set) == 4 + + +def test_dns_nsec_record_hashablity(): + """Test DNSNsec are hashable.""" + nsec1 = r.DNSNsec( + 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2, 3] + ) + nsec2 = r.DNSNsec( + 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2] + ) + + record_set = {nsec1, nsec2} + assert len(record_set) == 2 + + record_set.add(nsec1) + assert len(record_set) == 2 + + nsec2_dupe = r.DNSNsec( + 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2] + ) + assert nsec2 == nsec2_dupe + assert nsec2.__hash__() == nsec2_dupe.__hash__() + + record_set.add(nsec2_dupe) + assert len(record_set) == 2 + + +def test_rrset_does_not_consider_ttl(): + """Test DNSRRSet does not consider the ttl in the hash.""" + + longarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 100, b'same') + shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same') + longaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 100, b'same') + shortaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 10, b'same') + + rrset = DNSRRSet([longarec, shortaaaarec]) + + assert rrset.suppresses(longarec) + assert rrset.suppresses(shortarec) + assert not rrset.suppresses(longaaaarec) + assert rrset.suppresses(shortaaaarec) + + verylongarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1000, b'same') + longarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 100, b'same') + mediumarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 60, b'same') + shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same') + + rrset2 = DNSRRSet([mediumarec]) + assert not rrset2.suppresses(verylongarec) + assert rrset2.suppresses(longarec) + assert rrset2.suppresses(mediumarec) + assert rrset2.suppresses(shortarec) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 00000000..47e68b75 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python + + +""" Unit tests for zeroconf._exceptions """ + +import logging +import unittest +import unittest.mock + +import zeroconf as r +from zeroconf import ( + ServiceInfo, + Zeroconf, +) + + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class Exceptions(unittest.TestCase): + + browser = None # type: Zeroconf + + @classmethod + def setUpClass(cls): + cls.browser = Zeroconf(interfaces=['127.0.0.1']) + + @classmethod + def tearDownClass(cls): + cls.browser.close() + del cls.browser + + def test_bad_service_info_name(self): + self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, "type", "type_not") + + def test_bad_service_names(self): + bad_names_to_try = ( + '', + 'local', + '_tcp.local.', + '_udp.local.', + '._udp.local.', + '_@._tcp.local.', + '_A@._tcp.local.', + '_x--x._tcp.local.', + '_-x._udp.local.', + '_x-._tcp.local.', + '_22._udp.local.', + '_2-2._tcp.local.', + '\x00._x._udp.local.', + ) + for name in bad_names_to_try: + self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, name, 'x.' + name) + + def test_bad_local_names_for_get_service_info(self): + bad_names_to_try = ( + 'homekitdev._nothttp._tcp.local.', + 'homekitdev._http._udp.local.', + ) + for name in bad_names_to_try: + self.assertRaises( + r.BadTypeInNameException, self.browser.get_service_info, '_http._tcp.local.', name + ) + + def test_good_instance_names(self): + assert r.service_type_name('.._x._tcp.local.') == '_x._tcp.local.' + assert r.service_type_name('x.sub._http._tcp.local.') == '_http._tcp.local.' + assert ( + r.service_type_name('6d86f882b90facee9170ad3439d72a4d6ee9f511._zget._http._tcp.local.') + == '_http._tcp.local.' + ) + + def test_good_instance_names_without_protocol(self): + good_names_to_try = ( + "Rachio-C73233.local.", + 'YeelightColorBulb-3AFD.local.', + 'YeelightTunableBulb-7220.local.', + "AlexanderHomeAssistant 74651D.local.", + 'iSmartGate-152.local.', + 'MyQ-FGA.local.', + 'lutron-02c4392a.local.', + 'WICED-hap-3E2734.local.', + 'MyHost.local.', + 'MyHost.sub.local.', + ) + for name in good_names_to_try: + assert r.service_type_name(name, strict=False) == 'local.' + + for name in good_names_to_try: + # Raises without strict=False + self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) + + def test_bad_types(self): + bad_names_to_try = ( + '._x._tcp.local.', + 'a' * 64 + '._sub._http._tcp.local.', + 'a' * 62 + 'â._sub._http._tcp.local.', + ) + for name in bad_names_to_try: + self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) + + def test_bad_sub_types(self): + bad_names_to_try = ( + '_sub._http._tcp.local.', + '._sub._http._tcp.local.', + '\x7f._sub._http._tcp.local.', + '\x1f._sub._http._tcp.local.', + ) + for name in bad_names_to_try: + self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) + + def test_good_service_names(self): + good_names_to_try = ( + ('_x._tcp.local.', '_x._tcp.local.'), + ('_x._udp.local.', '_x._udp.local.'), + ('_12345-67890-abc._udp.local.', '_12345-67890-abc._udp.local.'), + ('x._sub._http._tcp.local.', '_http._tcp.local.'), + ('a' * 63 + '._sub._http._tcp.local.', '_http._tcp.local.'), + ('a' * 61 + 'â._sub._http._tcp.local.', '_http._tcp.local.'), + ) + + for name, result in good_names_to_try: + assert r.service_type_name(name) == result + + assert r.service_type_name('_one_two._tcp.local.', strict=False) == '_one_two._tcp.local.' + + def test_invalid_addresses(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + bad = ('127.0.0.1', '::1', 42) + for addr in bad: + self.assertRaisesRegex( + TypeError, + 'Addresses must be bytes', + ServiceInfo, + type_, + registration_name, + port=80, + addresses=[addr], + ) diff --git a/tests/test_handlers.py b/tests/test_handlers.py new file mode 100644 index 00000000..44ee1d5a --- /dev/null +++ b/tests/test_handlers.py @@ -0,0 +1,1540 @@ +#!/usr/bin/env python + + +""" Unit tests for zeroconf._handlers """ + +import asyncio +import logging +import os +import pytest +import socket +import time +import unittest +import unittest.mock +from typing import List + +import zeroconf as r +from zeroconf import _handlers, ServiceInfo, Zeroconf, current_time_millis +from zeroconf import const +from zeroconf._handlers import construct_outgoing_multicast_answers, MulticastOutgoingQueue +from zeroconf._utils.time import millis_to_seconds +from zeroconf.asyncio import AsyncZeroconf + + +from . import _clear_cache, _inject_response, has_working_ipv6 + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class TestRegistrar(unittest.TestCase): + def test_ttl(self): + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + nbr_answers = nbr_additionals = nbr_authorities = 0 + + def get_ttl(record_type): + if expected_ttl is not None: + return expected_ttl + elif record_type in [const._TYPE_A, const._TYPE_SRV]: + return const._DNS_HOST_TTL + else: + return const._DNS_OTHER_TTL + + def _process_outgoing_packet(out): + """Sends an outgoing packet.""" + nonlocal nbr_answers, nbr_additionals, nbr_authorities + + for answer, time_ in out.answers: + nbr_answers += 1 + assert answer.ttl == get_ttl(answer.type) + for answer in out.additionals: + nbr_additionals += 1 + assert answer.ttl == get_ttl(answer.type) + for answer in out.authorities: + nbr_authorities += 1 + assert answer.ttl == get_ttl(answer.type) + + # register service with default TTL + expected_ttl = None + for _ in range(3): + _process_outgoing_packet(zc.generate_service_query(info)) + zc.registry.async_add(info) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, None)) + assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 + nbr_answers = nbr_additionals = nbr_authorities = 0 + + # query + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + assert query.is_query() is True + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate)) + + # The additonals should all be suppresed since they are all in the answers section + # There will be one NSEC additional to indicate the lack of AAAA record + # + assert nbr_answers == 4 and nbr_additionals == 1 and nbr_authorities == 0 + nbr_answers = nbr_additionals = nbr_authorities = 0 + + # unregister + expected_ttl = 0 + zc.registry.async_remove(info) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, 0)) + assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 + nbr_answers = nbr_additionals = nbr_authorities = 0 + + expected_ttl = None + for _ in range(3): + _process_outgoing_packet(zc.generate_service_query(info)) + zc.registry.async_add(info) + # register service with custom TTL + expected_ttl = const._DNS_HOST_TTL * 2 + assert expected_ttl != const._DNS_HOST_TTL + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, expected_ttl)) + assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 + nbr_answers = nbr_additionals = nbr_authorities = 0 + + # query + expected_ttl = None + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate)) + + # There will be one NSEC additional to indicate the lack of AAAA record + assert nbr_answers == 4 and nbr_additionals == 1 and nbr_authorities == 0 + nbr_answers = nbr_additionals = nbr_authorities = 0 + + # unregister + expected_ttl = 0 + zc.registry.async_remove(info) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, 0)) + assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 + nbr_answers = nbr_additionals = nbr_authorities = 0 + zc.close() + + def test_name_conflicts(self): + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_homeassistant._tcp.local." + name = "Home" + registration_name = f"{name}.{type_}" + + info = ServiceInfo( + type_, + name=registration_name, + server="random123.local.", + addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")], + port=80, + properties={"version": "1.0"}, + ) + zc.register_service(info) + + conflicting_info = ServiceInfo( + type_, + name=registration_name, + server="random456.local.", + addresses=[socket.inet_pton(socket.AF_INET, "4.5.6.7")], + port=80, + properties={"version": "1.0"}, + ) + with pytest.raises(r.NonUniqueNameException): + zc.register_service(conflicting_info) + zc.close() + + def test_register_and_lookup_type_by_uppercase_name(self): + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_mylowertype._tcp.local." + name = "Home" + registration_name = f"{name}.{type_}" + + info = ServiceInfo( + type_, + name=registration_name, + server="random123.local.", + addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")], + port=80, + properties={"version": "1.0"}, + ) + zc.register_service(info) + _clear_cache(zc) + info = ServiceInfo(type_, registration_name) + info.load_from_cache(zc) + assert info.addresses == [] + + out = r.DNSOutgoing(const._FLAGS_QR_QUERY) + out.add_question(r.DNSQuestion(type_.upper(), const._TYPE_PTR, const._CLASS_IN)) + zc.send(out) + time.sleep(1) + info = ServiceInfo(type_, registration_name) + info.load_from_cache(zc) + assert info.addresses == [socket.inet_pton(socket.AF_INET, "1.2.3.4")] + assert info.properties == {b"version": b"1.0"} + zc.close() + + +def test_ptr_optimization(): + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + # register + zc.register_service(info) + + # Verify we won't respond for 1s with the same multicast + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + # Since we sent the PTR in the last second, they + # should end up in the delayed at least one second bucket + assert question_answers.mcast_aggregate_last_second + + # Clear the cache to allow responding again + _clear_cache(zc) + + # Verify we will now respond + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate_last_second + has_srv = has_txt = has_a = False + nbr_additionals = 0 + nbr_answers = len(question_answers.mcast_aggregate) + additionals = set().union(*question_answers.mcast_aggregate.values()) + for answer in additionals: + nbr_additionals += 1 + if answer.type == const._TYPE_SRV: + has_srv = True + elif answer.type == const._TYPE_TXT: + has_txt = True + elif answer.type == const._TYPE_A: + has_a = True + assert nbr_answers == 1 and nbr_additionals == 4 + # There will be one NSEC additional to indicate the lack of AAAA record + + assert has_srv and has_txt and has_a + + # unregister + zc.unregister_service(info) + zc.close() + + +@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') +@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') +def test_any_query_for_ptr(): + """Test that queries for ANY will return PTR records and the response is aggregated.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_anyptr._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) + zc.registry.async_add(info) + + _clear_cache(zc) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_ANY, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + mcast_answers = list(question_answers.mcast_aggregate) + assert mcast_answers[0].name == type_ + assert mcast_answers[0].alias == registration_name + # unregister + zc.registry.async_remove(info) + zc.close() + + +@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') +@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') +def test_aaaa_query(): + """Test that queries for AAAA records work and should respond right away.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_knownaaaservice._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) + zc.registry.async_add(info) + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + mcast_answers = list(question_answers.mcast_now) + assert mcast_answers[0].address == ipv6_address + # unregister + zc.registry.async_remove(info) + zc.close() + + +@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') +@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') +def test_a_and_aaaa_record_fate_sharing(): + """Test that queries for AAAA always return A records in the additionals and should respond right away.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_a-and-aaaa-service._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + ipv4_address = socket.inet_aton("10.0.1.2") + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address, ipv4_address] + ) + aaaa_record = info.dns_addresses(version=r.IPVersion.V6Only)[0] + a_record = info.dns_addresses(version=r.IPVersion.V4Only)[0] + + zc.registry.async_add(info) + + # Test AAAA query + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + additionals = set().union(*question_answers.mcast_now.values()) + assert aaaa_record in question_answers.mcast_now + assert a_record in additionals + assert len(question_answers.mcast_now) == 1 + assert len(additionals) == 1 + + # Test A query + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + additionals = set().union(*question_answers.mcast_now.values()) + assert a_record in question_answers.mcast_now + assert aaaa_record in additionals + assert len(question_answers.mcast_now) == 1 + assert len(additionals) == 1 + + # unregister + zc.registry.async_remove(info) + zc.close() + + +def test_unicast_response(): + """Ensure we send a unicast response when the source port is not the MDNS port.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + # register + zc.registry.async_add(info) + _clear_cache(zc) + + # query + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], True + ) + for answers in (question_answers.ucast, question_answers.mcast_aggregate): + has_srv = has_txt = has_a = has_aaaa = has_nsec = False + nbr_additionals = 0 + nbr_answers = len(answers) + additionals = set().union(*answers.values()) + for answer in additionals: + nbr_additionals += 1 + if answer.type == const._TYPE_SRV: + has_srv = True + elif answer.type == const._TYPE_TXT: + has_txt = True + elif answer.type == const._TYPE_A: + has_a = True + elif answer.type == const._TYPE_AAAA: + has_aaaa = True + elif answer.type == const._TYPE_NSEC: + has_nsec = True + # There will be one NSEC additional to indicate the lack of AAAA record + assert nbr_answers == 1 and nbr_additionals == 4 + assert has_srv and has_txt and has_a and has_nsec + assert not has_aaaa + + # unregister + zc.registry.async_remove(info) + zc.close() + + +@pytest.mark.asyncio +async def test_probe_answered_immediately(): + """Verify probes are responded to immediately.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + query.add_question(question) + query.add_authorative_answer(info.dns_pointer()) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.ucast + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + assert question_answers.mcast_now + + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True + query.add_question(question) + query.add_authorative_answer(info.dns_pointer()) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert question_answers.ucast + assert question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + zc.close() + + +def test_qu_response(): + """Handle multicast incoming with the QU bit set.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + other_type_ = "_notthesame._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + registration_name2 = f"{name}.{other_type_}" + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + other_type_, + registration_name2, + 80, + 0, + 0, + desc, + "ash-other.local.", + addresses=[socket.inet_aton("10.0.4.2")], + ) + # register + zc.register_service(info) + + def _validate_complete_response(answers): + has_srv = has_txt = has_a = has_aaaa = has_nsec = False + nbr_answers = len(answers.keys()) + additionals = set().union(*answers.values()) + nbr_additionals = len(additionals) + + for answer in additionals: + if answer.type == const._TYPE_SRV: + has_srv = True + elif answer.type == const._TYPE_TXT: + has_txt = True + elif answer.type == const._TYPE_A: + has_a = True + elif answer.type == const._TYPE_AAAA: + has_aaaa = True + elif answer.type == const._TYPE_NSEC: + has_nsec = True + assert nbr_answers == 1 and nbr_additionals == 4 + assert has_srv and has_txt and has_a and has_nsec + assert not has_aaaa + + # With QU should respond to only unicast when the answer has been recently multicast + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + _validate_complete_response(question_answers.ucast) + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + _clear_cache(zc) + # With QU should respond to only multicast since the response hasn't been seen since 75% of the ttl + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.ucast + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate + _validate_complete_response(question_answers.mcast_now) + + # With QU set and an authorative answer (probe) should respond to both unitcast and multicast since the response hasn't been seen since 75% of the ttl + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + query.add_authorative_answer(info2.dns_pointer()) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + _validate_complete_response(question_answers.ucast) + _validate_complete_response(question_answers.mcast_now) + + _inject_response( + zc, r.DNSIncoming(construct_outgoing_multicast_answers(question_answers.mcast_now).packets()[0]) + ) + # With the cache repopulated; should respond to only unicast when the answer has been recently multicast + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + _validate_complete_response(question_answers.ucast) + # unregister + zc.unregister_service(info) + zc.close() + + +def test_known_answer_supression(): + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_knownanswersv8._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + + now = current_time_millis() + _clear_cache(zc) + # Test PTR supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + generated.add_answer_at_time(info.dns_pointer(), now) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + # Test A supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN) + generated.add_question(question) + for dns_address in info.dns_addresses(): + generated.add_answer_at_time(dns_address, now) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + # Test NSEC record returned when there is no AAAA record and we expectly ask + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) + generated.add_question(question) + for dns_address in info.dns_addresses(): + generated.add_answer_at_time(dns_address, now) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + expected_nsec_record: r.DNSNsec = list(question_answers.mcast_now)[0] + assert const._TYPE_A not in expected_nsec_record.rdtypes + assert const._TYPE_AAAA in expected_nsec_record.rdtypes + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + # Test SRV supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + generated.add_answer_at_time(info.dns_service(), now) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + # Test TXT supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(registration_name, const._TYPE_TXT, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(registration_name, const._TYPE_TXT, const._CLASS_IN) + generated.add_question(question) + generated.add_answer_at_time(info.dns_text(), now) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + # unregister + zc.registry.async_remove(info) + zc.close() + + +def test_multi_packet_known_answer_supression(): + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_handlermultis._tcp.local." + name = "knownname" + name2 = "knownname2" + name3 = "knownname3" + + registration_name = f"{name}.{type_}" + registration2_name = f"{name2}.{type_}" + registration3_name = f"{name3}.{type_}" + + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + server_name2 = "ash-3.local." + server_name3 = "ash-4.local." + + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + info3 = ServiceInfo( + type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + zc.registry.async_add(info2) + zc.registry.async_add(info3) + + now = current_time_millis() + _clear_cache(zc) + # Test PTR supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + for _ in range(1000): + # Add so many answers we end up with another packet + generated.add_answer_at_time(info.dns_pointer(), now) + generated.add_answer_at_time(info2.dns_pointer(), now) + generated.add_answer_at_time(info3.dns_pointer(), now) + packets = generated.packets() + assert len(packets) > 1 + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + # unregister + zc.registry.async_remove(info) + zc.registry.async_remove(info2) + zc.registry.async_remove(info3) + zc.close() + + +def test_known_answer_supression_service_type_enumeration_query(): + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_otherknown._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + + type_2 = "_otherknown2._tcp.local." + name = "knownname" + registration_name2 = f"{name}.{type_2}" + desc = {'path': '/~paulsm/'} + server_name2 = "ash-3.local." + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info2) + now = current_time_millis() + _clear_cache(zc) + + # Test PTR supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + generated.add_answer_at_time( + r.DNSPointer( + const._SERVICE_TYPE_ENUMERATION_NAME, + const._TYPE_PTR, + const._CLASS_IN, + const._DNS_OTHER_TTL, + type_, + ), + now, + ) + generated.add_answer_at_time( + r.DNSPointer( + const._SERVICE_TYPE_ENUMERATION_NAME, + const._TYPE_PTR, + const._CLASS_IN, + const._DNS_OTHER_TTL, + type_2, + ), + now, + ) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + # unregister + zc.registry.async_remove(info) + zc.registry.async_remove(info2) + zc.close() + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_qu_response_only_sends_additionals_if_sends_answer(): + """Test that a QU response does not send additionals unless it sends the answer as well.""" + # instantiate a zeroconf instance + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + + type_ = "_addtest1._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + + type_2 = "_addtest2._tcp.local." + name = "knownname" + registration_name2 = f"{name}.{type_2}" + desc = {'path': '/~paulsm/'} + server_name2 = "ash-3.local." + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info2) + + ptr_record = info.dns_pointer() + + # Add the PTR record to the cache + zc.cache.async_add_records([ptr_record]) + + # Add the A record to the cache with 50% ttl remaining + a_record = info.dns_addresses()[0] + a_record.set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl) + assert not a_record.is_recent(current_time_millis()) + zc.cache.async_add_records([a_record]) + + # With QU should respond to only unicast when the answer has been recently multicast + # even if the additional has not been recently multicast + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + additionals = set().union(*question_answers.ucast.values()) + assert a_record in additionals + assert ptr_record in question_answers.ucast + + # Remove the 50% A record and add a 100% A record + zc.cache.async_remove_records([a_record]) + a_record = info.dns_addresses()[0] + assert a_record.is_recent(current_time_millis()) + zc.cache.async_add_records([a_record]) + # With QU should respond to only unicast when the answer has been recently multicast + # even if the additional has not been recently multicast + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + additionals = set().union(*question_answers.ucast.values()) + assert a_record in additionals + assert ptr_record in question_answers.ucast + + # Remove the 100% PTR record and add a 50% PTR record + zc.cache.async_remove_records([ptr_record]) + ptr_record.set_created_ttl(current_time_millis() - (ptr_record.ttl * 1000 / 2), ptr_record.ttl) + assert not ptr_record.is_recent(current_time_millis()) + zc.cache.async_add_records([ptr_record]) + # With QU should respond to only multicast since the has less + # than 75% of its ttl remaining + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.ucast + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + additionals = set().union(*question_answers.mcast_now.values()) + assert a_record in additionals + assert info.dns_text() in additionals + assert info.dns_service() in additionals + assert ptr_record in question_answers.mcast_now + + # Ask 2 QU questions, with info the PTR is at 50%, with info2 the PTR is at 100% + # We should get back a unicast reply for info2, but info should be multicasted since its within 75% of its TTL + # With QU should respond to only multicast since the has less + # than 75% of its ttl remaining + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + + question = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + zc.cache.async_add_records([info2.dns_pointer()]) # Add 100% TTL for info2 to the cache + + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + mcast_now_additionals = set().union(*question_answers.mcast_now.values()) + assert a_record in mcast_now_additionals + assert info.dns_text() in mcast_now_additionals + assert info.dns_addresses()[0] in mcast_now_additionals + assert info.dns_pointer() in question_answers.mcast_now + + ucast_additionals = set().union(*question_answers.ucast.values()) + assert info2.dns_pointer() in question_answers.ucast + assert info2.dns_text() in ucast_additionals + assert info2.dns_service() in ucast_additionals + assert info2.dns_addresses()[0] in ucast_additionals + + # unregister + zc.registry.async_remove(info) + await aiozc.async_close() + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_cache_flush_bit(): + """Test that the cache flush bit sets the TTL to one for matching records.""" + # instantiate a zeroconf instance + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + + type_ = "_cacheflush._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "server-uu1.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + a_record = info.dns_addresses()[0] + zc.cache.async_add_records([info.dns_pointer(), a_record, info.dns_text(), info.dns_service()]) + + info.addresses = [socket.inet_aton("10.0.1.5"), socket.inet_aton("10.0.1.6")] + new_records = info.dns_addresses() + for new_record in new_records: + assert new_record.unique is True + + original_a_record = zc.cache.async_get_unique(a_record) + # Do the run within 1s to verify the original record is not going to be expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in new_records: + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + assert zc.cache.async_get_unique(a_record) is original_a_record + assert original_a_record.ttl != 1 + for record in new_records: + assert zc.cache.async_get_unique(record) is not None + + original_a_record.created = current_time_millis() - 1001 + + # Do the run within 1s to verify the original record is not going to be expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in new_records: + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + assert original_a_record.ttl == 1 + for record in new_records: + assert zc.cache.async_get_unique(record) is not None + + cached_records = [zc.cache.async_get_unique(record) for record in new_records] + for record in cached_records: + record.created = current_time_millis() - 1001 + + fresh_address = socket.inet_aton("4.4.4.4") + info.addresses = [fresh_address] + # Do the run within 1s to verify the two new records get marked as expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in info.dns_addresses(): + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + for record in cached_records: + assert record.ttl == 1 + + for entry in zc.cache.async_all_by_details(server_name, const._TYPE_A, const._CLASS_IN): + if entry.address == fresh_address: + assert entry.ttl > 1 + else: + assert entry.ttl == 1 + + # Wait for the ttl 1 records to expire + await asyncio.sleep(1.1) + + loaded_info = r.ServiceInfo(type_, registration_name) + loaded_info.load_from_cache(zc) + assert loaded_info.addresses == info.addresses + + await aiozc.async_close() + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_record_update_manager_add_listener_callsback_existing_records(): + """Test that the RecordUpdateManager will callback existing records.""" + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc: Zeroconf = aiozc.zeroconf + updated = [] + + class MyListener(r.RecordUpdateListener): + """A RecordUpdateListener that does not implement update_records.""" + + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.RecordUpdate]) -> None: + """Update multiple records in one shot.""" + updated.extend(records) + + type_ = "_cacheflush._tcp.local." + name = "knownname" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + server_name = "server-uu1.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + a_record = info.dns_addresses()[0] + ptr_record = info.dns_pointer() + zc.cache.async_add_records([ptr_record, a_record, info.dns_text(), info.dns_service()]) + + listener = MyListener() + + zc.add_listener( + listener, + [ + r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN), + r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN), + ], + ) + await asyncio.sleep(0) # flush out the call_soon_threadsafe + + assert {record.new for record in updated} == {ptr_record, a_record} + + # The old records should be None so we trigger Add events + # in service browsers instead of Update events + assert {record.old for record in updated} == {None} + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_questions_query_handler_populates_the_question_history_from_qm_questions(): + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + now = current_time_millis() + _clear_cache(zc) + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + question.unicast = False + known_answer = r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + ) + generated.add_question(question) + generated.add_answer_at_time(known_answer, 0) + now = r.current_time_millis() + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + assert zc.question_history.suppresses(question, now, {known_answer}) + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_questions_query_handler_does_not_put_qu_questions_in_history(): + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + now = current_time_millis() + _clear_cache(zc) + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + question.unicast = True + known_answer = r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + ) + generated.add_question(question) + generated.add_answer_at_time(known_answer, 0) + now = r.current_time_millis() + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + assert not zc.question_history.suppresses(question, now, {known_answer}) + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_guard_against_low_ptr_ttl(): + """Ensure we enforce a minimum for PTR record ttls to avoid excessive refresh queries from ServiceBrowsers. + + Some poorly designed IoT devices can set excessively low PTR + TTLs would will cause ServiceBrowsers to flood the network + with excessive refresh queries. + """ + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + # Apple uses a 15s minimum TTL, however we do not have the same + # level of rate limit and safe guards so we use 1/4 of the recommended value + answer_with_low_ttl = r.DNSPointer( + "myservicelow_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + 2, + 'low.local.', + ) + answer_with_normal_ttl = r.DNSPointer( + "myservicelow_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'normal.local.', + ) + good_bye_answer = r.DNSPointer( + "myservicelow_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + 0, + 'goodbye.local.', + ) + # TTL should be adjusted to a safe value + response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + response.add_answer_at_time(answer_with_low_ttl, 0) + response.add_answer_at_time(answer_with_normal_ttl, 0) + response.add_answer_at_time(good_bye_answer, 0) + incoming = r.DNSIncoming(response.packets()[0]) + zc.record_manager.async_updates_from_response(incoming) + + incoming_answer_low = zc.cache.async_get_unique(answer_with_low_ttl) + assert incoming_answer_low.ttl == const._DNS_PTR_MIN_TTL + incoming_answer_normal = zc.cache.async_get_unique(answer_with_normal_ttl) + assert incoming_answer_normal.ttl == const._DNS_OTHER_TTL + assert zc.cache.async_get_unique(good_bye_answer) is None + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_duplicate_goodbye_answers_in_packet(): + """Ensure we do not throw an exception when there are duplicate goodbye records in a packet.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + answer_with_normal_ttl = r.DNSPointer( + "myservicelow_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'host.local.', + ) + good_bye_answer = r.DNSPointer( + "myservicelow_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + 0, + 'host.local.', + ) + response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + response.add_answer_at_time(answer_with_normal_ttl, 0) + incoming = r.DNSIncoming(response.packets()[0]) + zc.record_manager.async_updates_from_response(incoming) + + response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + response.add_answer_at_time(good_bye_answer, 0) + response.add_answer_at_time(good_bye_answer, 0) + incoming = r.DNSIncoming(response.packets()[0]) + zc.record_manager.async_updates_from_response(incoming) + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_response_aggregation_timings(run_isolated): + """Verify multicast respones are aggregated.""" + type_ = "_mservice._tcp.local." + type_2 = "_mservice2._tcp.local." + type_3 = "_mservice3._tcp.local." + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.zeroconf.async_wait_for_start() + + name = "xxxyyy" + registration_name = f"{name}.{type_}" + registration_name2 = f"{name}.{type_2}" + registration_name3 = f"{name}.{type_3}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.3")] + ) + info3 = ServiceInfo( + type_3, registration_name3, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.3")] + ) + aiozc.zeroconf.registry.async_add(info) + aiozc.zeroconf.registry.async_add(info2) + aiozc.zeroconf.registry.async_add(info3) + + query = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + query.add_question(question) + + query2 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question2 = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN) + query2.add_question(question2) + + query3 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question3 = r.DNSQuestion(info3.type, const._TYPE_PTR, const._CLASS_IN) + query3.add_question(question3) + + query4 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + query4.add_question(question) + query4.add_question(question2) + + zc = aiozc.zeroconf + protocol = zc.engine.protocols[0] + + with unittest.mock.patch.object(aiozc.zeroconf, "async_send") as send_mock: + protocol.datagram_received(query.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + protocol.datagram_received(query.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + await asyncio.sleep(0.7) + + # Should aggregate into a single answer with up to a 500ms + 120ms delay + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + zc.handle_response(incoming) + assert info.dns_pointer() in incoming.answers + assert info2.dns_pointer() in incoming.answers + send_mock.reset_mock() + + protocol.datagram_received(query3.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + await asyncio.sleep(0.3) + + # Should send within 120ms since there are no other + # answers to aggregate with + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + zc.handle_response(incoming) + assert info3.dns_pointer() in incoming.answers + send_mock.reset_mock() + + # Because the response was sent in the last second we need to make + # sure the next answer is delayed at least a second + aiozc.zeroconf.engine.protocols[0].datagram_received( + query4.packets()[0], ('127.0.0.1', const._MDNS_PORT) + ) + await asyncio.sleep(0.5) + + # After 0.5 seconds it should not have been sent + # Protect the network against excessive packet flooding + # https://datatracker.ietf.org/doc/html/rfc6762#section-14 + calls = send_mock.mock_calls + assert len(calls) == 0 + send_mock.reset_mock() + + await asyncio.sleep(1.2) + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + assert info.dns_pointer() in incoming.answers + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_response_aggregation_timings_multiple(run_isolated): + """Verify multicast responses that are aggregated do not take longer than 620ms to send. + + 620ms is the maximum random delay of 120ms and 500ms additional for aggregation.""" + type_2 = "_mservice2._tcp.local." + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.zeroconf.async_wait_for_start() + + name = "xxxyyy" + registration_name2 = f"{name}.{type_2}" + + desc = {'path': '/~paulsm/'} + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.3")] + ) + aiozc.zeroconf.registry.async_add(info2) + + query2 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question2 = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN) + query2.add_question(question2) + + zc = aiozc.zeroconf + protocol = zc.engine.protocols[0] + + with unittest.mock.patch.object(aiozc.zeroconf, "async_send") as send_mock, unittest.mock.patch.object( + protocol, "suppress_duplicate_packet", return_value=False + ): + send_mock.reset_mock() + protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + await asyncio.sleep(0.2) + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + zc.handle_response(incoming) + assert info2.dns_pointer() in incoming.answers + + send_mock.reset_mock() + protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + await asyncio.sleep(1.2) + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + zc.handle_response(incoming) + assert info2.dns_pointer() in incoming.answers + + send_mock.reset_mock() + protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + # The delay should increase with two packets and + # 900ms is beyond the maximum aggregation delay + # when there is no network protection delay + await asyncio.sleep(0.9) + calls = send_mock.mock_calls + assert len(calls) == 0 + + # 1000ms (1s network protection delays) + # - 900ms (already slept) + # + 120ms (maximum random delay) + # + 200ms (maximum protected aggregation delay) + # + 20ms (execution time) + await asyncio.sleep(millis_to_seconds(1000 - 900 + 120 + 200 + 20)) + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + zc.handle_response(incoming) + assert info2.dns_pointer() in incoming.answers + + +@pytest.mark.asyncio +async def test_response_aggregation_random_delay(): + """Verify the random delay for outgoing multicast will coalesce into a single group + + When the random delay is shorter than the last outgoing group, + the groups should be combined. + """ + type_ = "_mservice._tcp.local." + type_2 = "_mservice2._tcp.local." + type_3 = "_mservice3._tcp.local." + type_4 = "_mservice4._tcp.local." + type_5 = "_mservice5._tcp.local." + + name = "xxxyyy" + registration_name = f"{name}.{type_}" + registration_name2 = f"{name}.{type_2}" + registration_name3 = f"{name}.{type_3}" + registration_name4 = f"{name}.{type_4}" + registration_name5 = f"{name}.{type_5}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-1.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.3")] + ) + info3 = ServiceInfo( + type_3, registration_name3, 80, 0, 0, desc, "ash-3.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info4 = ServiceInfo( + type_4, registration_name4, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info5 = ServiceInfo( + type_5, registration_name5, 80, 0, 0, desc, "ash-5.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + mocked_zc = unittest.mock.MagicMock() + outgoing_queue = MulticastOutgoingQueue(mocked_zc, 0, 500) + + now = current_time_millis() + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (500, 600)): + outgoing_queue.async_add(now, {info.dns_pointer(): set()}) + + # The second group should always be coalesced into first group since it will always come before + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (300, 400)): + outgoing_queue.async_add(now, {info2.dns_pointer(): set()}) + + # The third group should always be coalesced into first group since it will always come before + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (100, 200)): + outgoing_queue.async_add(now, {info3.dns_pointer(): set(), info4.dns_pointer(): set()}) + + assert len(outgoing_queue.queue) == 1 + assert info.dns_pointer() in outgoing_queue.queue[0].answers + assert info2.dns_pointer() in outgoing_queue.queue[0].answers + assert info3.dns_pointer() in outgoing_queue.queue[0].answers + assert info4.dns_pointer() in outgoing_queue.queue[0].answers + + # The forth group should not be coalesced because its scheduled after the last group in the queue + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (700, 800)): + outgoing_queue.async_add(now, {info5.dns_pointer(): set()}) + + assert len(outgoing_queue.queue) == 2 + assert info.dns_pointer() not in outgoing_queue.queue[1].answers + assert info2.dns_pointer() not in outgoing_queue.queue[1].answers + assert info3.dns_pointer() not in outgoing_queue.queue[1].answers + assert info4.dns_pointer() not in outgoing_queue.queue[1].answers + assert info5.dns_pointer() in outgoing_queue.queue[1].answers + + +@pytest.mark.asyncio +async def test_future_answers_are_removed_on_send(): + """Verify any future answers scheduled to be sent are removed when we send.""" + type_ = "_mservice._tcp.local." + type_2 = "_mservice2._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + registration_name2 = f"{name}.{type_2}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-1.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.3")] + ) + mocked_zc = unittest.mock.MagicMock() + outgoing_queue = MulticastOutgoingQueue(mocked_zc, 0, 0) + + now = current_time_millis() + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (1, 1)): + outgoing_queue.async_add(now, {info.dns_pointer(): set()}) + + assert len(outgoing_queue.queue) == 1 + + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (2, 2)): + outgoing_queue.async_add(now, {info.dns_pointer(): set()}) + + assert len(outgoing_queue.queue) == 2 + + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (1000, 1000)): + outgoing_queue.async_add(now, {info2.dns_pointer(): set()}) + outgoing_queue.async_add(now, {info.dns_pointer(): set()}) + + assert len(outgoing_queue.queue) == 3 + + await asyncio.sleep(0.1) + outgoing_queue.async_ready() + + assert len(outgoing_queue.queue) == 1 + # The answer should get removed because we just sent it + assert info.dns_pointer() not in outgoing_queue.queue[0].answers + + # But the one we have not sent yet shoudl still go out later + assert info2.dns_pointer() in outgoing_queue.queue[0].answers diff --git a/tests/test_history.py b/tests/test_history.py new file mode 100644 index 00000000..9da6b567 --- /dev/null +++ b/tests/test_history.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python + + +"""Unit tests for _history.py.""" + +from zeroconf._history import QuestionHistory +import zeroconf as r +import zeroconf.const as const + + +def test_question_suppression(): + history = QuestionHistory() + + question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + now = r.current_time_millis() + other_known_answers = { + r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + ) + } + our_known_answers = { + r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-us._hap._tcp.local.' + ) + } + + history.add_question_at_time(question, now, other_known_answers) + + # Verify the question is suppressed if the known answers are the same + assert history.suppresses(question, now, other_known_answers) + + # Verify the question is suppressed if we know the answer to all the known answers + assert history.suppresses(question, now, other_known_answers | our_known_answers) + + # Verify the question is not suppressed if our known answers do no include the ones in the last question + assert not history.suppresses(question, now, set()) + + # Verify the question is not suppressed if our known answers do no include the ones in the last question + assert not history.suppresses(question, now, our_known_answers) + + # Verify the question is no longer suppressed after 1s + assert not history.suppresses(question, now + 1000, other_known_answers) + + +def test_question_expire(): + history = QuestionHistory() + + question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + now = r.current_time_millis() + other_known_answers = { + r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + ) + } + history.add_question_at_time(question, now, other_known_answers) + + # Verify the question is suppressed if the known answers are the same + assert history.suppresses(question, now, other_known_answers) + + history.async_expire(now) + + # Verify the question is suppressed if the known answers are the same since the cache hasn't expired + assert history.suppresses(question, now, other_known_answers) + + history.async_expire(now + 1000) + + # Verify the question not longer suppressed since the cache has expired + assert not history.suppresses(question, now, other_known_answers) diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 00000000..1d1f7086 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python + + +""" Unit tests for zeroconf.py """ + +import logging +import socket +import time +import unittest +import unittest.mock +from unittest.mock import patch + +import zeroconf as r +from zeroconf import ServiceInfo, Zeroconf, const + +from . import _inject_responses + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class Names(unittest.TestCase): + def test_long_name(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + question = r.DNSQuestion( + "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", const._TYPE_SRV, const._CLASS_IN + ) + generated.add_question(question) + r.DNSIncoming(generated.packets()[0]) + + def test_exceedingly_long_name(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + name = "%slocal." % ("part." * 1000) + question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + r.DNSIncoming(generated.packets()[0]) + + def test_extra_exceedingly_long_name(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + name = "%slocal." % ("part." * 4000) + question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + r.DNSIncoming(generated.packets()[0]) + + def test_exceedingly_long_name_part(self): + name = "%s.local." % ("a" * 1000) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + self.assertRaises(r.NamePartTooLongException, generated.packets) + + def test_same_name(self): + name = "paired.local." + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + generated.add_question(question) + r.DNSIncoming(generated.packets()[0]) + + def test_verify_name_change_with_lots_of_names(self): + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # create a bunch of servers + type_ = "_my-service._tcp.local." + name = 'a wonderful service' + server_count = 300 + self.generate_many_hosts(zc, type_, name, server_count) + + # verify that name changing works + self.verify_name_change(zc, type_, name, server_count) + + zc.close() + + def test_large_packet_exception_log_handling(self): + """Verify we downgrade debug after warning.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + with patch('zeroconf._logger.log.warning') as mocked_log_warn, patch( + 'zeroconf._logger.log.debug' + ) as mocked_log_debug: + # now that we have a long packet in our possession, let's verify the + # exception handling. + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) + out.data.append(b'\0' * 10000) + + # mock the zeroconf logger and check for the correct logging backoff + call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count + # try to send an oversized packet + zc.send(out) + assert mocked_log_warn.call_count == call_counts[0] + zc.send(out) + assert mocked_log_warn.call_count == call_counts[0] + + # mock the zeroconf logger and check for the correct logging backoff + call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count + # force receive on oversized packet + zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) + zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) + time.sleep(0.3) + r.log.debug( + 'warn %d debug %d was %s', + mocked_log_warn.call_count, + mocked_log_debug.call_count, + call_counts, + ) + assert mocked_log_debug.call_count > call_counts[0] + + # close our zeroconf which will close the sockets + zc.close() + + def verify_name_change(self, zc, type_, name, number_hosts): + desc = {'path': '/~paulsm/'} + info_service = ServiceInfo( + type_, + f'{name}.{type_}', + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + # verify name conflict + self.assertRaises(r.NonUniqueNameException, zc.register_service, info_service) + + # verify no name conflict https://tools.ietf.org/html/rfc6762#section-6.6 + zc.register_service(info_service, cooperating_responders=True) + + # Create a new object since allow_name_change will mutate the + # original object and then we will have the wrong service + # in the registry + info_service2 = ServiceInfo( + type_, + f'{name}.{type_}', + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + zc.register_service(info_service2, allow_name_change=True) + assert info_service2.name.split('.')[0] == '%s-%d' % (name, number_hosts + 1) + + def generate_many_hosts(self, zc, type_, name, number_hosts): + block_size = 25 + number_hosts = int((number_hosts - 1) / block_size + 1) * block_size + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) + for i in range(1, number_hosts + 1): + next_name = name if i == 1 else '%s-%d' % (name, i) + self.generate_host(out, next_name, type_) + + _inject_responses(zc, [r.DNSIncoming(packet) for packet in out.packets()]) + + @staticmethod + def generate_host(out, host_name, type_): + name = '.'.join((host_name, type_)) + out.add_answer_at_time( + r.DNSPointer(type_, const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, name), 0 + ) + out.add_answer_at_time( + r.DNSService( + type_, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + name, + ), + 0, + ) diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 00000000..2d8bbb08 --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python + + +"""Unit tests for logger.py.""" + +import logging +from unittest.mock import patch +from zeroconf._logger import QuietLogger, set_logger_level_if_unset + + +def test_loading_logger(): + """Test loading logger does not change level unless it is unset.""" + log = logging.getLogger('zeroconf') + log.setLevel(logging.CRITICAL) + set_logger_level_if_unset() + log = logging.getLogger('zeroconf') + assert log.level == logging.CRITICAL + + log = logging.getLogger('zeroconf') + log.setLevel(logging.NOTSET) + set_logger_level_if_unset() + log = logging.getLogger('zeroconf') + assert log.level == logging.WARNING + + +def test_log_warning_once(): + """Test we only log with warning level once.""" + quiet_logger = QuietLogger() + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_warning_once("the warning") + + assert mock_log_warning.mock_calls + assert not mock_log_debug.mock_calls + + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_warning_once("the warning") + + assert not mock_log_warning.mock_calls + assert mock_log_debug.mock_calls + + +def test_log_exception_warning(): + """Test we only log with warning level once.""" + quiet_logger = QuietLogger() + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_exception_warning("the exception warning") + + assert mock_log_warning.mock_calls + assert not mock_log_debug.mock_calls + + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_exception_warning("the exception warning") + + assert not mock_log_warning.mock_calls + assert mock_log_debug.mock_calls + + +def test_log_exception_once(): + """Test we only log with warning level once.""" + quiet_logger = QuietLogger() + exc = Exception() + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_exception_once(exc, "the exceptional exception warning") + + assert mock_log_warning.mock_calls + assert not mock_log_debug.mock_calls + + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_exception_once(exc, "the exceptional exception warning") + + assert not mock_log_warning.mock_calls + assert mock_log_debug.mock_calls diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 00000000..55dbbe4d --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,1026 @@ +#!/usr/bin/env python + + +""" Unit tests for zeroconf._protocol """ + +import copy +import logging +import os +import socket +import struct +import unittest +import unittest.mock +from typing import cast + +import zeroconf as r +from zeroconf import DNSIncoming, const, current_time_millis +from zeroconf import ( + DNSHinfo, + DNSText, +) + +from . import has_working_ipv6 + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class PacketGeneration(unittest.TestCase): + def test_parse_own_packet_simple(self): + generated = r.DNSOutgoing(0) + r.DNSIncoming(generated.packets()[0]) + + def test_parse_own_packet_simple_unicast(self): + generated = r.DNSOutgoing(0, False) + r.DNSIncoming(generated.packets()[0]) + + def test_parse_own_packet_flags(self): + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + r.DNSIncoming(generated.packets()[0]) + + def test_parse_own_packet_question(self): + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + generated.add_question(r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN)) + r.DNSIncoming(generated.packets()[0]) + + def test_parse_own_packet_nsec(self): + answer = r.DNSNsec( + 'eufy HomeBase2-2464._hap._tcp.local.', + const._TYPE_NSEC, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'eufy HomeBase2-2464._hap._tcp.local.', + [const._TYPE_TXT, const._TYPE_SRV], + ) + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time(answer, 0) + parsed = r.DNSIncoming(generated.packets()[0]) + assert answer in parsed.answers + + # Types > 255 should be ignored + answer_invalid_types = r.DNSNsec( + 'eufy HomeBase2-2464._hap._tcp.local.', + const._TYPE_NSEC, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'eufy HomeBase2-2464._hap._tcp.local.', + [const._TYPE_TXT, const._TYPE_SRV, 1000], + ) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time(answer_invalid_types, 0) + parsed = r.DNSIncoming(generated.packets()[0]) + assert answer in parsed.answers + + def test_parse_own_packet_response(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSService( + "æøå.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + 0, + ) + parsed = r.DNSIncoming(generated.packets()[0]) + assert len(generated.answers) == 1 + assert len(generated.answers) == len(parsed.answers) + + def test_adding_empty_answer(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + None, + 0, + ) + generated.add_answer_at_time( + r.DNSService( + "æøå.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + 0, + ) + parsed = r.DNSIncoming(generated.packets()[0]) + assert len(generated.answers) == 1 + assert len(generated.answers) == len(parsed.answers) + + def test_adding_expired_answer(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSService( + "æøå.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + current_time_millis() + 1000000, + ) + parsed = r.DNSIncoming(generated.packets()[0]) + assert len(generated.answers) == 0 + assert len(generated.answers) == len(parsed.answers) + + def test_match_question(self): + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + parsed = r.DNSIncoming(generated.packets()[0]) + assert len(generated.questions) == 1 + assert len(generated.questions) == len(parsed.questions) + assert question == parsed.questions[0] + + def test_suppress_answer(self): + query_generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) + query_generated.add_question(question) + answer1 = r.DNSService( + "testname1.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ) + staleanswer2 = r.DNSService( + "testname2.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL / 2, + 0, + 0, + 80, + "foo.local.", + ) + answer2 = r.DNSService( + "testname2.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ) + query_generated.add_answer_at_time(answer1, 0) + query_generated.add_answer_at_time(staleanswer2, 0) + query = r.DNSIncoming(query_generated.packets()[0]) + + # Should be suppressed + response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + response.add_answer(query, answer1) + assert len(response.answers) == 0 + + # Should not be suppressed, TTL in query is too short + response.add_answer(query, answer2) + assert len(response.answers) == 1 + + # Should not be suppressed, name is different + tmp = copy.copy(answer1) + tmp.key = "testname3.local." + tmp.name = "testname3.local." + response.add_answer(query, tmp) + assert len(response.answers) == 2 + + # Should not be suppressed, type is different + tmp = copy.copy(answer1) + tmp.type = const._TYPE_A + response.add_answer(query, tmp) + assert len(response.answers) == 3 + + # Should not be suppressed, class is different + tmp = copy.copy(answer1) + tmp.class_ = const._CLASS_NONE + response.add_answer(query, tmp) + assert len(response.answers) == 4 + + # ::TODO:: could add additional tests for DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService + + def test_dns_hinfo(self): + generated = r.DNSOutgoing(0) + generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os')) + parsed = r.DNSIncoming(generated.packets()[0]) + answer = cast(r.DNSHinfo, parsed.answers[0]) + assert answer.cpu == 'cpu' + assert answer.os == 'os' + + generated = r.DNSOutgoing(0) + generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) + self.assertRaises(r.NamePartTooLongException, generated.packets) + + def test_many_questions(self): + """Test many questions get seperated into multiple packets.""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + questions = [] + for i in range(100): + question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + questions.append(question) + assert len(generated.questions) == 100 + + packets = generated.packets() + assert len(packets) == 2 + assert len(packets[0]) < const._MAX_MSG_TYPICAL + assert len(packets[1]) < const._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 85 + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 15 + + def test_many_questions_with_many_known_answers(self): + """Test many questions and known answers get seperated into multiple packets.""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + questions = [] + for _ in range(30): + question = r.DNSQuestion(f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + questions.append(question) + assert len(generated.questions) == 30 + now = current_time_millis() + for _ in range(200): + known_answer = r.DNSPointer( + "myservice{i}_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + '123.local.', + ) + generated.add_answer_at_time(known_answer, now) + packets = generated.packets() + assert len(packets) == 3 + assert len(packets[0]) <= const._MAX_MSG_TYPICAL + assert len(packets[1]) <= const._MAX_MSG_TYPICAL + assert len(packets[2]) <= const._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 30 + assert len(parsed1.answers) == 88 + assert parsed1.truncated + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 0 + assert len(parsed2.answers) == 101 + assert parsed2.truncated + parsed3 = r.DNSIncoming(packets[2]) + assert len(parsed3.questions) == 0 + assert len(parsed3.answers) == 11 + assert not parsed3.truncated + + def test_massive_probe_packet_split(self): + """Test probe with many authorative answers.""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + questions = [] + for _ in range(30): + question = r.DNSQuestion( + f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN | const._CLASS_UNIQUE + ) + generated.add_question(question) + questions.append(question) + assert len(generated.questions) == 30 + now = current_time_millis() + for _ in range(200): + authorative_answer = r.DNSPointer( + "myservice{i}_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + '123.local.', + ) + generated.add_authorative_answer(authorative_answer) + packets = generated.packets() + assert len(packets) == 3 + assert len(packets[0]) <= const._MAX_MSG_TYPICAL + assert len(packets[1]) <= const._MAX_MSG_TYPICAL + assert len(packets[2]) <= const._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert parsed1.questions[0].unicast is True + assert len(parsed1.questions) == 30 + assert parsed1.num_authorities == 88 + assert parsed1.truncated + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 0 + assert parsed2.num_authorities == 101 + assert parsed2.truncated + parsed3 = r.DNSIncoming(packets[2]) + assert len(parsed3.questions) == 0 + assert parsed3.num_authorities == 11 + assert not parsed3.truncated + + def test_only_one_answer_can_by_large(self): + """Test that only the first answer in each packet can be large. + + https://datatracker.ietf.org/doc/html/rfc6762#section-17 + """ + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + query = r.DNSIncoming(r.DNSOutgoing(const._FLAGS_QR_QUERY).packets()[0]) + for i in range(3): + generated.add_answer( + query, + r.DNSText( + "zoom._hap._tcp.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 1200, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==' * 100, + ), + ) + generated.add_answer( + query, + r.DNSService( + "testname1.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + ) + assert len(generated.answers) == 4 + + packets = generated.packets() + assert len(packets) == 4 + assert len(packets[0]) <= const._MAX_MSG_ABSOLUTE + assert len(packets[0]) > const._MAX_MSG_TYPICAL + + assert len(packets[1]) <= const._MAX_MSG_ABSOLUTE + assert len(packets[1]) > const._MAX_MSG_TYPICAL + + assert len(packets[2]) <= const._MAX_MSG_ABSOLUTE + assert len(packets[2]) > const._MAX_MSG_TYPICAL + + assert len(packets[3]) <= const._MAX_MSG_TYPICAL + + for packet in packets: + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + def test_questions_do_not_end_up_every_packet(self): + """Test that questions are not sent again when multiple packets are needed. + + https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 + Sometimes a Multicast DNS querier will already have too many answers + to fit in the Known-Answer Section of its query packets.... It MUST + immediately follow the packet with another query packet containing no + questions and as many more Known-Answer records as will fit. + """ + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + for i in range(35): + question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + answer = r.DNSService( + f"testname{i}.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + f"foo{i}.local.", + ) + generated.add_answer_at_time(answer, 0) + + assert len(generated.questions) == 35 + assert len(generated.answers) == 35 + + packets = generated.packets() + assert len(packets) == 2 + assert len(packets[0]) <= const._MAX_MSG_TYPICAL + assert len(packets[1]) <= const._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 35 + assert len(parsed1.answers) == 33 + + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 0 + assert len(parsed2.answers) == 2 + + +class PacketForm(unittest.TestCase): + def test_transaction_id(self): + """ID must be zero in a DNS-SD packet""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + bytes = generated.packets()[0] + id = bytes[0] << 8 | bytes[1] + assert id == 0 + + def test_setting_id(self): + """Test setting id in the constructor""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY, id_=4444) + assert generated.id == 4444 + + def test_query_header_bits(self): + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + bytes = generated.packets()[0] + flags = bytes[2] << 8 | bytes[3] + assert flags == 0x0 + + def test_response_header_bits(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + bytes = generated.packets()[0] + flags = bytes[2] << 8 | bytes[3] + assert flags == 0x8000 + + def test_numbers(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + bytes = generated.packets()[0] + (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) + assert num_questions == 0 + assert num_answers == 0 + assert num_authorities == 0 + assert num_additionals == 0 + + def test_numbers_questions(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) + for i in range(10): + generated.add_question(question) + bytes = generated.packets()[0] + (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) + assert num_questions == 10 + assert num_answers == 0 + assert num_authorities == 0 + assert num_additionals == 0 + + +class TestDnsIncoming(unittest.TestCase): + def test_incoming_exception_handling(self): + generated = r.DNSOutgoing(0) + packet = generated.packets()[0] + packet = packet[:8] + b'deadbeef' + packet[8:] + parsed = r.DNSIncoming(packet) + parsed = r.DNSIncoming(packet) + assert parsed.valid is False + + def test_incoming_unknown_type(self): + generated = r.DNSOutgoing(0) + answer = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + generated.add_additional_answer(answer) + packet = generated.packets()[0] + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 0 + assert parsed.is_query() != parsed.is_response() + + def test_incoming_circular_reference(self): + assert not r.DNSIncoming( + bytes.fromhex( + '01005e0000fb542a1bf0577608004500006897934000ff11d81bc0a86a31e00000fb' + '14e914e90054f9b2000084000000000100000000095f7365727669636573075f646e' + '732d7364045f756470056c6f63616c00000c0001000011940018105f73706f746966' + '792d636f6e6e656374045f746370c023' + ) + ).valid + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_incoming_ipv6(self): + addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com + packed = socket.inet_pton(socket.AF_INET6, addr) + generated = r.DNSOutgoing(0) + answer = r.DNSAddress('domain', const._TYPE_AAAA, const._CLASS_IN | const._CLASS_UNIQUE, 1, packed) + generated.add_additional_answer(answer) + packet = generated.packets()[0] + parsed = r.DNSIncoming(packet) + record = parsed.answers[0] + assert isinstance(record, r.DNSAddress) + assert record.address == packed + + +def test_dns_compression_rollback_for_corruption(): + """Verify rolling back does not lead to dns compression corruption.""" + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) + address = socket.inet_pton(socket.AF_INET, "192.168.208.5") + + additionals = [ + { + "name": "HASS Bridge ZJWH FF5137._hap._tcp.local.", + "address": address, + "port": 51832, + "text": b"\x13md=HASS Bridge" + b" ZJWH\x06pv=1.0\x14id=01:6B:30:FF:51:37\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=L0m/aQ==", + }, + { + "name": "HASS Bridge 3K9A C2582A._hap._tcp.local.", + "address": address, + "port": 51834, + "text": b"\x13md=HASS Bridge" + b" 3K9A\x06pv=1.0\x14id=E2:AA:5B:C2:58:2A\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=b2CnzQ==", + }, + { + "name": "Master Bed TV CEDB27._hap._tcp.local.", + "address": address, + "port": 51830, + "text": b"\x10md=Master Bed" + b" TV\x06pv=1.0\x14id=9E:B7:44:CE:DB:27\x05c#=18\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=CVj1kw==", + }, + { + "name": "Living Room TV 921B77._hap._tcp.local.", + "address": address, + "port": 51833, + "text": b"\x11md=Living Room" + b" TV\x06pv=1.0\x14id=11:61:E7:92:1B:77\x05c#=17\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=qU77SQ==", + }, + { + "name": "HASS Bridge ZC8X FF413D._hap._tcp.local.", + "address": address, + "port": 51829, + "text": b"\x13md=HASS Bridge" + b" ZC8X\x06pv=1.0\x14id=96:14:45:FF:41:3D\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=b0QZlg==", + }, + { + "name": "HASS Bridge WLTF 4BE61F._hap._tcp.local.", + "address": address, + "port": 51837, + "text": b"\x13md=HASS Bridge" + b" WLTF\x06pv=1.0\x14id=E0:E7:98:4B:E6:1F\x04c#=2\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=ahAISA==", + }, + { + "name": "FrontdoorCamera 8941D1._hap._tcp.local.", + "address": address, + "port": 54898, + "text": b"\x12md=FrontdoorCamera\x06pv=1.0\x14id=9F:B7:DC:89:41:D1\x04c#=2\x04" + b"s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=0+MXmA==", + }, + { + "name": "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", + "address": address, + "port": 51836, + "text": b"\x13md=HASS Bridge" + b" W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=6fLM5A==", + }, + { + "name": "HASS Bridge Y9OO EFF0A7._hap._tcp.local.", + "address": address, + "port": 51838, + "text": b"\x13md=HASS Bridge" + b" Y9OO\x06pv=1.0\x14id=D3:FE:98:EF:F0:A7\x04c#=2\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=u3bdfw==", + }, + { + "name": "Snooze Room TV 6B89B0._hap._tcp.local.", + "address": address, + "port": 51835, + "text": b"\x11md=Snooze Room" + b" TV\x06pv=1.0\x14id=5F:D5:70:6B:89:B0\x05c#=17\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=xNTqsg==", + }, + { + "name": "AlexanderHomeAssistant 74651D._hap._tcp.local.", + "address": address, + "port": 54811, + "text": b"\x19md=AlexanderHomeAssistant\x06pv=1.0\x14id=59:8A:0B:74:65:1D\x05" + b"c#=14\x04s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=ccZLPA==", + }, + { + "name": "HASS Bridge OS95 39C053._hap._tcp.local.", + "address": address, + "port": 51831, + "text": b"\x13md=HASS Bridge" + b" OS95\x06pv=1.0\x14id=7E:8C:E6:39:C0:53\x05c#=12\x04s#=1\x04ff=0\x04ci=2" + b"\x04sf=0\x0bsh=Xfe5LQ==", + }, + ] + + out.add_answer_at_time( + DNSText( + "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + for record in additionals: + out.add_additional_answer( + r.DNSService( + record["name"], # type: ignore + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + record["port"], # type: ignore + record["name"], # type: ignore + ) + ) + out.add_additional_answer( + r.DNSText( + record["name"], # type: ignore + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + record["text"], # type: ignore + ) + ) + out.add_additional_answer( + r.DNSAddress( + record["name"], # type: ignore + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + record["address"], # type: ignore + ) + ) + + for packet in out.packets(): + # Verify we can process the packets we created to + # ensure there is no corruption with the dns compression + incoming = r.DNSIncoming(packet) + assert incoming.valid is True + assert ( + len(incoming.answers) + == incoming.num_answers + incoming.num_authorities + incoming.num_additionals + ) + + +def test_tc_bit_in_query_packet(): + """Verify the TC bit is set when known answers exceed the packet size.""" + out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + type_ = "_hap._tcp.local." + out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) + + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + packets = out.packets() + assert len(packets) == 3 + + first_packet = r.DNSIncoming(packets[0]) + assert first_packet.truncated + assert first_packet.valid is True + + second_packet = r.DNSIncoming(packets[1]) + assert second_packet.truncated + assert second_packet.valid is True + + third_packet = r.DNSIncoming(packets[2]) + assert not third_packet.truncated + assert third_packet.valid is True + + +def test_tc_bit_not_set_in_answer_packet(): + """Verify the TC bit is not set when there are no questions and answers exceed the packet size.""" + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + packets = out.packets() + assert len(packets) == 3 + + first_packet = r.DNSIncoming(packets[0]) + assert not first_packet.truncated + assert first_packet.valid is True + + second_packet = r.DNSIncoming(packets[1]) + assert not second_packet.truncated + assert second_packet.valid is True + + third_packet = r.DNSIncoming(packets[2]) + assert not third_packet.truncated + assert third_packet.valid is True + + +# 4003 15.973052 192.168.107.68 224.0.0.251 MDNS 76 Standard query 0xffc4 PTR _raop._tcp.local, "QM" question +def test_qm_packet_parser(): + """Test we can parse a query packet with the QM bit.""" + qm_packet = ( + b'\xff\xc4\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x05_raop\x04_tcp\x05local\x00\x00\x0c\x00\x01' + ) + parsed = DNSIncoming(qm_packet) + assert parsed.questions[0].unicast is False + assert ",QM," in str(parsed.questions[0]) + + +# 389951 1450.577370 192.168.107.111 224.0.0.251 MDNS 115 Standard query 0x0000 PTR _companion-link._tcp.local, "QU" question OPT +def test_qu_packet_parser(): + """Test we can parse a query packet with the QU bit.""" + qu_packet = b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x01\x0f_companion-link\x04_tcp\x05local\x00\x00\x0c\x80\x01\x00\x00)\x05\xa0\x00\x00\x11\x94\x00\x12\x00\x04\x00\x0e\x00dz{\x8a6\x9czF\x84,\xcaQ\xff' + parsed = DNSIncoming(qu_packet) + assert parsed.questions[0].unicast is True + assert ",QU," in str(parsed.questions[0]) + + +def test_parse_packet_with_nsec_record(): + """Test we can parse a packet with an NSEC record.""" + nsec_packet = ( + b"\x00\x00\x84\x00\x00\x00\x00\x01\x00\x00\x00\x03\x08_meshcop\x04_udp\x05local\x00\x00\x0c\x00" + b"\x01\x00\x00\x11\x94\x00\x0f\x0cMyHome54 (2)\xc0\x0c\xc0+\x00\x10\x80\x01\x00\x00\x11\x94\x00" + b")\x0bnn=MyHome54\x13xp=695034D148CC4784\x08tv=0.0.0\xc0+\x00!\x80\x01\x00\x00\x00x\x00\x15\x00" + b"\x00\x00\x00\xc0'\x0cMaster-Bed-2\xc0\x1a\xc0+\x00/\x80\x01\x00\x00\x11\x94\x00\t\xc0+\x00\x05" + b"\x00\x00\x80\x00@" + ) + parsed = DNSIncoming(nsec_packet) + nsec_record = parsed.answers[3] + assert "nsec," in str(nsec_record) + assert nsec_record.rdtypes == [16, 33] + assert nsec_record.next_name == "MyHome54 (2)._meshcop._udp.local." + + +def test_records_same_packet_share_fate(): + """Test records in the same packet all have the same created time.""" + out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + type_ = "_hap._tcp.local." + out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) + + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + for packet in out.packets(): + dnsin = DNSIncoming(packet) + first_time = dnsin.answers[0].created + for answer in dnsin.answers: + assert answer.created == first_time + + +def test_dns_compression_invalid_skips_bad_name_compress_in_question(): + """Test our wire parser can skip bad compression in questions.""" + packet = ( + b'\x00\x00\x00\x00\x00\x04\x00\x00\x00\x07\x00\x00\x11homeassistant1128\x05l' + b'ocal\x00\x00\xff\x00\x014homeassistant1128 [534a4794e5ed41879ecf012252d3e02' + b'a]\x0c_workstation\x04_tcp\xc0\x1e\x00\xff\x00\x014homeassistant1127 [534a47' + b'94e5ed41879ecf012252d3e02a]\xc0^\x00\xff\x00\x014homeassistant1123 [534a479' + b'4e5ed41879ecf012252d3e02a]\xc0^\x00\xff\x00\x014homeassistant1118 [534a4794' + b'e5ed41879ecf012252d3e02a]\xc0^\x00\xff\x00\x01\xc0\x0c\x00\x01\x80' + b'\x01\x00\x00\x00x\x00\x04\xc0\xa8<\xc3\xc0v\x00\x10\x80\x01\x00\x00\x00' + b'x\x00\x01\x00\xc0v\x00!\x80\x01\x00\x00\x00x\x00\x1f\x00\x00\x00\x00' + b'\x00\x00\x11homeassistant1127\x05local\x00\xc0\xb1\x00\x10\x80' + b'\x01\x00\x00\x00x\x00\x01\x00\xc0\xb1\x00!\x80\x01\x00\x00\x00x\x00\x1f' + b'\x00\x00\x00\x00\x00\x00\x11homeassistant1123\x05local\x00\xc0)\x00\x10\x80' + b'\x01\x00\x00\x00x\x00\x01\x00\xc0)\x00!\x80\x01\x00\x00\x00x\x00\x1f' + b'\x00\x00\x00\x00\x00\x00\x11homeassistant1128\x05local\x00' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.questions) == 4 + + +def test_dns_compression_all_invalid(caplog): + """Test our wire parser can skip all invalid data.""" + packet = ( + b'\x00\x00\x84\x00\x00\x00\x00\x01\x00\x00\x00\x00!roborock-vacuum-s5e_miio416' + b'112328\x00\x00/\x80\x01\x00\x00\x00x\x00\t\xc0P\x00\x05@\x00\x00\x00\x00' + ) + parsed = r.DNSIncoming(packet, ("2.4.5.4", 5353)) + assert len(parsed.questions) == 0 + assert len(parsed.answers) == 0 + + assert " Unable to parse; skipping record" in caplog.text + + +def test_invalid_next_name_ignored(): + """Test our wire parser does not throw an an invalid next name. + + The RFC states it should be ignored when used with mDNS. + """ + packet = ( + b'\x00\x00\x00\x00\x00\x01\x00\x02\x00\x00\x00\x00\x07Android\x05local\x00\x00' + b'\xff\x00\x01\xc0\x0c\x00/\x00\x01\x00\x00\x00x\x00\x08\xc02\x00\x04@' + b'\x00\x00\x08\xc0\x0c\x00\x01\x00\x01\x00\x00\x00x\x00\x04\xc0\xa8X<' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.questions) == 1 + assert len(parsed.answers) == 2 + + +def test_dns_compression_invalid_skips_record(): + """Test our wire parser can skip records we do not know how to parse.""" + packet = ( + b"\x00\x00\x84\x00\x00\x00\x00\x06\x00\x00\x00\x00\x04_hap\x04_tcp\x05local\x00\x00\x0c" + b"\x00\x01\x00\x00\x11\x94\x00\x16\x13eufy HomeBase2-2464\xc0\x0c\x04Eufy\xc0\x16\x00/" + b"\x80\x01\x00\x00\x00x\x00\x08\xc0\xa6\x00\x04@\x00\x00\x08\xc0'\x00/\x80\x01\x00\x00" + b"\x11\x94\x00\t\xc0'\x00\x05\x00\x00\x80\x00@\xc0=\x00\x01\x80\x01\x00\x00\x00x\x00\x04" + b"\xc0\xa8Dp\xc0'\x00!\x80\x01\x00\x00\x00x\x00\x08\x00\x00\x00\x00\xd1_\xc0=\xc0'\x00" + b"\x10\x80\x01\x00\x00\x11\x94\x00K\x04c#=1\x04ff=2\x14id=38:71:4F:6B:76:00\x08md=T8010" + b"\x06pv=1.1\x05s#=75\x04sf=1\x04ci=2\x0bsh=xaQk4g==" + ) + parsed = r.DNSIncoming(packet) + answer = r.DNSNsec( + 'eufy HomeBase2-2464._hap._tcp.local.', + const._TYPE_NSEC, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'eufy HomeBase2-2464._hap._tcp.local.', + [const._TYPE_TXT, const._TYPE_SRV], + ) + assert answer in parsed.answers + + +def test_dns_compression_points_forward(): + """Test our wire parser can unpack nsec records with compression.""" + packet = ( + b"\x00\x00\x84\x00\x00\x00\x00\x07\x00\x00\x00\x00\x0eTV Beneden (2)" + b"\x10_androidtvremote\x04_tcp\x05local\x00\x00\x10\x80\x01\x00\x00\x11" + b"\x94\x00\x15\x14bt=D8:13:99:AC:98:F1\xc0\x0c\x00/\x80\x01\x00\x00\x11" + b"\x94\x00\t\xc0\x0c\x00\x05\x00\x00\x80\x00@\tAndroid-3\xc01\x00/\x80" + b"\x01\x00\x00\x00x\x00\x08\xc0\x9c\x00\x04@\x00\x00\x08\xc0l\x00\x01\x80" + b"\x01\x00\x00\x00x\x00\x04\xc0\xa8X\x0f\xc0\x0c\x00!\x80\x01\x00\x00\x00" + b"x\x00\x08\x00\x00\x00\x00\x19B\xc0l\xc0\x1b\x00\x0c\x00\x01\x00\x00\x11" + b"\x94\x00\x02\xc0\x0c\t_services\x07_dns-sd\x04_udp\xc01\x00\x0c\x00\x01" + b"\x00\x00\x11\x94\x00\x02\xc0\x1b" + ) + parsed = r.DNSIncoming(packet) + answer = r.DNSNsec( + 'TV Beneden (2)._androidtvremote._tcp.local.', + const._TYPE_NSEC, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'TV Beneden (2)._androidtvremote._tcp.local.', + [const._TYPE_TXT, const._TYPE_SRV], + ) + assert answer in parsed.answers + + +def test_dns_compression_points_to_itself(): + """Test our wire parser does not loop forever when a compression pointer points to itself.""" + packet = ( + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01" + b"\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\xc0(\x00\x01\x80\x01\x00\x00\x00" + b"\x01\x00\x04\xc0\xa8\xd0\x06" + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + +def test_dns_compression_points_beyond_packet(): + """Test our wire parser does not fail when the compression pointer points beyond the packet.""" + packet = ( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01' + b'\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\xe7\x0f\x00\x01\x80\x01\x00\x00' + b'\x00\x01\x00\x04\xc0\xa8\xd0\x06' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + +def test_dns_compression_generic_failure(caplog): + """Test our wire parser does not loop forever when dns compression is corrupt.""" + packet = ( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01' + b'\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05-\x0c\x00\x01\x80\x01\x00\x00' + b'\x00\x01\x00\x04\xc0\xa8\xd0\x06' + ) + parsed = r.DNSIncoming(packet, ("1.2.3.4", 5353)) + assert len(parsed.answers) == 1 + assert "Received invalid packet from ('1.2.3.4', 5353)" in caplog.text + + +def test_label_length_attack(): + """Test our wire parser does not loop forever when the name exceeds 253 chars.""" + packet = ( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x00\x00\x01\x80' + b'\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\xc0\x0c\x00\x01\x80\x01\x00\x00\x00' + b'\x01\x00\x04\xc0\xa8\xd0\x06' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 0 + + +def test_label_compression_attack(): + """Test our wire parser does not loop forever when exceeding the maximum number of labels.""" + packet = ( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x03atk\x00\x00\x01\x80' + b'\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\xc0' + b'\x0c\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x06' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + +def test_dns_compression_loop_attack(): + """Test our wire parser does not loop forever when dns compression is in a loop.""" + packet = ( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07\x03atk\x03dns\x05loc' + b'al\xc0\x10\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\x04a' + b'tk2\x04dns2\xc0\x14\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05' + b'\x04atk3\xc0\x10\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0' + b'\x05\x04atk4\x04dns5\xc0\x14\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0' + b'\xa8\xd0\x05\x04atk5\x04dns2\xc0^\x00\x01\x80\x01\x00\x00\x00\x01\x00' + b'\x04\xc0\xa8\xd0\x05\xc0s\x00\x01\x80\x01\x00\x00\x00\x01\x00' + b'\x04\xc0\xa8\xd0\x05\xc0s\x00\x01\x80\x01\x00\x00\x00\x01\x00' + b'\x04\xc0\xa8\xd0\x05' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 0 + + +def test_txt_after_invalid_nsec_name_still_usable(): + """Test that we can see the txt record after the invalid nsec record.""" + packet = ( + b'\x00\x00\x84\x00\x00\x00\x00\x06\x00\x00\x00\x00\x06_sonos\x04_tcp\x05loc' + b'al\x00\x00\x0c\x00\x01\x00\x00\x11\x94\x00\x15\x12Sonos-542A1BC9220E' + b'\xc0\x0c\x12Sonos-542A1BC9220E\xc0\x18\x00/\x80\x01\x00\x00\x00x\x00' + b'\x08\xc1t\x00\x04@\x00\x00\x08\xc0)\x00/\x80\x01\x00\x00\x11\x94\x00' + b'\t\xc0)\x00\x05\x00\x00\x80\x00@\xc0)\x00!\x80\x01\x00\x00\x00x' + b'\x00\x08\x00\x00\x00\x00\x05\xa3\xc0>\xc0>\x00\x01\x80\x01\x00\x00\x00x' + b'\x00\x04\xc0\xa8\x02:\xc0)\x00\x10\x80\x01\x00\x00\x11\x94\x01*2info=/api' + b'/v1/players/RINCON_542A1BC9220E01400/info\x06vers=3\x10protovers=1.24.1\nbo' + b'otseq=11%hhid=Sonos_rYn9K9DLXJe0f3LP9747lbvFvh;mhhid=Sonos_rYn9K9DLXJe0f3LP9' + b'747lbvFvh.Q45RuMaeC07rfXh7OJGm None: + nonlocal updates + updates.append(record) + + listener = LegacyRecordUpdateListener() + + zc.add_listener(listener, None) + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + # start a browser + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + browser = ServiceBrowser(zc, type_, [on_service_state_change]) + + info_service = ServiceInfo( + type_, + f'{name}.{type_}', + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + zc.register_service(info_service) + + time.sleep(0.001) + + browser.cancel() + + assert len(updates) + assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1 + + zc.remove_listener(listener) + # Removing a second time should not throw + zc.remove_listener(listener) + + zc.close() diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..2ef4b15b --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,21 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" diff --git a/tests/utils/test_asyncio.py b/tests/utils/test_asyncio.py new file mode 100644 index 00000000..2939b5ab --- /dev/null +++ b/tests/utils/test_asyncio.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for zeroconf._utils.asyncio.""" + +import asyncio +import concurrent.futures +import contextlib +import threading +import time +from unittest.mock import patch + +import pytest + +from zeroconf._core import _CLOSE_TIMEOUT +from zeroconf._utils import asyncio as aioutils +from zeroconf.const import _LOADED_SYSTEM_TIMEOUT + + +@pytest.mark.asyncio +async def test_async_get_all_tasks() -> None: + """Test we can get all tasks in the event loop. + + We make sure we handle RuntimeError here as + this is not thread safe under PyPy + """ + await aioutils._async_get_all_tasks(aioutils.get_running_loop()) + if not hasattr(asyncio, 'all_tasks'): + return + with patch("zeroconf._utils.asyncio.asyncio.all_tasks", side_effect=RuntimeError): + await aioutils._async_get_all_tasks(aioutils.get_running_loop()) + + +@pytest.mark.asyncio +async def test_get_running_loop_from_async() -> None: + """Test we can get the event loop.""" + assert isinstance(aioutils.get_running_loop(), asyncio.AbstractEventLoop) + + +def test_get_running_loop_no_loop() -> None: + """Test we get None when there is no loop running.""" + assert aioutils.get_running_loop() is None + + +@pytest.mark.asyncio +async def test_wait_event_or_timeout_times_out() -> None: + """Test wait_event_or_timeout will timeout.""" + test_event = asyncio.Event() + await aioutils.wait_event_or_timeout(test_event, 0.1) + + task = asyncio.ensure_future(test_event.wait()) + await asyncio.sleep(0.1) + + async def _async_wait_or_timeout(): + await aioutils.wait_event_or_timeout(test_event, 0.1) + + # Test high lock contention + await asyncio.gather(*[_async_wait_or_timeout() for _ in range(100)]) + + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + +def test_shutdown_loop() -> None: + """Test shutting down an event loop.""" + loop = None + loop_thread_ready = threading.Event() + runcoro_thread_ready = threading.Event() + + def _run_loop() -> None: + nonlocal loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop_thread_ready.set() + loop.run_forever() + + loop_thread = threading.Thread(target=_run_loop, daemon=True) + loop_thread.start() + loop_thread_ready.wait() + + async def _still_running(): + await asyncio.sleep(5) + + def _run_coro() -> None: + runcoro_thread_ready.set() + with contextlib.suppress(concurrent.futures.TimeoutError): + asyncio.run_coroutine_threadsafe(_still_running(), loop).result(1) + + runcoro_thread = threading.Thread(target=_run_coro, daemon=True) + runcoro_thread.start() + runcoro_thread_ready.wait() + + time.sleep(0.1) + aioutils.shutdown_loop(loop) + for _ in range(5): + if not loop.is_running(): + break + time.sleep(0.05) + + assert loop.is_running() is False + runcoro_thread.join() + + +def test_cumulative_timeouts_less_than_close_plus_buffer(): + """Test that the combined async timeouts are shorter than the close timeout with the buffer. + + We want to make sure that the close timeout is the one that gets + raised if something goes wrong. + """ + assert ( + aioutils._TASK_AWAIT_TIMEOUT + aioutils._GET_ALL_TASKS_TIMEOUT + aioutils._WAIT_FOR_LOOP_TASKS_TIMEOUT + ) < 1 + _CLOSE_TIMEOUT + _LOADED_SYSTEM_TIMEOUT diff --git a/tests/utils/test_name.py b/tests/utils/test_name.py new file mode 100644 index 00000000..6f8b417d --- /dev/null +++ b/tests/utils/test_name.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for zeroconf._utils.name.""" + +import pytest + +from zeroconf._utils import name as nameutils +from zeroconf import BadTypeInNameException + + +def test_service_type_name_overlong_type(): + """Test overlong service_type_name type.""" + with pytest.raises(BadTypeInNameException): + nameutils.service_type_name("Tivo1._tivo-videostream._tcp.local.") + nameutils.service_type_name("Tivo1._tivo-videostream._tcp.local.", strict=False) + + +def test_service_type_name_overlong_full_name(): + """Test overlong service_type_name full name.""" + long_name = "Tivo1Tivo1Tivo1Tivo1Tivo1Tivo1Tivo1Tivo1" * 100 + with pytest.raises(BadTypeInNameException): + nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.") + with pytest.raises(BadTypeInNameException): + nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.", strict=False) diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py new file mode 100644 index 00000000..41fdb7aa --- /dev/null +++ b/tests/utils/test_net.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for zeroconf._utils.net.""" +from unittest.mock import Mock, patch + +import errno +import ifaddr +import pytest +import socket +import unittest + +from zeroconf._utils import net as netutils +import zeroconf as r + + +def _generate_mock_adapters(): + mock_lo0 = Mock(spec=ifaddr.Adapter) + mock_lo0.nice_name = "lo0" + mock_lo0.ips = [ifaddr.IP("127.0.0.1", 8, "lo0")] + mock_lo0.index = 0 + mock_eth0 = Mock(spec=ifaddr.Adapter) + mock_eth0.nice_name = "eth0" + mock_eth0.ips = [ifaddr.IP(("2001:db8::", 1, 1), 8, "eth0")] + mock_eth0.index = 1 + mock_eth1 = Mock(spec=ifaddr.Adapter) + mock_eth1.nice_name = "eth1" + mock_eth1.ips = [ifaddr.IP("192.168.1.5", 23, "eth1")] + mock_eth1.index = 2 + mock_vtun0 = Mock(spec=ifaddr.Adapter) + mock_vtun0.nice_name = "vtun0" + mock_vtun0.ips = [ifaddr.IP("169.254.3.2", 16, "vtun0")] + mock_vtun0.index = 3 + return [mock_eth0, mock_lo0, mock_eth1, mock_vtun0] + + +def test_ip6_to_address_and_index(): + """Test we can extract from mocked adapters.""" + adapters = _generate_mock_adapters() + assert netutils.ip6_to_address_and_index(adapters, "2001:db8::") == (('2001:db8::', 1, 1), 1) + with pytest.raises(RuntimeError): + assert netutils.ip6_to_address_and_index(adapters, "2005:db8::") + + +def test_interface_index_to_ip6_address(): + """Test we can extract from mocked adapters.""" + adapters = _generate_mock_adapters() + assert netutils.interface_index_to_ip6_address(adapters, 1) == ('2001:db8::', 1, 1) + + # call with invalid adapter + with pytest.raises(RuntimeError): + assert netutils.interface_index_to_ip6_address(adapters, 6) + + # call with adapter that has ipv4 address only + with pytest.raises(RuntimeError): + assert netutils.interface_index_to_ip6_address(adapters, 2) + + +def test_ip6_addresses_to_indexes(): + """Test we can extract from mocked adapters.""" + interfaces = [1] + with patch("zeroconf._utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()): + assert netutils.ip6_addresses_to_indexes(interfaces) == [(('2001:db8::', 1, 1), 1)] + + interfaces = ['2001:db8::'] + with patch("zeroconf._utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()): + assert netutils.ip6_addresses_to_indexes(interfaces) == [(('2001:db8::', 1, 1), 1)] + + +def test_normalize_interface_choice_errors(): + """Test we generate exception on invalid input.""" + with patch("zeroconf._utils.net.get_all_addresses", return_value=[]), patch( + "zeroconf._utils.net.get_all_addresses_v6", return_value=[] + ), pytest.raises(RuntimeError): + netutils.normalize_interface_choice(r.InterfaceChoice.All) + + with pytest.raises(TypeError): + netutils.normalize_interface_choice("1.2.3.4") + + +@pytest.mark.parametrize( + "errno,expected_result", + [(errno.EADDRINUSE, False), (errno.EADDRNOTAVAIL, False), (errno.EINVAL, False), (0, True)], +) +def test_add_multicast_member_socket_errors(errno, expected_result): + """Test we handle socket errors when adding multicast members.""" + if errno: + setsockopt_mock = unittest.mock.Mock(side_effect=OSError(errno, "Error: {}".format(errno))) + else: + setsockopt_mock = unittest.mock.Mock() + fileno_mock = unittest.mock.PropertyMock(return_value=10) + socket_mock = unittest.mock.Mock(setsockopt=setsockopt_mock, fileno=fileno_mock) + assert r.add_multicast_member(socket_mock, "0.0.0.0") == expected_result + + +def test_autodetect_ip_version(): + """Tests for auto detecting IPVersion based on interface ips.""" + assert r.autodetect_ip_version(["1.3.4.5"]) is r.IPVersion.V4Only + assert r.autodetect_ip_version([]) is r.IPVersion.V4Only + assert r.autodetect_ip_version(["::1", "1.2.3.4"]) is r.IPVersion.All + assert r.autodetect_ip_version(["::1"]) is r.IPVersion.V6Only + + +def test_disable_ipv6_only_or_raise(): + """Test that IPV6_V6ONLY failing logs a nice error message and still raises.""" + errors_logged = [] + + def _log_error(*args): + nonlocal errors_logged + errors_logged.append(args) + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with pytest.raises(OSError), patch.object(netutils.log, "error", _log_error), patch( + "socket.socket.setsockopt", side_effect=OSError + ): + netutils.disable_ipv6_only_or_raise(sock) + + assert ( + errors_logged[0][0] + == 'Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6' + ) + + +@pytest.mark.skipif(not hasattr(socket, 'SO_REUSEPORT'), reason="System does not have SO_REUSEPORT") +def test_set_so_reuseport_if_available_is_present(): + """Test that setting socket.SO_REUSEPORT only OSError errno.ENOPROTOOPT is trapped.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError): + netutils.set_so_reuseport_if_available(sock) + + with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENOPROTOOPT, None)): + netutils.set_so_reuseport_if_available(sock) + + +@pytest.mark.skipif(hasattr(socket, 'SO_REUSEPORT'), reason="System has SO_REUSEPORT") +def test_set_so_reuseport_if_available_not_present(): + """Test that we do not try to set SO_REUSEPORT if it is not present.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with patch("socket.socket.setsockopt", side_effect=OSError): + netutils.set_so_reuseport_if_available(sock) + + +def test_set_mdns_port_socket_options_for_ip_version(): + """Test OSError with errno with EINVAL and bind address '' from setsockopt IP_MULTICAST_TTL does not raise.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + # Should raise on EPERM always + with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.EPERM, None)): + netutils.set_mdns_port_socket_options_for_ip_version(sock, ('',), r.IPVersion.V4Only) + + # Should raise on EINVAL always when bind address is not '' + with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.EINVAL, None)): + netutils.set_mdns_port_socket_options_for_ip_version(sock, ('127.0.0.1',), r.IPVersion.V4Only) + + # Should not raise on EINVAL when bind address is '' + with patch("socket.socket.setsockopt", side_effect=OSError(errno.EINVAL, None)): + netutils.set_mdns_port_socket_options_for_ip_version(sock, ('',), r.IPVersion.V4Only) + + +def test_add_multicast_member(): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + interface = '127.0.0.1' + + # EPERM should always raise + with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.EPERM, None)): + netutils.add_multicast_member(sock, interface) + + # EADDRINUSE should return False + with patch("socket.socket.setsockopt", side_effect=OSError(errno.EADDRINUSE, None)): + assert netutils.add_multicast_member(sock, interface) is False + + # EADDRNOTAVAIL should return False + with patch("socket.socket.setsockopt", side_effect=OSError(errno.EADDRNOTAVAIL, None)): + assert netutils.add_multicast_member(sock, interface) is False + + # EINVAL should return False + with patch("socket.socket.setsockopt", side_effect=OSError(errno.EINVAL, None)): + assert netutils.add_multicast_member(sock, interface) is False + + # ENOPROTOOPT should return False + with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENOPROTOOPT, None)): + assert netutils.add_multicast_member(sock, interface) is False + + # ENODEV should raise for ipv4 + with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.ENODEV, None)): + netutils.add_multicast_member(sock, interface) is False + + # ENODEV should return False for ipv6 + with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENODEV, None)): + assert netutils.add_multicast_member(sock, ('2001:db8::', 1, 1)) is False + + # No IPv6 support should return False for IPv6 + with patch("socket.inet_pton", side_effect=OSError()): + assert netutils.add_multicast_member(sock, ('2001:db8::', 1, 1)) is False + + # No error should return True + with patch("socket.socket.setsockopt"): + assert netutils.add_multicast_member(sock, interface) is True diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 35fac80a..4821cbb8 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -20,2608 +20,88 @@ USA """ -import enum -import errno -import ipaddress -import itertools -import logging -import os -import platform -import re -import select -import socket -import struct import sys -import threading -import time -import warnings -from typing import Dict, List, Optional, Sequence, Union, cast -from typing import Any, Callable, Set, Tuple # noqa # used in type hints -import ifaddr +from ._cache import DNSCache # noqa # import needed for backwards compat +from ._core import Zeroconf # noqa # import needed for backwards compat +from ._dns import ( # noqa # import needed for backwards compat + DNSAddress, + DNSEntry, + DNSHinfo, + DNSNsec, + DNSPointer, + DNSQuestion, + DNSRecord, + DNSService, + DNSText, + DNSQuestionType, +) +from ._logger import QuietLogger, log # noqa # import needed for backwards compat +from ._exceptions import ( # noqa # import needed for backwards compat + AbstractMethodException, + BadTypeInNameException, + Error, + IncomingDecodeError, + NamePartTooLongException, + NonUniqueNameException, + ServiceNameAlreadyRegistered, +) +from ._protocol.incoming import DNSIncoming # noqa # import needed for backwards compat +from ._protocol.outgoing import DNSOutgoing # noqa # import needed for backwards compat +from ._services import ( # noqa # import needed for backwards compat + Signal, + SignalRegistrationInterface, + ServiceListener, + ServiceStateChange, +) +from ._services.browser import ( # noqa # import needed for backwards compat + ServiceBrowser, +) +from ._services.info import ( # noqa # import needed for backwards compat + instance_name_from_service_info, + ServiceInfo, +) +from ._services.registry import ServiceRegistry # noqa # import needed for backwards compat +from ._services.types import ZeroconfServiceTypes +from ._updates import RecordUpdate, RecordUpdateListener # noqa # import needed for backwards compat +from ._utils.name import service_type_name # noqa # import needed for backwards compat +from ._utils.net import ( # noqa # import needed for backwards compat + add_multicast_member, + autodetect_ip_version, + create_sockets, + get_all_addresses_v6, + InterfaceChoice, + InterfacesType, + IPVersion, + get_all_addresses, +) +from ._utils.time import current_time_millis, millis_to_seconds # noqa # import needed for backwards compat __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.24.4' +__version__ = '0.36.7' __license__ = 'LGPL' __all__ = [ "__version__", + "DNSQuestionType", "Zeroconf", "ServiceInfo", "ServiceBrowser", + "ServiceListener", "Error", "InterfaceChoice", "ServiceStateChange", "IPVersion", + "ZeroconfServiceTypes", ] -if sys.version_info <= (3, 3): - raise ImportError( +if sys.version_info <= (3, 6): # pragma: no cover + raise ImportError( # pragma: no cover ''' -Python version > 3.3 required for python-zeroconf. -If you need support for Python 2 or Python 3.3 please use version 19.1 +Python version > 3.6 required for python-zeroconf. +If you need support for Python 2 or Python 3.3-3.4 please use version 19.1 +If you need support for Python 3.5 please use version 0.28.0 ''' ) - -log = logging.getLogger(__name__) -log.addHandler(logging.NullHandler()) - -if log.level == logging.NOTSET: - log.setLevel(logging.WARN) - -# Some timing constants - -_UNREGISTER_TIME = 125 # ms -_CHECK_TIME = 175 # ms -_REGISTER_TIME = 225 # ms -_LISTENER_TIME = 200 # ms -_BROWSER_TIME = 1000 # ms -_BROWSER_BACKOFF_LIMIT = 3600 # s - -# Some DNS constants - -_MDNS_ADDR = '224.0.0.251' -_MDNS_ADDR_BYTES = socket.inet_aton(_MDNS_ADDR) -_MDNS_ADDR6 = 'ff02::fb' -_MDNS_ADDR6_BYTES = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) -_MDNS_PORT = 5353 -_DNS_PORT = 53 -_DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762 -_DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762 - -_MAX_MSG_TYPICAL = 1460 # unused -_MAX_MSG_ABSOLUTE = 8966 - -_FLAGS_QR_MASK = 0x8000 # query response mask -_FLAGS_QR_QUERY = 0x0000 # query -_FLAGS_QR_RESPONSE = 0x8000 # response - -_FLAGS_AA = 0x0400 # Authoritative answer -_FLAGS_TC = 0x0200 # Truncated -_FLAGS_RD = 0x0100 # Recursion desired -_FLAGS_RA = 0x8000 # Recursion available - -_FLAGS_Z = 0x0040 # Zero -_FLAGS_AD = 0x0020 # Authentic data -_FLAGS_CD = 0x0010 # Checking disabled - -_CLASS_IN = 1 -_CLASS_CS = 2 -_CLASS_CH = 3 -_CLASS_HS = 4 -_CLASS_NONE = 254 -_CLASS_ANY = 255 -_CLASS_MASK = 0x7FFF -_CLASS_UNIQUE = 0x8000 - -_TYPE_A = 1 -_TYPE_NS = 2 -_TYPE_MD = 3 -_TYPE_MF = 4 -_TYPE_CNAME = 5 -_TYPE_SOA = 6 -_TYPE_MB = 7 -_TYPE_MG = 8 -_TYPE_MR = 9 -_TYPE_NULL = 10 -_TYPE_WKS = 11 -_TYPE_PTR = 12 -_TYPE_HINFO = 13 -_TYPE_MINFO = 14 -_TYPE_MX = 15 -_TYPE_TXT = 16 -_TYPE_AAAA = 28 -_TYPE_SRV = 33 -_TYPE_ANY = 255 - -# Mapping constants to names - -_CLASSES = { - _CLASS_IN: "in", - _CLASS_CS: "cs", - _CLASS_CH: "ch", - _CLASS_HS: "hs", - _CLASS_NONE: "none", - _CLASS_ANY: "any", -} - -_TYPES = { - _TYPE_A: "a", - _TYPE_NS: "ns", - _TYPE_MD: "md", - _TYPE_MF: "mf", - _TYPE_CNAME: "cname", - _TYPE_SOA: "soa", - _TYPE_MB: "mb", - _TYPE_MG: "mg", - _TYPE_MR: "mr", - _TYPE_NULL: "null", - _TYPE_WKS: "wks", - _TYPE_PTR: "ptr", - _TYPE_HINFO: "hinfo", - _TYPE_MINFO: "minfo", - _TYPE_MX: "mx", - _TYPE_TXT: "txt", - _TYPE_AAAA: "quada", - _TYPE_SRV: "srv", - _TYPE_ANY: "any", -} - -_HAS_A_TO_Z = re.compile(r'[A-Za-z]') -_HAS_ONLY_A_TO_Z_NUM_HYPHEN = re.compile(r'^[A-Za-z0-9\-]+$') -_HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$') -_HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]') - -try: - _IPPROTO_IPV6 = socket.IPPROTO_IPV6 -except AttributeError: - # Sigh: https://bugs.python.org/issue29515 - _IPPROTO_IPV6 = 41 - -int2byte = struct.Struct(">B").pack - - -@enum.unique -class InterfaceChoice(enum.Enum): - Default = 1 - All = 2 - - -InterfacesType = Union[List[Union[str, int]], InterfaceChoice] - - -@enum.unique -class ServiceStateChange(enum.Enum): - Added = 1 - Removed = 2 - Updated = 3 - - -@enum.unique -class IPVersion(enum.Enum): - V4Only = 1 - V6Only = 2 - All = 3 - - -# utility functions - - -def current_time_millis() -> float: - """Current system time in milliseconds""" - return time.time() * 1000 - - -def _is_v6_address(addr: bytes) -> bool: - return len(addr) == 16 - - -def service_type_name(type_: str, *, allow_underscores: bool = False) -> str: - """ - Validate a fully qualified service name, instance or subtype. [rfc6763] - - Returns fully qualified service name. - - Domain names used by mDNS-SD take the following forms: - - . <_tcp|_udp> . local. - . . <_tcp|_udp> . local. - ._sub . . <_tcp|_udp> . local. - - 1) must end with 'local.' - - This is true because we are implementing mDNS and since the 'm' means - multi-cast, the 'local.' domain is mandatory. - - 2) local is preceded with either '_udp.' or '_tcp.' - - 3) service name precedes <_tcp|_udp> - - The rules for Service Names [RFC6335] state that they may be no more - than fifteen characters long (not counting the mandatory underscore), - consisting of only letters, digits, and hyphens, must begin and end - with a letter or digit, must not contain consecutive hyphens, and - must contain at least one letter. - - The instance name and sub type may be up to 63 bytes. - - The portion of the Service Instance Name is a user- - friendly name consisting of arbitrary Net-Unicode text [RFC5198]. It - MUST NOT contain ASCII control characters (byte values 0x00-0x1F and - 0x7F) [RFC20] but otherwise is allowed to contain any characters, - without restriction, including spaces, uppercase, lowercase, - punctuation -- including dots -- accented characters, non-Roman text, - and anything else that may be represented using Net-Unicode. - - :param type_: Type, SubType or service name to validate - :return: fully qualified service name (eg: _http._tcp.local.) - """ - if not (type_.endswith('._tcp.local.') or type_.endswith('._udp.local.')): - raise BadTypeInNameException("Type '%s' must end with '._tcp.local.' or '._udp.local.'" % type_) - - remaining = type_[: -len('._tcp.local.')].split('.') - name = remaining.pop() - if not name: - raise BadTypeInNameException("No Service name found") - - if len(remaining) == 1 and len(remaining[0]) == 0: - raise BadTypeInNameException("Type '%s' must not start with '.'" % type_) - - if name[0] != '_': - raise BadTypeInNameException("Service name (%s) must start with '_'" % name) - - # remove leading underscore - name = name[1:] - - if len(name) > 15: - raise BadTypeInNameException("Service name (%s) must be <= 15 bytes" % name) - - if '--' in name: - raise BadTypeInNameException("Service name (%s) must not contain '--'" % name) - - if '-' in (name[0], name[-1]): - raise BadTypeInNameException("Service name (%s) may not start or end with '-'" % name) - - if not _HAS_A_TO_Z.search(name): - raise BadTypeInNameException("Service name (%s) must contain at least one letter (eg: 'A-Z')" % name) - - allowed_characters_re = ( - _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE if allow_underscores else _HAS_ONLY_A_TO_Z_NUM_HYPHEN - ) - - if not allowed_characters_re.search(name): - raise BadTypeInNameException( - "Service name (%s) must contain only these characters: " - "A-Z, a-z, 0-9, hyphen ('-')%s" % (name, ", underscore ('_')" if allow_underscores else "") - ) - - if remaining and remaining[-1] == '_sub': - remaining.pop() - if len(remaining) == 0 or len(remaining[0]) == 0: - raise BadTypeInNameException("_sub requires a subtype name") - - if len(remaining) > 1: - remaining = ['.'.join(remaining)] - - if remaining: - length = len(remaining[0].encode('utf-8')) - if length > 63: - raise BadTypeInNameException("Too long: '%s'" % remaining[0]) - - if _HAS_ASCII_CONTROL_CHARS.search(remaining[0]): - raise BadTypeInNameException( - "Ascii control character 0x00-0x1F and 0x7F illegal in '%s'" % remaining[0] - ) - - return '_' + name + type_[-len('._tcp.local.') :] - - -# Exceptions - - -class Error(Exception): - pass - - -class IncomingDecodeError(Error): - pass - - -class NonUniqueNameException(Error): - pass - - -class NamePartTooLongException(Error): - pass - - -class AbstractMethodException(Error): - pass - - -class BadTypeInNameException(Error): - pass - - -# implementation classes - - -class QuietLogger: - _seen_logs = {} # type: Dict[str, Union[int, tuple]] - - @classmethod - def log_exception_warning(cls, logger_data: Optional[Tuple] = None) -> None: - exc_info = sys.exc_info() - exc_str = str(exc_info[1]) - if exc_str not in cls._seen_logs: - # log at warning level the first time this is seen - cls._seen_logs[exc_str] = exc_info - logger = log.warning - else: - logger = log.debug - if logger_data is not None: - logger(*logger_data) - logger('Exception occurred:', exc_info=True) - - @classmethod - def log_warning_once(cls, *args: Any) -> None: - msg_str = args[0] - if msg_str not in cls._seen_logs: - cls._seen_logs[msg_str] = 0 - logger = log.warning - else: - logger = log.debug - cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 - logger(*args) - - -class DNSEntry: - - """A DNS entry""" - - def __init__(self, name: str, type_: int, class_: int) -> None: - self.key = name.lower() - self.name = name - self.type = type_ - self.class_ = class_ & _CLASS_MASK - self.unique = (class_ & _CLASS_UNIQUE) != 0 - - def __eq__(self, other: Any) -> bool: - """Equality test on name, type, and class""" - return ( - self.name == other.name - and self.type == other.type - and self.class_ == other.class_ - and isinstance(other, DNSEntry) - ) - - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - - @staticmethod - def get_class_(class_: int) -> str: - """Class accessor""" - return _CLASSES.get(class_, "?(%s)" % class_) - - @staticmethod - def get_type(t: int) -> str: - """Type accessor""" - return _TYPES.get(t, "?(%s)" % t) - - def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str: - """String representation with additional information""" - result = "%s[%s,%s" % (hdr, self.get_type(self.type), self.get_class_(self.class_)) - if self.unique: - result += "-unique," - else: - result += "," - result += self.name - if other is not None: - result += "]=%s" % cast(Any, other) - else: - result += "]" - return result - - -class DNSQuestion(DNSEntry): - - """A DNS question entry""" - - def __init__(self, name: str, type_: int, class_: int) -> None: - DNSEntry.__init__(self, name, type_, class_) - - def answered_by(self, rec: 'DNSRecord') -> bool: - """Returns true if the question is answered by the record""" - return ( - self.class_ == rec.class_ - and (self.type == rec.type or self.type == _TYPE_ANY) - and self.name == rec.name - ) - - def __repr__(self) -> str: - """String representation""" - return DNSEntry.entry_to_string(self, "question", None) - - -class DNSRecord(DNSEntry): - - """A DNS record - like a DNS entry, but has a TTL""" - - # TODO: Switch to just int ttl - def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None: - DNSEntry.__init__(self, name, type_, class_) - self.ttl = ttl - self.created = current_time_millis() - self._expiration_time = self.get_expiration_time(100) - self._stale_time = self.get_expiration_time(50) - - def __eq__(self, other: Any) -> bool: - """Abstract method""" - raise AbstractMethodException - - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - - def suppressed_by(self, msg: 'DNSIncoming') -> bool: - """Returns true if any answer in a message can suffice for the - information held in this record.""" - for record in msg.answers: - if self.suppressed_by_answer(record): - return True - return False - - def suppressed_by_answer(self, other: 'DNSRecord') -> bool: - """Returns true if another record has same name, type and class, - and if its TTL is at least half of this record's.""" - return self == other and other.ttl > (self.ttl / 2) - - def get_expiration_time(self, percent: int) -> float: - """Returns the time at which this record will have expired - by a certain percentage.""" - return self.created + (percent * self.ttl * 10) - - # TODO: Switch to just int here - def get_remaining_ttl(self, now: float) -> Union[int, float]: - """Returns the remaining TTL in seconds.""" - return max(0, (self._expiration_time - now) / 1000.0) - - def is_expired(self, now: float) -> bool: - """Returns true if this record has expired.""" - return self._expiration_time <= now - - def is_stale(self, now: float) -> bool: - """Returns true if this record is at least half way expired.""" - return self._stale_time <= now - - def reset_ttl(self, other: 'DNSRecord') -> None: - """Sets this record's TTL and created time to that of - another record.""" - self.created = other.created - self.ttl = other.ttl - self._expiration_time = self.get_expiration_time(100) - self._stale_time = self.get_expiration_time(50) - - def write(self, out: 'DNSOutgoing') -> None: - """Abstract method""" - raise AbstractMethodException - - def to_string(self, other: Union[bytes, str]) -> str: - """String representation with additional information""" - arg = "%s/%s,%s" % (self.ttl, int(self.get_remaining_ttl(current_time_millis())), cast(Any, other)) - return DNSEntry.entry_to_string(self, "record", arg) - - -class DNSAddress(DNSRecord): - - """A DNS address record""" - - def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) - self.address = address - - def write(self, out: 'DNSOutgoing') -> None: - """Used in constructing an outgoing packet""" - out.write_string(self.address) - - def __eq__(self, other: Any) -> bool: - """Tests equality on address""" - return ( - isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and self.address == other.address - ) - - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - - def __repr__(self) -> str: - """String representation""" - try: - return self.to_string(str(socket.inet_ntoa(self.address))) - except Exception: # TODO stop catching all Exceptions - return self.to_string(str(self.address)) - - -class DNSHinfo(DNSRecord): - - """A DNS host information record""" - - def __init__( - self, name: str, type_: int, class_: int, ttl: int, cpu: Union[bytes, str], os: Union[bytes, str] - ) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) - try: - self.cpu = cast(bytes, cpu).decode('utf-8') - except AttributeError: - self.cpu = cast(str, cpu) - try: - self.os = cast(bytes, os).decode('utf-8') - except AttributeError: - self.os = cast(str, os) - - def write(self, out: 'DNSOutgoing') -> None: - """Used in constructing an outgoing packet""" - out.write_character_string(self.cpu.encode('utf-8')) - out.write_character_string(self.os.encode('utf-8')) - - def __eq__(self, other: Any) -> bool: - """Tests equality on cpu and os""" - return ( - isinstance(other, DNSHinfo) - and DNSEntry.__eq__(self, other) - and self.cpu == other.cpu - and self.os == other.os - ) - - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - - def __repr__(self) -> str: - """String representation""" - return self.to_string(self.cpu + " " + self.os) - - -class DNSPointer(DNSRecord): - - """A DNS pointer record""" - - def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) - self.alias = alias - - def write(self, out: 'DNSOutgoing') -> None: - """Used in constructing an outgoing packet""" - out.write_name(self.alias) - - def __eq__(self, other: Any) -> bool: - """Tests equality on alias""" - return isinstance(other, DNSPointer) and self.alias == other.alias and DNSEntry.__eq__(self, other) - - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - - def __repr__(self) -> str: - """String representation""" - return self.to_string(self.alias) - - -class DNSText(DNSRecord): - - """A DNS text record""" - - def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None: - assert isinstance(text, (bytes, type(None))) - DNSRecord.__init__(self, name, type_, class_, ttl) - self.text = text - - def write(self, out: 'DNSOutgoing') -> None: - """Used in constructing an outgoing packet""" - out.write_string(self.text) - - def __eq__(self, other: Any) -> bool: - """Tests equality on text""" - return isinstance(other, DNSText) and self.text == other.text and DNSEntry.__eq__(self, other) - - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - - def __repr__(self) -> str: - """String representation""" - if len(self.text) > 10: - return self.to_string(self.text[:7]) + "..." - else: - return self.to_string(self.text) - - -class DNSService(DNSRecord): - - """A DNS service record""" - - def __init__( - self, - name: str, - type_: int, - class_: int, - ttl: Union[float, int], - priority: int, - weight: int, - port: int, - server: str, - ) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) - self.priority = priority - self.weight = weight - self.port = port - self.server = server - - def write(self, out: 'DNSOutgoing') -> None: - """Used in constructing an outgoing packet""" - out.write_short(self.priority) - out.write_short(self.weight) - out.write_short(self.port) - out.write_name(self.server) - - def __eq__(self, other: Any) -> bool: - """Tests equality on priority, weight, port and server""" - return ( - isinstance(other, DNSService) - and self.priority == other.priority - and self.weight == other.weight - and self.port == other.port - and self.server == other.server - and DNSEntry.__eq__(self, other) - ) - - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - - def __repr__(self) -> str: - """String representation""" - return self.to_string("%s:%s" % (self.server, self.port)) - - -class DNSIncoming(QuietLogger): - - """Object representation of an incoming DNS packet""" - - def __init__(self, data: bytes) -> None: - """Constructor from string holding bytes of packet""" - self.offset = 0 - self.data = data - self.questions = [] # type: List[DNSQuestion] - self.answers = [] # type: List[DNSRecord] - self.id = 0 - self.flags = 0 # type: int - self.num_questions = 0 - self.num_answers = 0 - self.num_authorities = 0 - self.num_additionals = 0 - self.valid = False - - try: - self.read_header() - self.read_questions() - self.read_others() - self.valid = True - - except (IndexError, struct.error, IncomingDecodeError): - self.log_exception_warning(('Choked at offset %d while unpacking %r', self.offset, data)) - - def unpack(self, format_: bytes) -> tuple: - length = struct.calcsize(format_) - info = struct.unpack(format_, self.data[self.offset : self.offset + length]) - self.offset += length - return info - - def read_header(self) -> None: - """Reads header portion of packet""" - ( - self.id, - self.flags, - self.num_questions, - self.num_answers, - self.num_authorities, - self.num_additionals, - ) = self.unpack(b'!6H') - - def read_questions(self) -> None: - """Reads questions section of packet""" - for i in range(self.num_questions): - name = self.read_name() - type_, class_ = self.unpack(b'!HH') - - question = DNSQuestion(name, type_, class_) - self.questions.append(question) - - # def read_int(self): - # """Reads an integer from the packet""" - # return self.unpack(b'!I')[0] - - def read_character_string(self) -> bytes: - """Reads a character string from the packet""" - length = self.data[self.offset] - self.offset += 1 - return self.read_string(length) - - def read_string(self, length: int) -> bytes: - """Reads a string of a given length from the packet""" - info = self.data[self.offset : self.offset + length] - self.offset += length - return info - - def read_unsigned_short(self) -> int: - """Reads an unsigned short from the packet""" - return cast(int, self.unpack(b'!H')[0]) - - def read_others(self) -> None: - """Reads the answers, authorities and additionals section of the - packet""" - n = self.num_answers + self.num_authorities + self.num_additionals - for i in range(n): - domain = self.read_name() - type_, class_, ttl, length = self.unpack(b'!HHiH') - - rec = None # type: Optional[DNSRecord] - if type_ == _TYPE_A: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4)) - elif type_ == _TYPE_CNAME or type_ == _TYPE_PTR: - rec = DNSPointer(domain, type_, class_, ttl, self.read_name()) - elif type_ == _TYPE_TXT: - rec = DNSText(domain, type_, class_, ttl, self.read_string(length)) - elif type_ == _TYPE_SRV: - rec = DNSService( - domain, - type_, - class_, - ttl, - self.read_unsigned_short(), - self.read_unsigned_short(), - self.read_unsigned_short(), - self.read_name(), - ) - elif type_ == _TYPE_HINFO: - rec = DNSHinfo( - domain, type_, class_, ttl, self.read_character_string(), self.read_character_string() - ) - elif type_ == _TYPE_AAAA: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16)) - else: - # Try to ignore types we don't know about - # Skip the payload for the resource record so the next - # records can be parsed correctly - self.offset += length - - if rec is not None: - self.answers.append(rec) - - def is_query(self) -> bool: - """Returns true if this is a query""" - return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY - - def is_response(self) -> bool: - """Returns true if this is a response""" - return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE - - def read_utf(self, offset: int, length: int) -> str: - """Reads a UTF-8 string of a given length from the packet""" - return str(self.data[offset : offset + length], 'utf-8', 'replace') - - def read_name(self) -> str: - """Reads a domain name from the packet""" - result = '' - off = self.offset - next_ = -1 - first = off - - while True: - length = self.data[off] - off += 1 - if length == 0: - break - t = length & 0xC0 - if t == 0x00: - result = ''.join((result, self.read_utf(off, length) + '.')) - off += length - elif t == 0xC0: - if next_ < 0: - next_ = off + 1 - off = ((length & 0x3F) << 8) | self.data[off] - if off >= first: - raise IncomingDecodeError("Bad domain name (circular) at %s" % (off,)) - first = off - else: - raise IncomingDecodeError("Bad domain name at %s" % (off,)) - - if next_ >= 0: - self.offset = next_ - else: - self.offset = off - - return result - - -class DNSOutgoing: - - """Object representation of an outgoing packet""" - - def __init__(self, flags: int, multicast: bool = True) -> None: - self.finished = False - self.id = 0 - self.multicast = multicast - self.flags = flags - self.names = {} # type: Dict[str, int] - self.data = [] # type: List[bytes] - self.size = 12 - self.state = self.State.init - - self.questions = [] # type: List[DNSQuestion] - self.answers = [] # type: List[Tuple[DNSRecord, float]] - self.authorities = [] # type: List[DNSPointer] - self.additionals = [] # type: List[DNSRecord] - - def __repr__(self) -> str: - return '' % ', '.join( - [ - 'multicast=%s' % self.multicast, - 'flags=%s' % self.flags, - 'questions=%s' % self.questions, - 'answers=%s' % self.answers, - 'authorities=%s' % self.authorities, - 'additionals=%s' % self.additionals, - ] - ) - - class State(enum.Enum): - init = 0 - finished = 1 - - def add_question(self, record: DNSQuestion) -> None: - """Adds a question""" - self.questions.append(record) - - def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: - """Adds an answer""" - if not record.suppressed_by(inp): - self.add_answer_at_time(record, 0) - - def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None: - """Adds an answer if it does not expire by a certain time""" - if record is not None: - if now == 0 or not record.is_expired(now): - self.answers.append((record, now)) - - def add_authorative_answer(self, record: DNSPointer) -> None: - """Adds an authoritative answer""" - self.authorities.append(record) - - def add_additional_answer(self, record: DNSRecord) -> None: - """ Adds an additional answer - - From: RFC 6763, DNS-Based Service Discovery, February 2013 - - 12. DNS Additional Record Generation - - DNS has an efficiency feature whereby a DNS server may place - additional records in the additional section of the DNS message. - These additional records are records that the client did not - explicitly request, but the server has reasonable grounds to expect - that the client might request them shortly, so including them can - save the client from having to issue additional queries. - - This section recommends which additional records SHOULD be generated - to improve network efficiency, for both Unicast and Multicast DNS-SD - responses. - - 12.1. PTR Records - - When including a DNS-SD Service Instance Enumeration or Selective - Instance Enumeration (subtype) PTR record in a response packet, the - server/responder SHOULD include the following additional records: - - o The SRV record(s) named in the PTR rdata. - o The TXT record(s) named in the PTR rdata. - o All address records (type "A" and "AAAA") named in the SRV rdata. - - 12.2. SRV Records - - When including an SRV record in a response packet, the - server/responder SHOULD include the following additional records: - - o All address records (type "A" and "AAAA") named in the SRV rdata. - - """ - self.additionals.append(record) - - def pack(self, format_: Union[bytes, str], value: Any) -> None: - self.data.append(struct.pack(format_, value)) - self.size += struct.calcsize(format_) - - def write_byte(self, value: int) -> None: - """Writes a single byte to the packet""" - self.pack(b'!c', int2byte(value)) - - def insert_short(self, index: int, value: int) -> None: - """Inserts an unsigned short in a certain position in the packet""" - self.data.insert(index, struct.pack(b'!H', value)) - self.size += 2 - - def write_short(self, value: int) -> None: - """Writes an unsigned short to the packet""" - self.pack(b'!H', value) - - def write_int(self, value: Union[float, int]) -> None: - """Writes an unsigned integer to the packet""" - self.pack(b'!I', int(value)) - - def write_string(self, value: bytes) -> None: - """Writes a string to the packet""" - assert isinstance(value, bytes) - self.data.append(value) - self.size += len(value) - - def write_utf(self, s: str) -> None: - """Writes a UTF-8 string of a given length to the packet""" - utfstr = s.encode('utf-8') - length = len(utfstr) - if length > 64: - raise NamePartTooLongException - self.write_byte(length) - self.write_string(utfstr) - - def write_character_string(self, value: bytes) -> None: - assert isinstance(value, bytes) - length = len(value) - if length > 256: - raise NamePartTooLongException - self.write_byte(length) - self.write_string(value) - - def write_name(self, name: str) -> None: - """ - Write names to packet - - 18.14. Name Compression - - When generating Multicast DNS messages, implementations SHOULD use - name compression wherever possible to compress the names of resource - records, by replacing some or all of the resource record name with a - compact two-byte reference to an appearance of that data somewhere - earlier in the message [RFC1035]. - """ - - # split name into each label - parts = name.split('.') - if not parts[-1]: - parts.pop() - - # construct each suffix - name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))] - - # look for an existing name or suffix - for count, sub_name in enumerate(name_suffices): - if sub_name in self.names: - break - else: - count = len(name_suffices) - - # note the new names we are saving into the packet - name_length = len(name.encode('utf-8')) - for suffix in name_suffices[:count]: - self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1 - - # write the new names out. - for part in parts[:count]: - self.write_utf(part) - - # if we wrote part of the name, create a pointer to the rest - if count != len(name_suffices): - # Found substring in packet, create pointer - index = self.names[name_suffices[count]] - self.write_byte((index >> 8) | 0xC0) - self.write_byte(index & 0xFF) - else: - # this is the end of a name - self.write_byte(0) - - def write_question(self, question: DNSQuestion) -> None: - """Writes a question to the packet""" - self.write_name(question.name) - self.write_short(question.type) - self.write_short(question.class_) - - def write_record(self, record: DNSRecord, now: float) -> int: - """Writes a record (answer, authoritative answer, additional) to - the packet""" - if self.state == self.State.finished: - return 1 - - start_data_length, start_size = len(self.data), self.size - self.write_name(record.name) - self.write_short(record.type) - if record.unique and self.multicast: - self.write_short(record.class_ | _CLASS_UNIQUE) - else: - self.write_short(record.class_) - if now == 0: - self.write_int(record.ttl) - else: - self.write_int(record.get_remaining_ttl(now)) - index = len(self.data) - - # Adjust size for the short we will write before this record - self.size += 2 - record.write(self) - self.size -= 2 - - length = sum((len(d) for d in self.data[index:])) - # Here is the short we adjusted for - self.insert_short(index, length) - - # if we go over, then rollback and quit - if self.size > _MAX_MSG_ABSOLUTE: - while len(self.data) > start_data_length: - self.data.pop() - self.size = start_size - self.state = self.State.finished - return 1 - return 0 - - def packet(self) -> bytes: - """Returns a string containing the packet's bytes - - No further parts should be added to the packet once this - is done.""" - - overrun_answers, overrun_authorities, overrun_additionals = 0, 0, 0 - - if self.state != self.State.finished: - for question in self.questions: - self.write_question(question) - for answer, time_ in self.answers: - overrun_answers += self.write_record(answer, time_) - for authority in self.authorities: - overrun_authorities += self.write_record(authority, 0) - for additional in self.additionals: - overrun_additionals += self.write_record(additional, 0) - self.state = self.State.finished - - self.insert_short(0, len(self.additionals) - overrun_additionals) - self.insert_short(0, len(self.authorities) - overrun_authorities) - self.insert_short(0, len(self.answers) - overrun_answers) - self.insert_short(0, len(self.questions)) - self.insert_short(0, self.flags) - if self.multicast: - self.insert_short(0, 0) - else: - self.insert_short(0, self.id) - return b''.join(self.data) - - -class DNSCache: - - """A cache of DNS entries""" - - def __init__(self) -> None: - self.cache = {} # type: Dict[str, List[DNSRecord]] - - def add(self, entry: DNSRecord) -> None: - """Adds an entry""" - # Insert first in list so get returns newest entry - self.cache.setdefault(entry.key, []).insert(0, entry) - - def remove(self, entry: DNSRecord) -> None: - """Removes an entry""" - try: - list_ = self.cache[entry.key] - list_.remove(entry) - except (KeyError, ValueError): - pass - - def get(self, entry: DNSEntry) -> Optional[DNSRecord]: - """Gets an entry by key. Will return None if there is no - matching entry.""" - try: - list_ = self.cache[entry.key] - for cached_entry in list_: - if entry.__eq__(cached_entry): - return cached_entry - return None - except (KeyError, ValueError): - return None - - def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: - """Gets an entry by details. Will return None if there is - no matching entry.""" - entry = DNSEntry(name, type_, class_) - return self.get(entry) - - def entries_with_name(self, name: str) -> List[DNSRecord]: - """Returns a list of entries whose key matches the name.""" - try: - return self.cache[name.lower()] - except KeyError: - return [] - - def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: - now = current_time_millis() - for record in self.entries_with_name(name): - if ( - record.type == _TYPE_PTR - and not record.is_expired(now) - and cast(DNSPointer, record).alias == alias - ): - return record - return None - - def entries(self) -> List[DNSRecord]: - """Returns a list of all entries""" - if not self.cache: - return [] - else: - # avoid size change during iteration by copying the cache - values = list(self.cache.values()) - return list(itertools.chain.from_iterable(values)) - - -class Engine(threading.Thread): - - """An engine wraps read access to sockets, allowing objects that - need to receive data from sockets to be called back when the - sockets are ready. - - A reader needs a handle_read() method, which is called when the socket - it is interested in is ready for reading. - - Writers are not implemented here, because we only send short - packets. - """ - - def __init__(self, zc: 'Zeroconf') -> None: - threading.Thread.__init__(self, name='zeroconf-Engine') - self.daemon = True - self.zc = zc - self.readers = {} # type: Dict[socket.socket, Listener] - self.timeout = 5 - self.condition = threading.Condition() - self.start() - - def run(self) -> None: - while not self.zc.done: - with self.condition: - rs = self.readers.keys() - if len(rs) == 0: - # No sockets to manage, but we wait for the timeout - # or addition of a socket - self.condition.wait(self.timeout) - - if len(rs) != 0: - try: - rr, wr, er = select.select(cast(Sequence[Any], rs), [], [], self.timeout) - if not self.zc.done: - for socket_ in rr: - reader = self.readers.get(socket_) - if reader: - reader.handle_read(socket_) - - except (select.error, socket.error) as e: - # If the socket was closed by another thread, during - # shutdown, ignore it and exit - if e.args[0] not in (errno.EBADF, errno.ENOTCONN) or not self.zc.done: - raise - - def add_reader(self, reader: 'Listener', socket_: socket.socket) -> None: - with self.condition: - self.readers[socket_] = reader - self.condition.notify() - - def del_reader(self, socket_: socket.socket) -> None: - with self.condition: - del self.readers[socket_] - self.condition.notify() - - -class Listener(QuietLogger): - - """A Listener is used by this module to listen on the multicast - group to which DNS messages are sent, allowing the implementation - to cache information as it arrives. - - It requires registration with an Engine object in order to have - the read() method called when a socket is available for reading.""" - - def __init__(self, zc: 'Zeroconf') -> None: - self.zc = zc - self.data = None # type: Optional[bytes] - - def handle_read(self, socket_: socket.socket) -> None: - try: - data, (addr, port, *_v6) = socket_.recvfrom(_MAX_MSG_ABSOLUTE) - except Exception: - self.log_exception_warning() - return - - log.debug('Received from %r:%r: %r ', addr, port, data) - - self.data = data - msg = DNSIncoming(data) - if not msg.valid: - pass - - elif msg.is_query(): - # Always multicast responses - if port == _MDNS_PORT: - self.zc.handle_query(msg, None, _MDNS_PORT) - - # If it's not a multicast query, reply via unicast - # and multicast - elif port == _DNS_PORT: - self.zc.handle_query(msg, addr, port) - self.zc.handle_query(msg, None, _MDNS_PORT) - - else: - self.zc.handle_response(msg) - - -class Reaper(threading.Thread): - - """A Reaper is used by this module to remove cache entries that - have expired.""" - - def __init__(self, zc: 'Zeroconf') -> None: - threading.Thread.__init__(self, name='zeroconf-Reaper') - self.daemon = True - self.zc = zc - self.start() - - def run(self) -> None: - while True: - self.zc.wait(10 * 1000) - if self.zc.done: - return - now = current_time_millis() - for record in self.zc.cache.entries(): - if record.is_expired(now): - self.zc.update_record(now, record) - self.zc.cache.remove(record) - - -class Signal: - def __init__(self) -> None: - self._handlers = [] # type: List[Callable[..., None]] - - def fire(self, **kwargs: Any) -> None: - for h in list(self._handlers): - h(**kwargs) - - @property - def registration_interface(self) -> 'SignalRegistrationInterface': - return SignalRegistrationInterface(self._handlers) - - -# NOTE: Callable quoting needed on Python 3.5.2, see -# https://github.com/jstasiak/python-zeroconf/issues/208 for details. -class SignalRegistrationInterface: - def __init__(self, handlers: List['Callable[..., None]']) -> None: - self._handlers = handlers - - def register_handler(self, handler: 'Callable[..., None]') -> 'SignalRegistrationInterface': - self._handlers.append(handler) - return self - - def unregister_handler(self, handler: 'Callable[..., None]') -> 'SignalRegistrationInterface': - self._handlers.remove(handler) - return self - - -class RecordUpdateListener: - def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: - raise NotImplementedError() - - -class ServiceListener: - def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - -class ServiceBrowser(RecordUpdateListener, threading.Thread): - - """Used to browse for a service of a specific type. - - The listener object will have its add_service() and - remove_service() methods called when this browser - discovers changes in the services availability.""" - - def __init__( - self, - zc: 'Zeroconf', - type_: str, - # NOTE: Callable quoting needed on Python 3.5.2, see - # https://github.com/jstasiak/python-zeroconf/issues/208 for details. - handlers: Optional[Union[ServiceListener, List['Callable[..., None]']]] = None, - listener: Optional[ServiceListener] = None, - addr: Optional[str] = None, - port: int = _MDNS_PORT, - delay: int = _BROWSER_TIME, - ) -> None: - """Creates a browser for a specific type""" - assert handlers or listener, 'You need to specify at least one handler' - if not type_.endswith(service_type_name(type_, allow_underscores=True)): - raise BadTypeInNameException - threading.Thread.__init__(self, name='zeroconf-ServiceBrowser_' + type_) - self.daemon = True - self.zc = zc - self.type = type_ - self.addr = addr - self.port = port - self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) - self.services = {} # type: Dict[str, DNSRecord] - self.next_time = current_time_millis() - self.delay = delay - self._handlers_to_call = [] # type: List[Callable[[Zeroconf], None]] - - self._service_state_changed = Signal() - - self.done = False - - if hasattr(handlers, 'add_service'): - listener = cast(ServiceListener, handlers) - handlers = None - - # NOTE: Callable quoting needed on Python 3.5.2, see - # https://github.com/jstasiak/python-zeroconf/issues/208 for details. - handlers = cast(List['Callable[..., None]'], handlers or []) - - if listener: - - def on_change( - zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange - ) -> None: - assert listener is not None - args = (zeroconf, service_type, name) - if state_change is ServiceStateChange.Added: - listener.add_service(*args) - elif state_change is ServiceStateChange.Removed: - listener.remove_service(*args) - elif state_change is ServiceStateChange.Updated: - if hasattr(listener, 'update_service'): - listener.update_service(*args) - else: - raise NotImplementedError(state_change) - - handlers.append(on_change) - - for h in handlers: - self.service_state_changed.register_handler(h) - - self.start() - - @property - def service_state_changed(self) -> SignalRegistrationInterface: - return self._service_state_changed.registration_interface - - def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: - """Callback invoked by Zeroconf when new information arrives. - - Updates information required by browser in the Zeroconf cache.""" - - def enqueue_callback(state_change: ServiceStateChange, name: str) -> None: - self._handlers_to_call.append( - lambda zeroconf: self._service_state_changed.fire( - zeroconf=zeroconf, service_type=self.type, name=name, state_change=state_change - ) - ) - - if record.type == _TYPE_PTR and record.name == self.type: - assert isinstance(record, DNSPointer) - expired = record.is_expired(now) - service_key = record.alias.lower() - try: - old_record = self.services[service_key] - except KeyError: - if not expired: - self.services[service_key] = record - enqueue_callback(ServiceStateChange.Added, record.alias) - else: - if not expired: - old_record.reset_ttl(record) - else: - del self.services[service_key] - enqueue_callback(ServiceStateChange.Removed, record.alias) - return - - expires = record.get_expiration_time(75) - if expires < self.next_time: - self.next_time = expires - - elif record.type == _TYPE_TXT and record.name.endswith(self.type): - assert isinstance(record, DNSText) - expired = record.is_expired(now) - if not expired: - enqueue_callback(ServiceStateChange.Updated, record.name) - - def cancel(self) -> None: - self.done = True - self.zc.remove_listener(self) - self.join() - - def run(self) -> None: - self.zc.add_listener(self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN)) - - while True: - now = current_time_millis() - if len(self._handlers_to_call) == 0 and self.next_time > now: - self.zc.wait(self.next_time - now) - if self.zc.done or self.done: - return - now = current_time_millis() - if self.next_time <= now: - out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast) - out.add_question(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN)) - for record in self.services.values(): - if not record.is_stale(now): - out.add_answer_at_time(record, now) - - self.zc.send(out, addr=self.addr, port=self.port) - self.next_time = now + self.delay - self.delay = min(_BROWSER_BACKOFF_LIMIT * 1000, self.delay * 2) - - if len(self._handlers_to_call) > 0 and not self.zc.done: - handler = self._handlers_to_call.pop(0) - handler(self.zc) - - -class ServiceInfo(RecordUpdateListener): - text = b'' - - """Service information""" - - # FIXME(dtantsur): black 19.3b0 produces code that is not valid syntax on - # Python 3.5: https://github.com/python/black/issues/759 - # fmt: off - def __init__( - self, - type_: str, - name: str, - address: Optional[Union[bytes, List[bytes]]] = None, - port: Optional[int] = None, - weight: int = 0, - priority: int = 0, - properties: Union[bytes, Dict] = b'', - server: Optional[str] = None, - host_ttl: int = _DNS_HOST_TTL, - other_ttl: int = _DNS_OTHER_TTL, - *, - addresses: Optional[List[bytes]] = None - ) -> None: - """Create a service description. - - type_: fully qualified service type name - name: fully qualified service name - address: IP address as unsigned short, network byte order (deprecated, use addresses) - port: port that the service runs on - weight: weight of the service - priority: priority of the service - properties: dictionary of properties (or a string holding the - bytes for the text field) - server: fully qualified name for service host (defaults to name) - host_ttl: ttl used for A/SRV records - other_ttl: ttl used for PTR/TXT records - addresses: List of IP addresses as unsigned short (IPv4) or unsigned - 128 bit number (IPv6), network byte order - """ - - # Accept both none, or one, but not both. - if address is not None and addresses is not None: - raise TypeError("address and addresses cannot be provided together") - - if not type_.endswith(service_type_name(name, allow_underscores=True)): - raise BadTypeInNameException - self.type = type_ - self.name = name - if addresses is not None: - self._addresses = addresses - elif address is not None: - warnings.warn("address is deprecated, use addresses instead", DeprecationWarning) - if isinstance(address, list): - self._addresses = address - else: - self._addresses = [address] - else: - self._addresses = [] - # This results in an ugly error when registering, better check now - invalid = [a for a in self._addresses - if not isinstance(a, bytes) or len(a) not in (4, 16)] - if invalid: - raise TypeError('Addresses must be bytes, got %s. Hint: convert string addresses ' - 'with socket.inet_pton' % invalid) - self.port = port - self.weight = weight - self.priority = priority - if server: - self.server = server - else: - self.server = name - self._properties = {} # type: Dict - self._set_properties(properties) - self.host_ttl = host_ttl - self.other_ttl = other_ttl - # fmt: on - - @property - def address(self) -> Optional[bytes]: - warnings.warn("ServiceInfo.address is deprecated, use addresses instead", DeprecationWarning) - try: - # Return the first V4 address for compatibility - return self.addresses[0] - except IndexError: - return None - - @address.setter - def address(self, value: bytes) -> None: - warnings.warn("ServiceInfo.address is deprecated, use addresses instead", DeprecationWarning) - if value is None: - self._addresses = [] - else: - self._addresses = [value] - - @property - def addresses(self) -> List[bytes]: - """IPv4 addresses of this service. - - Only IPv4 addresses are returned for backward compatibility. - Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to - include IPv6 addresses as well. - """ - return self.addresses_by_version(IPVersion.V4Only) - - @addresses.setter - def addresses(self, value: List[bytes]) -> None: - """Replace the addresses list. - - This replaces all currently stored addresses, both IPv4 and IPv6. - """ - self._addresses = value - - @property - def properties(self) -> Dict: - return self._properties - - def addresses_by_version(self, version: IPVersion) -> List[bytes]: - """List addresses matching IP version.""" - if version == IPVersion.V4Only: - return [addr for addr in self._addresses if not _is_v6_address(addr)] - elif version == IPVersion.V6Only: - return list(filter(_is_v6_address, self._addresses)) - else: - return self._addresses - - def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: - """List addresses in their parsed string form.""" - result = self.addresses_by_version(version) - return [ - socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr) - for addr in result - ] - - def _set_properties(self, properties: Union[bytes, Dict]) -> None: - """Sets properties and text of this info from a dictionary""" - if isinstance(properties, dict): - self._properties = properties - list_ = [] - result = b'' - for key, value in properties.items(): - if isinstance(key, str): - key = key.encode('utf-8') - - if value is None: - suffix = b'' - elif isinstance(value, str): - suffix = value.encode('utf-8') - elif isinstance(value, bytes): - suffix = value - elif isinstance(value, int): - if value: - suffix = b'true' - else: - suffix = b'false' - else: - suffix = b'' - list_.append(b'='.join((key, suffix))) - for item in list_: - result = b''.join((result, int2byte(len(item)), item)) - self.text = result - else: - self.text = properties - - def _set_text(self, text: bytes) -> None: - """Sets properties and text given a text field""" - self.text = text - result = {} # type: Dict - end = len(text) - index = 0 - strs = [] - while index < end: - length = text[index] - index += 1 - strs.append(text[index : index + length]) - index += length - - for s in strs: - parts = s.split(b'=', 1) - try: - key, value = parts # type: Tuple[bytes, Union[bool, bytes]] - except ValueError: - # No equals sign at all - key = s - value = False - else: - if value == b'true': - value = True - elif value == b'false' or not value: - value = False - - # Only update non-existent properties - if key and result.get(key) is None: - result[key] = value - - self._properties = result - - def get_name(self) -> str: - """Name accessor""" - if self.type is not None and self.name.endswith("." + self.type): - return self.name[: len(self.name) - len(self.type) - 1] - return self.name - - def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: - """Updates service information from a DNS record""" - if record is not None and not record.is_expired(now): - if record.type in [_TYPE_A, _TYPE_AAAA]: - assert isinstance(record, DNSAddress) - # if record.name == self.name: - if record.name == self.server: - if record.address not in self._addresses: - self._addresses.append(record.address) - elif record.type == _TYPE_SRV: - assert isinstance(record, DNSService) - if record.name == self.name: - self.server = record.server - self.port = record.port - self.weight = record.weight - self.priority = record.priority - # self.address = None - self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN)) - self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN)) - elif record.type == _TYPE_TXT: - assert isinstance(record, DNSText) - if record.name == self.name: - self._set_text(record.text) - - def request(self, zc: 'Zeroconf', timeout: float) -> bool: - """Returns true if the service could be discovered on the - network, and updates this object with details discovered. - """ - now = current_time_millis() - delay = _LISTENER_TIME - next_ = now + delay - last = now + timeout - - record_types_for_check_cache = [(_TYPE_SRV, _CLASS_IN), (_TYPE_TXT, _CLASS_IN)] - if self.server is not None: - record_types_for_check_cache.append((_TYPE_A, _CLASS_IN)) - record_types_for_check_cache.append((_TYPE_AAAA, _CLASS_IN)) - for record_type in record_types_for_check_cache: - cached = zc.cache.get_by_details(self.name, *record_type) - if cached: - self.update_record(zc, now, cached) - - if self.server is not None and self.text is not None and self._addresses: - return True - - try: - zc.add_listener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN)) - while self.server is None or self.text is None or not self._addresses: - if last <= now: - return False - if next_ <= now: - out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question(DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN)) - out.add_answer_at_time(zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN), now) - - out.add_question(DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN)) - out.add_answer_at_time(zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN), now) - - if self.server is not None: - out.add_question(DNSQuestion(self.server, _TYPE_A, _CLASS_IN)) - out.add_answer_at_time(zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN), now) - out.add_question(DNSQuestion(self.server, _TYPE_AAAA, _CLASS_IN)) - out.add_answer_at_time( - zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN), now - ) - zc.send(out) - next_ = now + delay - delay *= 2 - - zc.wait(min(next_, last) - now) - now = current_time_millis() - finally: - zc.remove_listener(self) - - return True - - def __eq__(self, other: object) -> bool: - """Tests equality of service name""" - return isinstance(other, ServiceInfo) and other.name == self.name - - def __ne__(self, other: object) -> bool: - """Non-equality test""" - return not self.__eq__(other) - - def __repr__(self) -> str: - """String representation""" - return '%s(%s)' % ( - type(self).__name__, - ', '.join( - '%s=%r' % (name, getattr(self, name)) - for name in ( - 'type', - 'name', - 'addresses', - 'port', - 'weight', - 'priority', - 'server', - 'properties', - ) - ), - ) - - -class ZeroconfServiceTypes(ServiceListener): - """ - Return all of the advertised services on any local networks - """ - - def __init__(self) -> None: - self.found_services = set() # type: Set[str] - - def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - self.found_services.add(name) - - def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - pass - - @classmethod - def find( - cls, - zc: Optional['Zeroconf'] = None, - timeout: Union[int, float] = 5, - interfaces: InterfacesType = InterfaceChoice.All, - ip_version: Optional[IPVersion] = None, - ) -> Tuple[str, ...]: - """ - Return all of the advertised services on any local networks. - - :param zc: Zeroconf() instance. Pass in if already have an - instance running or if non-default interfaces are needed - :param timeout: seconds to wait for any responses - :param interfaces: interfaces to listen on. - :param ip_version: IP protocol version to use. - :return: tuple of service type strings - """ - local_zc = zc or Zeroconf(interfaces=interfaces, ip_version=ip_version) - listener = cls() - browser = ServiceBrowser(local_zc, '_services._dns-sd._udp.local.', listener=listener) - - # wait for responses - time.sleep(timeout) - - # close down anything we opened - if zc is None: - local_zc.close() - else: - browser.cancel() - - return tuple(sorted(listener.found_services)) - - -def get_all_addresses() -> List[str]: - return list( - set( - addr.ip - for iface in ifaddr.get_adapters() - for addr in iface.ips - if addr.is_IPv4 and addr.network_prefix != 32 # Host only netmask 255.255.255.255 - ) - ) - - -def get_all_addresses_v6() -> List[int]: - # IPv6 multicast uses positive indexes for interfaces - try: - nameindex = socket.if_nameindex - except AttributeError: - # Requires Python 3.8 on Windows. Fall back to Default. - QuietLogger.log_warning_once( - 'if_nameindex is not available, falling back to using the default IPv6 interface' - ) - return [0] - - return [tpl[0] for tpl in nameindex()] - - -def ip_to_index(adapters: List[Any], ip: str) -> int: - if os.name != 'posix': - # Adapter names that ifaddr reports are not compatible with what if_nametoindex expects on Windows. - # We need https://github.com/pydron/ifaddr/pull/21 but it seems stuck on review. - raise RuntimeError('Converting from IP addresses to indexes is not supported on non-POSIX systems') - - ipaddr = ipaddress.ip_address(ip) - for adapter in adapters: - for adapter_ip in adapter.ips: - # IPv6 addresses are represented as tuples - if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr: - return socket.if_nametoindex(adapter.name) - - raise RuntimeError('No adapter found for IP address %s' % ip) - - -def ip6_addresses_to_indexes(interfaces: List[Union[str, int]]) -> List[int]: - """Convert IPv6 interface addresses to interface indexes. - - IPv4 addresses are ignored. The conversion currently only works on POSIX - systems. - - :param interfaces: List of IP addresses and indexes. - :returns: List of indexes. - """ - result = [] - adapters = ifaddr.get_adapters() - - for iface in interfaces: - if isinstance(iface, int): - result.append(iface) - elif isinstance(iface, str) and ipaddress.ip_address(iface).version == 6: - result.append(ip_to_index(adapters, iface)) - - return result - - -def normalize_interface_choice( - choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only -) -> List[Union[str, int]]: - """Convert the interfaces choice into internal representation. - - :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only). - :param ip_address: IP version to use (ignored if `choice` is a list). - :returns: List of IP addresses (for IPv4) and indexes (for IPv6). - """ - result = [] # type: List[Union[str, int]] - if choice is InterfaceChoice.Default: - if ip_version != IPVersion.V4Only: - # IPv6 multicast uses interface 0 to mean the default - result.append(0) - if ip_version != IPVersion.V6Only: - result.append('0.0.0.0') - elif choice is InterfaceChoice.All: - if ip_version != IPVersion.V4Only: - result.extend(get_all_addresses_v6()) - if ip_version != IPVersion.V6Only: - result.extend(get_all_addresses()) - if not result: - raise RuntimeError( - 'No interfaces to listen on, check that any interfaces have IP version %s' % ip_version - ) - elif isinstance(choice, list): - # First, take IPv4 addresses. - result = [i for i in choice if isinstance(i, str) and ipaddress.ip_address(i).version == 4] - # Unlike IP_ADD_MEMBERSHIP, IPV6_JOIN_GROUP requires interface indexes. - result += ip6_addresses_to_indexes(choice) - else: - raise TypeError("choice must be a list or InterfaceChoice, got %r" % choice) - return result - - -def new_socket( - port: int = _MDNS_PORT, ip_version: IPVersion = IPVersion.V4Only, apple_p2p: bool = False -) -> socket.socket: - if ip_version == IPVersion.V4Only: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - else: - s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) - - if ip_version == IPVersion.All: - # make V6 sockets work for both V4 and V6 (required for Windows) - try: - s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, False) - except OSError: - log.error('Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6') - raise - - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - - # SO_REUSEADDR should be equivalent to SO_REUSEPORT for - # multicast UDP sockets (p 731, "TCP/IP Illustrated, - # Volume 2"), but some BSD-derived systems require - # SO_REUSEPORT to be specified explicitly. Also, not all - # versions of Python have SO_REUSEPORT available. - # Catch OSError and socket.error for kernel versions <3.9 because lacking - # SO_REUSEPORT support. - try: - reuseport = socket.SO_REUSEPORT - except AttributeError: - pass - else: - try: - s.setsockopt(socket.SOL_SOCKET, reuseport, 1) - except (OSError, socket.error) as err: - # OSError on python 3, socket.error on python 2 - if not err.errno == errno.ENOPROTOOPT: - raise - - if port is _MDNS_PORT: - ttl = struct.pack(b'B', 255) - loop = struct.pack(b'B', 1) - if ip_version != IPVersion.V6Only: - # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and - # IP_MULTICAST_LOOP socket options as an unsigned char. - s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) - s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) - if ip_version != IPVersion.V4Only: - # However, char doesn't work here (at least on Linux) - s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) - s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True) - - if apple_p2p: - # SO_RECV_ANYIF = 0x1104 - # https://opensource.apple.com/source/xnu/xnu-4570.41.2/bsd/sys/socket.h - s.setsockopt(socket.SOL_SOCKET, 0x1104, 1) - - s.bind(('', port)) - return s - - -def add_multicast_member( - listen_socket: socket.socket, interface: Union[str, int], apple_p2p: bool = False -) -> Optional[socket.socket]: - # This is based on assumptions in normalize_interface_choice - is_v6 = isinstance(interface, int) - log.debug('Adding %r to multicast group', interface) - try: - if is_v6: - iface_bin = struct.pack('@I', cast(int, interface)) - _value = _MDNS_ADDR6_BYTES + iface_bin - listen_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, _value) - else: - _value = _MDNS_ADDR_BYTES + socket.inet_aton(cast(str, interface)) - listen_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value) - except socket.error as e: - _errno = get_errno(e) - if _errno == errno.EADDRINUSE: - log.info( - 'Address in use when adding %s to multicast group, ' - 'it is expected to happen on some systems', - interface, - ) - return None - elif _errno == errno.EADDRNOTAVAIL: - log.info( - 'Address not available when adding %s to multicast ' - 'group, it is expected to happen on some systems', - interface, - ) - return None - elif _errno == errno.EINVAL: - log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface) - return None - else: - raise - - respond_socket = new_socket( - ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), apple_p2p=apple_p2p - ) - log.debug('Configuring %s with multicast interface %s', respond_socket, interface) - if is_v6: - respond_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, iface_bin) - else: - respond_socket.setsockopt( - socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(cast(str, interface)) - ) - return respond_socket - - -def create_sockets( - interfaces: InterfacesType = InterfaceChoice.All, - unicast: bool = False, - ip_version: IPVersion = IPVersion.V4Only, - apple_p2p: bool = False, -) -> Tuple[Optional[socket.socket], List[socket.socket]]: - if unicast: - listen_socket = None - else: - listen_socket = new_socket(ip_version=ip_version, apple_p2p=apple_p2p) - - interfaces = normalize_interface_choice(interfaces, ip_version) - - respond_sockets = [] - - for i in interfaces: - if not unicast: - respond_socket = add_multicast_member(cast(socket.socket, listen_socket), i, apple_p2p=apple_p2p) - else: - respond_socket = new_socket(port=0, ip_version=ip_version, apple_p2p=apple_p2p) - - if respond_socket is not None: - respond_sockets.append(respond_socket) - - return listen_socket, respond_sockets - - -def get_errno(e: Exception) -> int: - assert isinstance(e, socket.error) - return cast(int, e.args[0]) - - -def can_send_to(sock: socket.socket, address: str) -> bool: - addr = ipaddress.ip_address(address) - return cast(bool, addr.version == 6 if sock.family == socket.AF_INET6 else addr.version == 4) - - -class Zeroconf(QuietLogger): - - """Implementation of Zeroconf Multicast DNS Service Discovery - - Supports registration, unregistration, queries and browsing. - """ - - def __init__( - self, - interfaces: InterfacesType = InterfaceChoice.All, - unicast: bool = False, - ip_version: Optional[IPVersion] = None, - apple_p2p: bool = False, - ) -> None: - """Creates an instance of the Zeroconf class, establishing - multicast communications, listening and reaping threads. - - :param interfaces: :class:`InterfaceChoice` or a list of IP addresses - (IPv4 and IPv6) and interface indexes (IPv6 only). - - IPv6 notes for non-POSIX systems: - * IPv6 addresses are not supported, use indexes instead. - * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` - on Python versions before 3.8. - - Also listening on loopback (``::1``) doesn't work, use a real address. - :param ip_version: IP versions to support. If `choice` is a list, the default is detected - from it. Otherwise defaults to V4 only for backward compatibility. - :param apple_p2p: use AWDL interface (only macOS) - """ - if ip_version is None and isinstance(interfaces, list): - has_v6 = any( - isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6) - for i in interfaces - ) - has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces) - if has_v4 and has_v6: - ip_version = IPVersion.All - elif has_v6: - ip_version = IPVersion.V6Only - - if ip_version is None: - ip_version = IPVersion.V4Only - - # hook for threads - self._GLOBAL_DONE = False - self.unicast = unicast - - if apple_p2p and not platform.system() == 'Darwin': - raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.') - - self._listen_socket, self._respond_sockets = create_sockets( - interfaces, unicast, ip_version, apple_p2p=apple_p2p - ) - - self.listeners = [] # type: List[RecordUpdateListener] - self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] - self.services = {} # type: Dict[str, ServiceInfo] - self.servicetypes = {} # type: Dict[str, int] - - self.cache = DNSCache() - - self.condition = threading.Condition() - - self.engine = Engine(self) - self.listener = Listener(self) - if not unicast: - self.engine.add_reader(self.listener, cast(socket.socket, self._listen_socket)) - else: - for s in self._respond_sockets: - self.engine.add_reader(self.listener, s) - self.reaper = Reaper(self) - - self.debug = None # type: Optional[DNSOutgoing] - - @property - def done(self) -> bool: - return self._GLOBAL_DONE - - def wait(self, timeout: float) -> None: - """Calling thread waits for a given number of milliseconds or - until notified.""" - with self.condition: - self.condition.wait(timeout / 1000.0) - - def notify_all(self) -> None: - """Notifies all waiting threads""" - with self.condition: - self.condition.notify_all() - - def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]: - """Returns network's service information for a particular - name and type, or None if no service matches by the timeout, - which defaults to 3 seconds.""" - info = ServiceInfo(type_, name) - if info.request(self, timeout): - return info - return None - - def add_service_listener(self, type_: str, listener: ServiceListener) -> None: - """Adds a listener for a particular service type. This object - will then have its add_service and remove_service methods called when - services of that type become available and unavailable.""" - self.remove_service_listener(listener) - self.browsers[listener] = ServiceBrowser(self, type_, listener) - - def remove_service_listener(self, listener: ServiceListener) -> None: - """Removes a listener from the set that is currently listening.""" - if listener in self.browsers: - self.browsers[listener].cancel() - del self.browsers[listener] - - def remove_all_service_listeners(self) -> None: - """Removes a listener from the set that is currently listening.""" - for listener in [k for k in self.browsers]: - self.remove_service_listener(listener) - - def register_service( - self, info: ServiceInfo, ttl: Optional[int] = None, allow_name_change: bool = False - ) -> None: - """Registers service information to the network with a default TTL. - Zeroconf will then respond to requests for information for that - service. The name of the service may be changed if needed to make - it unique on the network.""" - if ttl is not None: - # ttl argument is used to maintain backward compatibility - # Setting TTLs via ServiceInfo is preferred - info.host_ttl = ttl - info.other_ttl = ttl - self.check_service(info, allow_name_change) - self.services[info.name.lower()] = info - if info.type in self.servicetypes: - self.servicetypes[info.type] += 1 - else: - self.servicetypes[info.type] = 1 - - self._broadcast_service(info) - - def update_service(self, info: ServiceInfo) -> None: - """Registers service information to the network with a default TTL. - Zeroconf will then respond to requests for information for that - service.""" - - assert self.services[info.name.lower()] is not None - - self.services[info.name.lower()] = info - - self._broadcast_service(info) - - def _broadcast_service(self, info: ServiceInfo) -> None: - - now = current_time_millis() - next_time = now - i = 0 - while i < 3: - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, info.other_ttl, info.name), 0) - out.add_answer_at_time( - DNSService( - info.name, - _TYPE_SRV, - _CLASS_IN, - info.host_ttl, - info.priority, - info.weight, - cast(int, info.port), - info.server, - ), - 0, - ) - - out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, info.other_ttl, info.text), 0) - for address in info.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, info.host_ttl, address), 0) - self.send(out) - i += 1 - next_time += _REGISTER_TIME - - def unregister_service(self, info: ServiceInfo) -> None: - """Unregister a service.""" - try: - del self.services[info.name.lower()] - if self.servicetypes[info.type] > 1: - self.servicetypes[info.type] -= 1 - else: - del self.servicetypes[info.type] - except Exception as e: # TODO stop catching all Exceptions - log.exception('Unknown error, possibly benign: %r', e) - now = current_time_millis() - next_time = now - i = 0 - while i < 3: - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0) - out.add_answer_at_time( - DNSService( - info.name, - _TYPE_SRV, - _CLASS_IN, - 0, - info.priority, - info.weight, - cast(int, info.port), - info.name, - ), - 0, - ) - out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0) - - for address in info.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, 0, address), 0) - self.send(out) - i += 1 - next_time += _UNREGISTER_TIME - - def unregister_all_services(self) -> None: - """Unregister all registered services.""" - if len(self.services) > 0: - now = current_time_millis() - next_time = now - i = 0 - while i < 3: - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - for info in self.services.values(): - out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0) - out.add_answer_at_time( - DNSService( - info.name, - _TYPE_SRV, - _CLASS_IN, - 0, - info.priority, - info.weight, - cast(int, info.port), - info.server, - ), - 0, - ) - out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0) - for address in info.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, 0, address), 0) - self.send(out) - i += 1 - next_time += _UNREGISTER_TIME - - def check_service(self, info: ServiceInfo, allow_name_change: bool) -> None: - """Checks the network for a unique service name, modifying the - ServiceInfo passed in if it is not unique.""" - - # This is kind of funky because of the subtype based tests - # need to make subtypes a first class citizen - service_name = service_type_name(info.name) - if not info.type.endswith(service_name): - raise BadTypeInNameException - - instance_name = info.name[: -len(service_name) - 1] - next_instance_number = 2 - - now = current_time_millis() - next_time = now - i = 0 - while i < 3: - # check for a name conflict - while self.cache.current_entry_with_name_and_alias(info.type, info.name): - if not allow_name_change: - raise NonUniqueNameException - - # change the name and look for a conflict - info.name = '%s-%s.%s' % (instance_name, next_instance_number, info.type) - next_instance_number += 1 - service_type_name(info.name) - next_time = now - i = 0 - - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - - out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) - self.debug = out - out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN)) - out.add_authorative_answer(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, info.other_ttl, info.name)) - self.send(out) - i += 1 - next_time += _CHECK_TIME - - def add_listener(self, listener: RecordUpdateListener, question: Optional[DNSQuestion]) -> None: - """Adds a listener for a given question. The listener will have - its update_record method called when information is available to - answer the question.""" - now = current_time_millis() - self.listeners.append(listener) - if question is not None: - for record in self.cache.entries_with_name(question.name): - if question.answered_by(record) and not record.is_expired(now): - listener.update_record(self, now, record) - self.notify_all() - - def remove_listener(self, listener: RecordUpdateListener) -> None: - """Removes a listener.""" - try: - self.listeners.remove(listener) - self.notify_all() - except Exception as e: # TODO stop catching all Exceptions - log.exception('Unknown error, possibly benign: %r', e) - - def update_record(self, now: float, rec: DNSRecord) -> None: - """Used to notify listeners of new information that has updated - a record.""" - for listener in self.listeners: - listener.update_record(self, now, rec) - self.notify_all() - - def handle_response(self, msg: DNSIncoming) -> None: - """Deal with incoming response packets. All answers - are held in the cache, and listeners are notified.""" - now = current_time_millis() - for record in msg.answers: - if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - for entry in self.cache.entries(): - if DNSEntry.__eq__(entry, record) and (record.created - entry.created > 1000): - self.cache.remove(entry) - - expired = record.is_expired(now) - maybe_entry = self.cache.get(record) - if not expired: - if maybe_entry is not None: - maybe_entry.reset_ttl(record) - else: - self.cache.add(record) - self.update_record(now, record) - else: - if maybe_entry is not None: - self.update_record(now, record) - self.cache.remove(maybe_entry) - - def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: - """Deal with incoming query packets. Provides a response if - possible.""" - out = None - - # Support unicast client responses - # - if port != _MDNS_PORT: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False) - for question in msg.questions: - out.add_question(question) - - for question in msg.questions: - if question.type == _TYPE_PTR: - if question.name == "_services._dns-sd._udp.local.": - for stype in self.servicetypes.keys(): - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer( - msg, - DNSPointer( - "_services._dns-sd._udp.local.", _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype - ), - ) - for service in self.services.values(): - if question.name == service.type: - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer( - msg, - DNSPointer(service.type, _TYPE_PTR, _CLASS_IN, service.other_ttl, service.name), - ) - - # Add recommended additional answers according to - # https://tools.ietf.org/html/rfc6763#section-12.1. - out.add_additional_answer( - DNSService( - service.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - service.priority, - service.weight, - cast(int, service.port), - service.server, - ) - ) - out.add_additional_answer( - DNSText( - service.name, - _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, - service.other_ttl, - service.text, - ) - ) - for address in service.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_additional_answer( - DNSAddress( - service.server, - type_, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - address, - ) - ) - else: - try: - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - - # Answer A record queries for any service addresses we know - if question.type in (_TYPE_A, _TYPE_ANY): - for service in self.services.values(): - if service.server == question.name.lower(): - for address in service.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer( - msg, - DNSAddress( - question.name, - type_, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - address, - ), - ) - - name_to_find = question.name.lower() - if name_to_find not in self.services: - continue - service = self.services[name_to_find] - - if question.type in (_TYPE_SRV, _TYPE_ANY): - out.add_answer( - msg, - DNSService( - question.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - service.priority, - service.weight, - cast(int, service.port), - service.server, - ), - ) - if question.type in (_TYPE_TXT, _TYPE_ANY): - out.add_answer( - msg, - DNSText( - question.name, - _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, - service.other_ttl, - service.text, - ), - ) - if question.type == _TYPE_SRV: - for address in service.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_additional_answer( - DNSAddress( - service.server, - type_, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - address, - ) - ) - except Exception: # TODO stop catching all Exceptions - self.log_exception_warning() - - if out is not None and out.answers: - out.id = msg.id - self.send(out, addr, port) - - def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: - """Sends an outgoing packet.""" - packet = out.packet() - if len(packet) > _MAX_MSG_ABSOLUTE: - self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) - return - log.debug('Sending %r (%d bytes) as %r...', out, len(packet), packet) - for s in self._respond_sockets: - if self._GLOBAL_DONE: - return - try: - if addr is None: - real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR - elif not can_send_to(s, addr): - continue - else: - real_addr = addr - bytes_sent = s.sendto(packet, 0, (real_addr, port)) - except Exception as exc: # TODO stop catching all Exceptions - if ( - isinstance(exc, OSError) - and exc.errno == errno.ENETUNREACH - and s.family == socket.AF_INET6 - ): - # with IPv6 we don't have a reliable way to determine if an interface actually has IPv6 - # support, so we have to try and ignore errors. - continue - # on send errors, log the exception and keep going - self.log_exception_warning() - else: - if bytes_sent != len(packet): - self.log_warning_once('!!! sent %d out of %d bytes to %r' % (bytes_sent, len(packet), s)) - - def close(self) -> None: - """Ends the background threads, and prevent this instance from - servicing further queries.""" - if not self._GLOBAL_DONE: - # remove service listeners - self.remove_all_service_listeners() - self.unregister_all_services() - self._GLOBAL_DONE = True - - # shutdown recv socket and thread - if not self.unicast: - self.engine.del_reader(cast(socket.socket, self._listen_socket)) - cast(socket.socket, self._listen_socket).close() - else: - for s in self._respond_sockets: - self.engine.del_reader(s) - self.engine.join() - - # shutdown the rest - self.notify_all() - self.reaper.join() - for s in self._respond_sockets: - s.close() diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py new file mode 100644 index 00000000..24b6a233 --- /dev/null +++ b/zeroconf/_cache.py @@ -0,0 +1,209 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import itertools +from typing import Dict, Iterable, Iterator, List, Optional, Union, cast + +from ._dns import ( + DNSAddress, + DNSEntry, + DNSHinfo, + DNSPointer, + DNSRecord, + DNSService, + DNSText, + dns_entry_matches, +) +from ._utils.time import current_time_millis +from .const import _TYPE_PTR + +_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService) +_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] +_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] + + +def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None: + """Remove a key from a DNSRecord cache + + This function must be run in from event loop. + """ + del cache[key][entry] + if not cache[key]: + del cache[key] + + +class DNSCache: + """A cache of DNS entries.""" + + def __init__(self) -> None: + self.cache: _DNSRecordCacheType = {} + self.service_cache: _DNSRecordCacheType = {} + + # Functions prefixed with async_ are NOT threadsafe and must + # be run in the event loop. + + def _async_add(self, entry: DNSRecord) -> None: + """Adds an entry. + + This function must be run in from event loop. + """ + # Previously storage of records was implemented as a list + # instead a dict. Since DNSRecords are now hashable, the implementation + # uses a dict to ensure that adding a new record to the cache + # replaces any existing records that are __eq__ to each other which + # removes the risk that accessing the cache from the wrong + # direction would return the old incorrect entry. + self.cache.setdefault(entry.key, {})[entry] = entry + if isinstance(entry, DNSService): + self.service_cache.setdefault(entry.server, {})[entry] = entry + + def async_add_records(self, entries: Iterable[DNSRecord]) -> None: + """Add multiple records. + + This function must be run in from event loop. + """ + for entry in entries: + self._async_add(entry) + + def _async_remove(self, entry: DNSRecord) -> None: + """Removes an entry. + + This function must be run in from event loop. + """ + if isinstance(entry, DNSService): + _remove_key(self.service_cache, entry.server, entry) + _remove_key(self.cache, entry.key, entry) + + def async_remove_records(self, entries: Iterable[DNSRecord]) -> None: + """Remove multiple records. + + This function must be run in from event loop. + """ + for entry in entries: + self._async_remove(entry) + + def async_expire(self, now: float) -> List[DNSRecord]: + """Purge expired entries from the cache. + + This function must be run in from event loop. + """ + expired = [record for record in itertools.chain(*self.cache.values()) if record.is_expired(now)] + self.async_remove_records(expired) + return expired + + def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: + """Gets a unique entry by key. Will return None if there is no + matching entry. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.cache.get(entry.key, {}).get(entry) + + def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]: + """Gets all matching entries by details. + + This function is not threadsafe and must be called from + the event loop. + """ + key = name.lower() + for entry in self.cache.get(key, []): + if dns_entry_matches(entry, key, type_, class_): + yield entry + + def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: + """Returns a dict of entries whose key matches the name. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.cache.get(name.lower(), {}) + + def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]: + """Returns a dict of entries whose key matches the server. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.service_cache.get(name.lower(), {}) + + # The below functions are threadsafe and do not need to be run in the + # event loop, however they all make copies so they significantly + # inefficent + + def get(self, entry: DNSEntry) -> Optional[DNSRecord]: + """Gets an entry by key. Will return None if there is no + matching entry.""" + if isinstance(entry, _UNIQUE_RECORD_TYPES): + return self.cache.get(entry.key, {}).get(entry) + for cached_entry in reversed(list(self.cache.get(entry.key, []))): + if entry.__eq__(cached_entry): + return cached_entry + return None + + def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: + """Gets the first matching entry by details. Returns None if no entries match. + + Calling this function is not recommended as it will only + return one record even if there are multiple entries. + + For example if there are multiple A or AAAA addresses this + function will return the last one that was added to the cache + which may not be the one you expect. + + Use get_all_by_details instead. + """ + key = name.lower() + for cached_entry in reversed(list(self.cache.get(key, []))): + if dns_entry_matches(cached_entry, key, type_, class_): + return cached_entry + return None + + def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: + """Gets all matching entries by details.""" + key = name.lower() + return [ + entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_) + ] + + def entries_with_server(self, server: str) -> List[DNSRecord]: + """Returns a list of entries whose server matches the name.""" + return list(self.service_cache.get(server.lower(), [])) + + def entries_with_name(self, name: str) -> List[DNSRecord]: + """Returns a list of entries whose key matches the name.""" + return list(self.cache.get(name.lower(), [])) + + def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: + now = current_time_millis() + for record in reversed(self.entries_with_name(name)): + if ( + record.type == _TYPE_PTR + and not record.is_expired(now) + and cast(DNSPointer, record).alias == alias + ): + return record + return None + + def names(self) -> List[str]: + """Return a copy of the list of current cache names.""" + return list(self.cache) diff --git a/zeroconf/_core.py b/zeroconf/_core.py new file mode 100644 index 00000000..1575eba2 --- /dev/null +++ b/zeroconf/_core.py @@ -0,0 +1,928 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import itertools +import logging +import random +import socket +import sys +import threading +from types import TracebackType # noqa # used in type hints +from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union, cast + +from ._cache import DNSCache +from ._dns import DNSQuestion, DNSQuestionType +from ._exceptions import NonUniqueNameException +from ._handlers import ( + MulticastOutgoingQueue, + QueryHandler, + RecordManager, + construct_outgoing_multicast_answers, + construct_outgoing_unicast_answers, +) +from ._history import QuestionHistory +from ._logger import QuietLogger, log +from ._protocol.incoming import DNSIncoming +from ._protocol.outgoing import DNSOutgoing +from ._services import ServiceListener +from ._services.browser import ServiceBrowser +from ._services.info import ServiceInfo, instance_name_from_service_info +from ._services.registry import ServiceRegistry +from ._updates import RecordUpdate, RecordUpdateListener +from ._utils.asyncio import ( + await_awaitable, + get_running_loop, + run_coro_with_timeout, + shutdown_loop, + wait_event_or_timeout, +) +from ._utils.name import service_type_name +from ._utils.net import ( + IPVersion, + InterfaceChoice, + InterfacesType, + autodetect_ip_version, + can_send_to, + create_sockets, +) +from ._utils.time import current_time_millis, millis_to_seconds +from .const import ( + _CACHE_CLEANUP_INTERVAL, + _CHECK_TIME, + _CLASS_IN, + _CLASS_UNIQUE, + _FLAGS_AA, + _FLAGS_QR_QUERY, + _FLAGS_QR_RESPONSE, + _MAX_MSG_ABSOLUTE, + _MDNS_ADDR, + _MDNS_ADDR6, + _MDNS_PORT, + _ONE_SECOND, + _REGISTER_TIME, + _TYPE_PTR, + _UNREGISTER_TIME, +) + +_TC_DELAY_RANDOM_INTERVAL = (400, 500) +# The maximum amont of time to delay a multicast +# response in order to aggregate answers +_AGGREGATION_DELAY = 500 # ms +# The maximum amont of time to delay a multicast +# response in order to aggregate answers after +# it has already been delayed to protect the network +# from excessive traffic. We use a shorter time +# window here as we want to _try_ to answer all +# queries in under 1350ms while protecting +# the network from excessive traffic to ensure +# a service info request with two questions +# can be answered in the default timeout of +# 3000ms +_PROTECTED_AGGREGATION_DELAY = 200 # ms + +_CLOSE_TIMEOUT = 3000 # ms +_REGISTER_BROADCASTS = 3 + + +class AsyncEngine: + """An engine wraps sockets in the event loop.""" + + def __init__( + self, + zeroconf: 'Zeroconf', + listen_socket: Optional[socket.socket], + respond_sockets: List[socket.socket], + ) -> None: + self.loop: Optional[asyncio.AbstractEventLoop] = None + self.zc = zeroconf + self.protocols: List[AsyncListener] = [] + self.readers: List[asyncio.DatagramTransport] = [] + self.senders: List[asyncio.DatagramTransport] = [] + self._listen_socket = listen_socket + self._respond_sockets = respond_sockets + self._cleanup_timer: Optional[asyncio.TimerHandle] = None + self._running_event: Optional[asyncio.Event] = None + + def setup(self, loop: asyncio.AbstractEventLoop, loop_thread_ready: Optional[threading.Event]) -> None: + """Set up the instance.""" + self.loop = loop + self._running_event = asyncio.Event() + self.loop.create_task(self._async_setup(loop_thread_ready)) + + async def _async_setup(self, loop_thread_ready: Optional[threading.Event]) -> None: + """Set up the instance.""" + assert self.loop is not None + self._cleanup_timer = self.loop.call_later( + millis_to_seconds(_CACHE_CLEANUP_INTERVAL), self._async_cache_cleanup + ) + await self._async_create_endpoints() + assert self._running_event is not None + self._running_event.set() + if loop_thread_ready: + loop_thread_ready.set() + + async def async_wait_for_start(self) -> None: + """Wait for start up.""" + assert self._running_event is not None + await self._running_event.wait() + + async def _async_create_endpoints(self) -> None: + """Create endpoints to send and receive.""" + assert self.loop is not None + loop = self.loop + reader_sockets = [] + sender_sockets = [] + if self._listen_socket: + reader_sockets.append(self._listen_socket) + for s in self._respond_sockets: + if s not in reader_sockets: + reader_sockets.append(s) + sender_sockets.append(s) + + for s in reader_sockets: + transport, protocol = await loop.create_datagram_endpoint(lambda: AsyncListener(self.zc), sock=s) + self.protocols.append(cast(AsyncListener, protocol)) + self.readers.append(cast(asyncio.DatagramTransport, transport)) + if s in sender_sockets: + self.senders.append(cast(asyncio.DatagramTransport, transport)) + + def _async_cache_cleanup(self) -> None: + """Periodic cache cleanup.""" + now = current_time_millis() + self.zc.question_history.async_expire(now) + self.zc.record_manager.async_updates( + now, [RecordUpdate(record, None) for record in self.zc.cache.async_expire(now)] + ) + self.zc.record_manager.async_updates_complete() + assert self.loop is not None + self._cleanup_timer = self.loop.call_later( + millis_to_seconds(_CACHE_CLEANUP_INTERVAL), self._async_cache_cleanup + ) + + async def _async_close(self) -> None: + """Cancel and wait for the cleanup task to finish.""" + self._async_shutdown() + await asyncio.sleep(0) # flush out any call soons + assert self._cleanup_timer is not None + self._cleanup_timer.cancel() + + def _async_shutdown(self) -> None: + """Shutdown transports and sockets.""" + for transport in itertools.chain(self.senders, self.readers): + transport.close() + + def close(self) -> None: + """Close from sync context.""" + assert self.loop is not None + # Guard against Zeroconf.close() being called from the eventloop + if get_running_loop() == self.loop: + self._async_shutdown() + return + if not self.loop.is_running(): + return + run_coro_with_timeout(self._async_close(), self.loop, _CLOSE_TIMEOUT) + + +class AsyncListener(asyncio.Protocol, QuietLogger): + + """A Listener is used by this module to listen on the multicast + group to which DNS messages are sent, allowing the implementation + to cache information as it arrives. + + It requires registration with an Engine object in order to have + the read() method called when a socket is available for reading.""" + + __slots__ = ('zc', 'data', 'last_time', 'transport', 'sock_description', '_deferred', '_timers') + + def __init__(self, zc: 'Zeroconf') -> None: + self.zc = zc + self.data: Optional[bytes] = None + self.last_time: float = 0 + self.transport: Optional[asyncio.DatagramTransport] = None + self.sock_description: Optional[str] = None + self._deferred: Dict[str, List[DNSIncoming]] = {} + self._timers: Dict[str, asyncio.TimerHandle] = {} + super().__init__() + + def suppress_duplicate_packet(self, data: bytes, now: float) -> bool: + """Suppress duplicate packet if the last one was the same in the last second.""" + if self.data == data and (now - 1000) < self.last_time: + return True + self.data = data + self.last_time = now + return False + + def datagram_received( + self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]] + ) -> None: + assert self.transport is not None + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () + data_len = len(data) + + if len(addrs) == 2: + # https://github.com/python/mypy/issues/1178 + addr, port = addrs # type: ignore + scope = None + else: + # https://github.com/python/mypy/issues/1178 + addr, port, flow, scope = addrs # type: ignore + log.debug('IPv6 scope_id %d associated to the receiving interface', scope) + v6_flow_scope = (flow, scope) + + now = current_time_millis() + if self.suppress_duplicate_packet(data, now): + # Guard against duplicate packets + log.debug( + 'Ignoring duplicate message received from %r:%r [socket %s] (%d bytes) as [%r]', + addr, + port, + self.sock_description, + data_len, + data, + ) + return + + if data_len > _MAX_MSG_ABSOLUTE: + # Guard against oversized packets to ensure bad implementations cannot overwhelm + # the system. + log.debug( + "Discarding incoming packet with length %s, which is larger " + "than the absolute maximum size of %s", + data_len, + _MAX_MSG_ABSOLUTE, + ) + return + + msg = DNSIncoming(data, (addr, port), scope, now) + if msg.valid: + log.debug( + 'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]', + addr, + port, + self.sock_description, + msg, + data_len, + data, + ) + else: + log.debug( + 'Received from %r:%r [socket %s]: (%d bytes) [%r]', + addr, + port, + self.sock_description, + data_len, + data, + ) + return + + if not msg.is_query(): + self.zc.handle_response(msg) + return + + self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope) + + def handle_query_or_defer( + self, + msg: DNSIncoming, + addr: str, + port: int, + transport: asyncio.DatagramTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: + """Deal with incoming query packets. Provides a response if + possible.""" + if not msg.truncated: + self._respond_query(msg, addr, port, transport, v6_flow_scope) + return + + deferred = self._deferred.setdefault(addr, []) + # If we get the same packet we ignore it + for incoming in reversed(deferred): + if incoming.data == msg.data: + return + deferred.append(msg) + delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) + assert self.zc.loop is not None + self._cancel_any_timers_for_addr(addr) + self._timers[addr] = self.zc.loop.call_later( + delay, self._respond_query, None, addr, port, transport, v6_flow_scope + ) + + def _cancel_any_timers_for_addr(self, addr: str) -> None: + """Cancel any future truncated packet timers for the address.""" + if addr in self._timers: + self._timers.pop(addr).cancel() + + def _respond_query( + self, + msg: Optional[DNSIncoming], + addr: str, + port: int, + transport: asyncio.DatagramTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: + """Respond to a query and reassemble any truncated deferred packets.""" + self._cancel_any_timers_for_addr(addr) + packets = self._deferred.pop(addr, []) + if msg: + packets.append(msg) + + self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope) + + def error_received(self, exc: Exception) -> None: + """Likely socket closed or IPv6.""" + # We preformat the message string with the socket as we want + # log_exception_once to log a warrning message once PER EACH + # different socket in case there are problems with multiple + # sockets + msg_str = f"Error with socket {self.sock_description}): %s" + self.log_exception_once(exc, msg_str, exc) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self.transport = cast(asyncio.DatagramTransport, transport) + sock_name = self.transport.get_extra_info('sockname') + sock_fileno = self.transport.get_extra_info('socket').fileno() + self.sock_description = f"{sock_fileno} ({sock_name})" + + def connection_lost(self, exc: Optional[Exception]) -> None: + """Handle connection lost.""" + + +def async_send_with_transport( + log_debug: bool, + transport: asyncio.DatagramTransport, + packet: bytes, + packet_num: int, + out: DNSOutgoing, + addr: Optional[str], + port: int, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), +) -> None: + s = transport.get_extra_info('socket') + ipv6_socket = s.family == socket.AF_INET6 + if addr is None: + real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR + else: + real_addr = addr + if not can_send_to(ipv6_socket, real_addr): + return + if log_debug: + log.debug( + 'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...', + real_addr, + port or _MDNS_PORT, + s.fileno(), + transport.get_extra_info('sockname'), + len(packet), + packet_num + 1, + out, + packet, + ) + # Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6 + # address tuple: https://docs.python.org/3.6/library/socket.html#socket-families + if ipv6_socket and not v6_flow_scope: + _, _, sock_flowinfo, sock_scopeid = s.getsockname() + v6_flow_scope = (sock_flowinfo, sock_scopeid) + transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) + + +class Zeroconf(QuietLogger): + + """Implementation of Zeroconf Multicast DNS Service Discovery + + Supports registration, unregistration, queries and browsing. + """ + + def __init__( + self, + interfaces: InterfacesType = InterfaceChoice.All, + unicast: bool = False, + ip_version: Optional[IPVersion] = None, + apple_p2p: bool = False, + ) -> None: + """Creates an instance of the Zeroconf class, establishing + multicast communications, listening and reaping threads. + + :param interfaces: :class:`InterfaceChoice` or a list of IP addresses + (IPv4 and IPv6) and interface indexes (IPv6 only). + + IPv6 notes for non-POSIX systems: + * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` + on Python versions before 3.8. + + Also listening on loopback (``::1``) doesn't work, use a real address. + :param ip_version: IP versions to support. If `choice` is a list, the default is detected + from it. Otherwise defaults to V4 only for backward compatibility. + :param apple_p2p: use AWDL interface (only macOS) + """ + if ip_version is None: + ip_version = autodetect_ip_version(interfaces) + + self.done = False + + if apple_p2p and sys.platform != 'darwin': + raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.') + + self.unicast = unicast + listen_socket, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p) + log.debug('Listen socket %s, respond sockets %s', listen_socket, respond_sockets) + + self.engine = AsyncEngine(self, listen_socket, respond_sockets) + + self.browsers: Dict[ServiceListener, ServiceBrowser] = {} + self.registry = ServiceRegistry() + self.cache = DNSCache() + self.question_history = QuestionHistory() + self.query_handler = QueryHandler(self.registry, self.cache, self.question_history) + self.record_manager = RecordManager(self) + + self.notify_event: Optional[asyncio.Event] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None + self._loop_thread: Optional[threading.Thread] = None + + self._out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY) + self._out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY) + + self.start() + + def start(self) -> None: + """Start Zeroconf.""" + self.loop = get_running_loop() + if self.loop: + self.notify_event = asyncio.Event() + self.engine.setup(self.loop, None) + return + self._start_thread() + + def _start_thread(self) -> None: + """Start a thread with a running event loop.""" + loop_thread_ready = threading.Event() + + def _run_loop() -> None: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.notify_event = asyncio.Event() + self.engine.setup(self.loop, loop_thread_ready) + self.loop.run_forever() + + self._loop_thread = threading.Thread(target=_run_loop, daemon=True) + self._loop_thread.start() + loop_thread_ready.wait() + + async def async_wait_for_start(self) -> None: + """Wait for start up.""" + await self.engine.async_wait_for_start() + + @property + def listeners(self) -> List[RecordUpdateListener]: + return self.record_manager.listeners + + async def async_wait(self, timeout: float) -> None: + """Calling task waits for a given number of milliseconds or until notified.""" + assert self.notify_event is not None + await wait_event_or_timeout(self.notify_event, timeout=millis_to_seconds(timeout)) + + def notify_all(self) -> None: + """Notifies all waiting threads and notify listeners.""" + assert self.loop is not None + self.loop.call_soon_threadsafe(self.async_notify_all) + + def async_notify_all(self) -> None: + """Schedule an async_notify_all.""" + assert self.notify_event is not None + self.notify_event.set() + self.notify_event.clear() + + def get_service_info( + self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None + ) -> Optional[ServiceInfo]: + """Returns network's service information for a particular + name and type, or None if no service matches by the timeout, + which defaults to 3 seconds.""" + info = ServiceInfo(type_, name) + if info.request(self, timeout, question_type): + return info + return None + + def add_service_listener(self, type_: str, listener: ServiceListener) -> None: + """Adds a listener for a particular service type. This object + will then have its add_service and remove_service methods called when + services of that type become available and unavailable.""" + self.remove_service_listener(listener) + self.browsers[listener] = ServiceBrowser(self, type_, listener) + + def remove_service_listener(self, listener: ServiceListener) -> None: + """Removes a listener from the set that is currently listening.""" + if listener in self.browsers: + self.browsers[listener].cancel() + del self.browsers[listener] + + def remove_all_service_listeners(self) -> None: + """Removes a listener from the set that is currently listening.""" + for listener in list(self.browsers): + self.remove_service_listener(listener) + + def register_service( + self, + info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, + cooperating_responders: bool = False, + ) -> None: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`).""" + assert self.loop is not None + run_coro_with_timeout( + await_awaitable( + self.async_register_service(info, ttl, allow_name_change, cooperating_responders) + ), + self.loop, + _REGISTER_TIME * _REGISTER_BROADCASTS, + ) + + async def async_register_service( + self, + info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, + cooperating_responders: bool = False, + ) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`).""" + if ttl is not None: + # ttl argument is used to maintain backward compatibility + # Setting TTLs via ServiceInfo is preferred + info.host_ttl = ttl + info.other_ttl = ttl + + await self.async_wait_for_start() + await self.async_check_service(info, allow_name_change, cooperating_responders) + self.registry.async_add(info) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + + def update_service(self, info: ServiceInfo) -> None: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service.""" + assert self.loop is not None + run_coro_with_timeout( + await_awaitable(self.async_update_service(info)), self.loop, _REGISTER_TIME * _REGISTER_BROADCASTS + ) + + async def async_update_service(self, info: ServiceInfo) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service.""" + self.registry.async_update(info) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + + async def _async_broadcast_service( + self, + info: ServiceInfo, + interval: int, + ttl: Optional[int], + broadcast_addresses: bool = True, + ) -> None: + """Send a broadcasts to announce a service at intervals.""" + for i in range(_REGISTER_BROADCASTS): + if i != 0: + await asyncio.sleep(millis_to_seconds(interval)) + self.async_send(self.generate_service_broadcast(info, ttl, broadcast_addresses)) + + def generate_service_broadcast( + self, + info: ServiceInfo, + ttl: Optional[int], + broadcast_addresses: bool = True, + ) -> DNSOutgoing: + """Generate a broadcast to announce a service.""" + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + self._add_broadcast_answer(out, info, ttl, broadcast_addresses) + return out + + def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use + """Generate a query to lookup a service.""" + out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) + # https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 + # Because of the mDNS multicast rate-limiting + # rules, the probes SHOULD be sent as "QU" questions with the unicast- + # response bit set, to allow a defending host to respond immediately + # via unicast, instead of potentially having to wait before replying + # via multicast. + # + # _CLASS_UNIQUE is the "QU" bit + out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE)) + out.add_authorative_answer(info.dns_pointer(created=current_time_millis())) + return out + + def _add_broadcast_answer( # pylint: disable=no-self-use + self, + out: DNSOutgoing, + info: ServiceInfo, + override_ttl: Optional[int], + broadcast_addresses: bool = True, + ) -> None: + """Add answers to broadcast a service.""" + now = current_time_millis() + other_ttl = info.other_ttl if override_ttl is None else override_ttl + host_ttl = info.host_ttl if override_ttl is None else override_ttl + out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl, created=now), 0) + out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0) + out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0) + if broadcast_addresses: + for dns_address in info.dns_addresses(override_ttl=host_ttl, created=now): + out.add_answer_at_time(dns_address, 0) + + def unregister_service(self, info: ServiceInfo) -> None: + """Unregister a service.""" + assert self.loop is not None + run_coro_with_timeout( + self.async_unregister_service(info), self.loop, _UNREGISTER_TIME * _REGISTER_BROADCASTS + ) + + async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: + """Unregister a service.""" + self.registry.async_remove(info) + # If another server uses the same addresses, we do not want to send + # goodbye packets for the address records + + entries = self.registry.async_get_infos_server(info.server) + broadcast_addresses = not bool(entries) + return asyncio.ensure_future( + self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses) + ) + + def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: + """Generate a DNSOutgoing goodbye for all services and remove them from the registry.""" + service_infos = self.registry.async_get_service_infos() + if not service_infos: + return None + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + for info in service_infos: + self._add_broadcast_answer(out, info, 0) + self.registry.async_remove(service_infos) + return out + + async def async_unregister_all_services(self) -> None: + """Unregister all registered services. + + Unlike async_register_service and async_unregister_service, this + method does not return a future and is always expected to be + awaited since its only called at shutdown. + """ + # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1 + out = self.generate_unregister_all_services() + if not out: + return + for i in range(_REGISTER_BROADCASTS): + if i != 0: + await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME)) + self.async_send(out) + + def unregister_all_services(self) -> None: + """Unregister all registered services.""" + assert self.loop is not None + run_coro_with_timeout( + self.async_unregister_all_services(), self.loop, _UNREGISTER_TIME * _REGISTER_BROADCASTS + ) + + async def async_check_service( + self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False + ) -> None: + """Checks the network for a unique service name, modifying the + ServiceInfo passed in if it is not unique.""" + instance_name = instance_name_from_service_info(info) + if cooperating_responders: + return + next_instance_number = 2 + next_time = now = current_time_millis() + i = 0 + while i < _REGISTER_BROADCASTS: + # check for a name conflict + while self.cache.current_entry_with_name_and_alias(info.type, info.name): + if not allow_name_change: + raise NonUniqueNameException + + # change the name and look for a conflict + info.name = f'{instance_name}-{next_instance_number}.{info.type}' + next_instance_number += 1 + service_type_name(info.name) + next_time = now + i = 0 + + if now < next_time: + await self.async_wait(next_time - now) + now = current_time_millis() + continue + + self.async_send(self.generate_service_query(info)) + i += 1 + next_time += _CHECK_TIME + + def add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s). + + This function is threadsafe + """ + assert self.loop is not None + self.loop.call_soon_threadsafe(self.record_manager.async_add_listener, listener, question) + + def remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener. + + This function is threadsafe + """ + assert self.loop is not None + self.loop.call_soon_threadsafe(self.record_manager.async_remove_listener, listener) + + def async_add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s). + + This function is not threadsafe and must be called in the eventloop. + """ + self.record_manager.async_add_listener(listener, question) + + def async_remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener. + + This function is not threadsafe and must be called in the eventloop. + """ + self.record_manager.async_remove_listener(listener) + + def handle_response(self, msg: DNSIncoming) -> None: + """Deal with incoming response packets. All answers + are held in the cache, and listeners are notified.""" + self.record_manager.async_updates_from_response(msg) + + def handle_assembled_query( + self, + packets: List[DNSIncoming], + addr: str, + port: int, + transport: asyncio.DatagramTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: + """Respond to a (re)assembled query. + + If the protocol recieved packets with the TC bit set, it will + wait a bit for the rest of the packets and only call + handle_assembled_query once it has a complete set of packets + or the timer expires. If the TC bit is not set, a single + packet will be in packets. + """ + now = packets[0].now + ucast_source = port != _MDNS_PORT + question_answers = self.query_handler.async_response(packets, ucast_source) + if question_answers.ucast: + questions = packets[0].questions + id_ = packets[0].id + out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_) + # When sending unicast, only send back the reply + # via the same socket that it was recieved from + # as we know its reachable from that socket + self.async_send(out, addr, port, v6_flow_scope, transport) + if question_answers.mcast_now: + self.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now)) + if question_answers.mcast_aggregate: + self._out_queue.async_add(now, question_answers.mcast_aggregate) + if question_answers.mcast_aggregate_last_second: + # https://datatracker.ietf.org/doc/html/rfc6762#section-14 + # If we broadcast it in the last second, we have to delay + # at least a second before we send it again + self._out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second) + + def send( + self, + out: DNSOutgoing, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + transport: Optional[asyncio.DatagramTransport] = None, + ) -> None: + """Sends an outgoing packet threadsafe.""" + assert self.loop is not None + self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope, transport) + + def async_send( + self, + out: DNSOutgoing, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + transport: Optional[asyncio.DatagramTransport] = None, + ) -> None: + """Sends an outgoing packet.""" + if self.done: + return + + # If no transport is specified, we send to all the ones + # with the same address family + transports = [transport] if transport else self.engine.senders + log_debug = log.isEnabledFor(logging.DEBUG) + + for packet_num, packet in enumerate(out.packets()): + if len(packet) > _MAX_MSG_ABSOLUTE: + self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) + return + for send_transport in transports: + async_send_with_transport( + log_debug, send_transport, packet, packet_num, out, addr, port, v6_flow_scope + ) + + def _close(self) -> None: + """Set global done and remove all service listeners.""" + if self.done: + return + self.remove_all_service_listeners() + self.done = True + + def _shutdown_threads(self) -> None: + """Shutdown any threads.""" + self.notify_all() + if not self._loop_thread: + return + assert self.loop is not None + shutdown_loop(self.loop) + self._loop_thread.join() + self._loop_thread = None + + def close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries. + + This method is idempotent and irreversible. + """ + assert self.loop is not None + if self.loop.is_running(): + if self.loop == get_running_loop(): + log.warning( + "unregister_all_services skipped as it does blocking i/o; use AsyncZeroconf with asyncio" + ) + else: + self.unregister_all_services() + self._close() + self.engine.close() + self._shutdown_threads() + + async def _async_close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries. + + This method is idempotent and irreversible. + + This call only intended to be used by AsyncZeroconf + + Callers are responsible for unregistering all services + before calling this function + """ + self._close() + await self.engine._async_close() # pylint: disable=protected-access + self._shutdown_threads() + + def __enter__(self) -> 'Zeroconf': + return self + + def __exit__( # pylint: disable=useless-return + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self.close() + return None diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py new file mode 100644 index 00000000..a551a7da --- /dev/null +++ b/zeroconf/_dns.py @@ -0,0 +1,534 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +import socket +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union, cast + +from ._exceptions import AbstractMethodException +from ._utils.net import _is_v6_address +from ._utils.time import current_time_millis, millis_to_seconds +from .const import ( + _CLASSES, + _CLASS_MASK, + _CLASS_UNIQUE, + _TYPES, + _TYPE_ANY, +) + +_LEN_BYTE = 1 +_LEN_SHORT = 2 +_LEN_INT = 4 + +_BASE_MAX_SIZE = _LEN_SHORT + _LEN_SHORT + _LEN_INT + _LEN_SHORT # type # class # ttl # length +_NAME_COMPRESSION_MIN_SIZE = _LEN_BYTE * 2 + +_EXPIRE_FULL_TIME_MS = 1000 +_EXPIRE_STALE_TIME_MS = 500 +_RECENT_TIME_MS = 250 + + +if TYPE_CHECKING: + from ._protocol.incoming import DNSIncoming + from ._protocol.outgoing import DNSOutgoing + + +@enum.unique +class DNSQuestionType(enum.Enum): + """An MDNS question type. + + "QU" - questions requesting unicast responses + "QM" - questions requesting multicast responses + https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 + """ + + QU = 1 + QM = 2 + + +def dns_entry_matches(record: 'DNSEntry', key: str, type_: int, class_: int) -> bool: + return key == record.key and type_ == record.type and class_ == record.class_ + + +class DNSEntry: + + """A DNS entry""" + + __slots__ = ('key', 'name', 'type', 'class_', 'unique') + + def __init__(self, name: str, type_: int, class_: int) -> None: + self.key = name.lower() + self.name = name + self.type = type_ + self.class_ = class_ & _CLASS_MASK + self.unique = (class_ & _CLASS_UNIQUE) != 0 + + def __eq__(self, other: Any) -> bool: + """Equality test on key (lowercase name), type, and class""" + return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry) + + @staticmethod + def get_class_(class_: int) -> str: + """Class accessor""" + return _CLASSES.get(class_, f"?({class_})") + + @staticmethod + def get_type(t: int) -> str: + """Type accessor""" + return _TYPES.get(t, f"?({t})") + + def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str: + """String representation with additional information""" + return "{}[{},{}{},{}]{}".format( + hdr, + self.get_type(self.type), + self.get_class_(self.class_), + "-unique" if self.unique else "", + self.name, + "=%s" % cast(Any, other) if other is not None else "", + ) + + +class DNSQuestion(DNSEntry): + + """A DNS question entry""" + + __slots__ = ('_hash',) + + def __init__(self, name: str, type_: int, class_: int) -> None: + super().__init__(name, type_, class_) + self._hash = hash((self.key, type_, self.class_)) + + def answered_by(self, rec: 'DNSRecord') -> bool: + """Returns true if the question is answered by the record""" + return self.class_ == rec.class_ and self.type in (rec.type, _TYPE_ANY) and self.name == rec.name + + def __hash__(self) -> int: + return self._hash + + def __eq__(self, other: Any) -> bool: + """Tests equality on dns question.""" + return isinstance(other, DNSQuestion) and dns_entry_matches(other, self.key, self.type, self.class_) + + @property + def max_size(self) -> int: + """Maximum size of the question in the packet.""" + return len(self.name.encode('utf-8')) + _LEN_BYTE + _LEN_SHORT + _LEN_SHORT # type # class + + @property + def unicast(self) -> bool: + """Returns true if the QU (not QM) is set. + + unique shares the same mask as the one + used for unicast. + """ + return self.unique + + @unicast.setter + def unicast(self, value: bool) -> None: + """Sets the QU bit (not QM).""" + self.unique = value + + def __repr__(self) -> str: + """String representation""" + return "{}[question,{},{},{}]".format( + self.get_type(self.type), + "QU" if self.unicast else "QM", + self.get_class_(self.class_), + self.name, + ) + + +class DNSRecord(DNSEntry): + + """A DNS record - like a DNS entry, but has a TTL""" + + __slots__ = ('ttl', 'created') + + # TODO: Switch to just int ttl + def __init__( + self, name: str, type_: int, class_: int, ttl: Union[float, int], created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_) + self.ttl = ttl + self.created = created or current_time_millis() + + def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use + """Abstract method""" + raise AbstractMethodException + + def suppressed_by(self, msg: 'DNSIncoming') -> bool: + """Returns true if any answer in a message can suffice for the + information held in this record.""" + return any(self.suppressed_by_answer(record) for record in msg.answers) + + def suppressed_by_answer(self, other: 'DNSRecord') -> bool: + """Returns true if another record has same name, type and class, + and if its TTL is at least half of this record's.""" + return self == other and other.ttl > (self.ttl / 2) + + def get_expiration_time(self, percent: int) -> float: + """Returns the time at which this record will have expired + by a certain percentage.""" + return self.created + (percent * self.ttl * 10) + + # TODO: Switch to just int here + def get_remaining_ttl(self, now: float) -> Union[int, float]: + """Returns the remaining TTL in seconds.""" + return max(0, millis_to_seconds((self.created + (_EXPIRE_FULL_TIME_MS * self.ttl)) - now)) + + def is_expired(self, now: float) -> bool: + """Returns true if this record has expired.""" + return self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) <= now + + def is_stale(self, now: float) -> bool: + """Returns true if this record is at least half way expired.""" + return self.created + (_EXPIRE_STALE_TIME_MS * self.ttl) <= now + + def is_recent(self, now: float) -> bool: + """Returns true if the record more than one quarter of its TTL remaining.""" + return self.created + (_RECENT_TIME_MS * self.ttl) > now + + def reset_ttl(self, other: 'DNSRecord') -> None: + """Sets this record's TTL and created time to that of + another record.""" + self.set_created_ttl(other.created, other.ttl) + + def set_created_ttl(self, created: float, ttl: Union[float, int]) -> None: + """Set the created and ttl of a record.""" + self.created = created + self.ttl = ttl + + def write(self, out: 'DNSOutgoing') -> None: # pylint: disable=no-self-use + """Abstract method""" + raise AbstractMethodException + + def to_string(self, other: Union[bytes, str]) -> str: + """String representation with additional information""" + arg = f"{self.ttl}/{int(self.get_remaining_ttl(current_time_millis()))},{cast(Any, other)}" + return DNSEntry.entry_to_string(self, "record", arg) + + +class DNSAddress(DNSRecord): + + """A DNS address record""" + + __slots__ = ('_hash', 'address', 'scope_id') + + def __init__( + self, + name: str, + type_: int, + class_: int, + ttl: int, + address: bytes, + *, + scope_id: Optional[int] = None, + created: Optional[float] = None, + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.address = address + self.scope_id = scope_id + self._hash = hash((self.key, type_, self.class_, address, scope_id)) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_string(self.address) + + def __eq__(self, other: Any) -> bool: + """Tests equality on address""" + return ( + isinstance(other, DNSAddress) + and self.address == other.address + and self.scope_id == other.scope_id + and dns_entry_matches(other, self.key, self.type, self.class_) + ) + + def __hash__(self) -> int: + """Hash to compare like DNSAddresses.""" + return self._hash + + def __repr__(self) -> str: + """String representation""" + try: + return self.to_string( + socket.inet_ntop( + socket.AF_INET6 if _is_v6_address(self.address) else socket.AF_INET, self.address + ) + ) + except (ValueError, OSError): + return self.to_string(str(self.address)) + + +class DNSHinfo(DNSRecord): + + """A DNS host information record""" + + __slots__ = ('_hash', 'cpu', 'os') + + def __init__( + self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.cpu = cpu + self.os = os + self._hash = hash((self.key, type_, self.class_, cpu, os)) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_character_string(self.cpu.encode('utf-8')) + out.write_character_string(self.os.encode('utf-8')) + + def __eq__(self, other: Any) -> bool: + """Tests equality on cpu and os""" + return ( + isinstance(other, DNSHinfo) + and self.cpu == other.cpu + and self.os == other.os + and dns_entry_matches(other, self.key, self.type, self.class_) + ) + + def __hash__(self) -> int: + """Hash to compare like DNSHinfo.""" + return self._hash + + def __repr__(self) -> str: + """String representation""" + return self.to_string(self.cpu + " " + self.os) + + +class DNSPointer(DNSRecord): + + """A DNS pointer record""" + + __slots__ = ('_hash', 'alias') + + def __init__( + self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.alias = alias + self._hash = hash((self.key, type_, self.class_, alias)) + + @property + def max_size_compressed(self) -> int: + """Maximum size of the record in the packet assuming the name has been compressed.""" + return ( + _BASE_MAX_SIZE + + _NAME_COMPRESSION_MIN_SIZE + + (len(self.alias) - len(self.name)) + + _NAME_COMPRESSION_MIN_SIZE + ) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_name(self.alias) + + def __eq__(self, other: Any) -> bool: + """Tests equality on alias""" + return ( + isinstance(other, DNSPointer) + and self.alias == other.alias + and dns_entry_matches(other, self.key, self.type, self.class_) + ) + + def __hash__(self) -> int: + """Hash to compare like DNSPointer.""" + return self._hash + + def __repr__(self) -> str: + """String representation""" + return self.to_string(self.alias) + + +class DNSText(DNSRecord): + + """A DNS text record""" + + __slots__ = ('_hash', 'text') + + def __init__( + self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None + ) -> None: + assert isinstance(text, (bytes, type(None))) + super().__init__(name, type_, class_, ttl, created) + self.text = text + self._hash = hash((self.key, type_, self.class_, text)) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_string(self.text) + + def __hash__(self) -> int: + """Hash to compare like DNSText.""" + return self._hash + + def __eq__(self, other: Any) -> bool: + """Tests equality on text""" + return ( + isinstance(other, DNSText) + and self.text == other.text + and dns_entry_matches(other, self.key, self.type, self.class_) + ) + + def __repr__(self) -> str: + """String representation""" + if len(self.text) > 10: + return self.to_string(self.text[:7]) + "..." + return self.to_string(self.text) + + +class DNSService(DNSRecord): + + """A DNS service record""" + + __slots__ = ('_hash', 'priority', 'weight', 'port', 'server') + + def __init__( + self, + name: str, + type_: int, + class_: int, + ttl: Union[float, int], + priority: int, + weight: int, + port: int, + server: str, + created: Optional[float] = None, + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.priority = priority + self.weight = weight + self.port = port + self.server = server + self._hash = hash((self.key, type_, self.class_, priority, weight, port, server)) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_short(self.priority) + out.write_short(self.weight) + out.write_short(self.port) + out.write_name(self.server) + + def __eq__(self, other: Any) -> bool: + """Tests equality on priority, weight, port and server""" + return ( + isinstance(other, DNSService) + and self.priority == other.priority + and self.weight == other.weight + and self.port == other.port + and self.server == other.server + and dns_entry_matches(other, self.key, self.type, self.class_) + ) + + def __hash__(self) -> int: + """Hash to compare like DNSService.""" + return self._hash + + def __repr__(self) -> str: + """String representation""" + return self.to_string(f"{self.server}:{self.port}") + + +class DNSNsec(DNSRecord): + + """A DNS NSEC record""" + + __slots__ = ('_hash', 'next_name', 'rdtypes') + + def __init__( + self, + name: str, + type_: int, + class_: int, + ttl: int, + next_name: str, + rdtypes: List[int], + created: Optional[float] = None, + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.next_name = next_name + self.rdtypes = sorted(rdtypes) + self._hash = hash((self.key, type_, self.class_, next_name, *self.rdtypes)) + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet.""" + bitmap = bytearray(b'\0' * 32) + for rdtype in self.rdtypes: + if rdtype > 255: # mDNS only supports window 0 + continue + offset = rdtype % 256 + byte = offset // 8 + total_octets = byte + 1 + bitmap[byte] |= 0x80 >> (offset % 8) + out_bytes = bytes(bitmap[0:total_octets]) + out.write_name(self.next_name) + out.write_short(0) + out.write_short(len(out_bytes)) + out.write_string(out_bytes) + + def __eq__(self, other: Any) -> bool: + """Tests equality on cpu and os""" + return ( + isinstance(other, DNSNsec) + and self.next_name == other.next_name + and self.rdtypes == other.rdtypes + and dns_entry_matches(other, self.key, self.type, self.class_) + ) + + def __hash__(self) -> int: + """Hash to compare like DNSNSec.""" + return self._hash + + def __repr__(self) -> str: + """String representation""" + return self.to_string( + self.next_name + "," + "|".join([self.get_type(type_) for type_ in self.rdtypes]) + ) + + +class DNSRRSet: + """A set of dns records independent of the ttl.""" + + __slots__ = ('_records', '_lookup') + + def __init__(self, records: Iterable[DNSRecord]) -> None: + """Create an RRset from records.""" + self._records = records + self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None + + @property + def lookup(self) -> Dict[DNSRecord, DNSRecord]: + if self._lookup is None: + # Build the hash table so we can lookup the record independent of the ttl + self._lookup = {record: record for record in self._records} + return self._lookup + + def suppresses(self, record: DNSRecord) -> bool: + """Returns true if any answer in the rrset can suffice for the + information held in this record.""" + other = self.lookup.get(record) + return bool(other and other.ttl > (record.ttl / 2)) + + def __contains__(self, record: DNSRecord) -> bool: + """Returns true if the rrset contains the record.""" + return record in self.lookup diff --git a/zeroconf/_exceptions.py b/zeroconf/_exceptions.py new file mode 100644 index 00000000..02771140 --- /dev/null +++ b/zeroconf/_exceptions.py @@ -0,0 +1,49 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + + +class Error(Exception): + pass + + +class IncomingDecodeError(Error): + pass + + +class NonUniqueNameException(Error): + pass + + +class NamePartTooLongException(Error): + pass + + +class AbstractMethodException(Error): + pass + + +class BadTypeInNameException(Error): + pass + + +class ServiceNameAlreadyRegistered(Error): + pass diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py new file mode 100644 index 00000000..b4c31e2d --- /dev/null +++ b/zeroconf/_handlers.py @@ -0,0 +1,597 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import itertools +import random +from collections import deque +from typing import Dict, Iterable, List, NamedTuple, Optional, Set, TYPE_CHECKING, Tuple, Union, cast + +from ._cache import DNSCache, _UniqueRecordsType +from ._dns import DNSAddress, DNSNsec, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord +from ._history import QuestionHistory +from ._logger import log +from ._protocol.incoming import DNSIncoming +from ._protocol.outgoing import DNSOutgoing +from ._services.info import ServiceInfo +from ._services.registry import ServiceRegistry +from ._updates import RecordUpdate, RecordUpdateListener +from ._utils.time import current_time_millis, millis_to_seconds +from .const import ( + _CLASS_IN, + _CLASS_UNIQUE, + _DNS_OTHER_TTL, + _DNS_PTR_MIN_TTL, + _FLAGS_AA, + _FLAGS_QR_RESPONSE, + _ONE_SECOND, + _SERVICE_TYPE_ENUMERATION_NAME, + _TYPE_A, + _TYPE_AAAA, + _TYPE_ANY, + _TYPE_NSEC, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) + +if TYPE_CHECKING: + from ._core import Zeroconf + + +_AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]] + +_MULTICAST_DELAY_RANDOM_INTERVAL = (20, 120) +_ADDRESS_RECORD_TYPES = {_TYPE_A, _TYPE_AAAA} +_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES} + + +class QuestionAnswers(NamedTuple): + ucast: _AnswerWithAdditionalsType + mcast_now: _AnswerWithAdditionalsType + mcast_aggregate: _AnswerWithAdditionalsType + mcast_aggregate_last_second: _AnswerWithAdditionalsType + + +class AnswerGroup(NamedTuple): + """A group of answers scheduled to be sent at the same time.""" + + send_after: float # Must be sent after this time + send_before: float # Must be sent before this time + answers: _AnswerWithAdditionalsType + + +def _message_is_probe(msg: DNSIncoming) -> bool: + return msg.num_authorities > 0 + + +def construct_nsec_record(name: str, types: List[int], now: float) -> DNSNsec: + """Construct an NSEC record for name and a list of dns types. + + This function should only be used for SRV/A/AAAA records + which have a TTL of _DNS_OTHER_TTL + """ + return DNSNsec(name, _TYPE_NSEC, _CLASS_IN | _CLASS_UNIQUE, _DNS_OTHER_TTL, name, types, created=now) + + +def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing: + """Add answers and additionals to a DNSOutgoing.""" + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=True) + _add_answers_additionals(out, answers) + return out + + +def construct_outgoing_unicast_answers( + answers: _AnswerWithAdditionalsType, ucast_source: bool, questions: List[DNSQuestion], id_: int +) -> DNSOutgoing: + """Add answers and additionals to a DNSOutgoing.""" + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False, id_=id_) + # Adding the questions back when the source is legacy unicast behavior + if ucast_source: + for question in questions: + out.add_question(question) + _add_answers_additionals(out, answers) + return out + + +def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsType) -> None: + # Find additionals and suppress any additionals that are already in answers + sending: Set[DNSRecord] = set(answers.keys()) + # Answers are sorted to group names together to increase the chance + # that similar names will end up in the same packet and can reduce the + # overall size of the outgoing response via name compression + for answer, additionals in sorted(answers.items(), key=lambda kv: kv[0].name): + out.add_answer_at_time(answer, 0) + for additional in additionals: + if additional not in sending: + out.add_additional_answer(additional) + sending.add(additional) + + +def sanitize_incoming_record(record: DNSRecord) -> None: + """Protect zeroconf from records that can cause denial of service. + + We enforce a minimum TTL for PTR records to avoid + ServiceBrowsers generating excessive queries refresh queries. + Apple uses a 15s minimum TTL, however we do not have the same + level of rate limit and safe guards so we use 1/4 of the recommended value. + """ + if record.ttl and record.ttl < _DNS_PTR_MIN_TTL and isinstance(record, DNSPointer): + log.debug( + "Increasing effective ttl of %s to minimum of %s to protect against excessive refreshes.", + record, + _DNS_PTR_MIN_TTL, + ) + record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL) + + +class _QueryResponse: + """A pair for unicast and multicast DNSOutgoing responses.""" + + def __init__(self, cache: DNSCache, msgs: List[DNSIncoming]) -> None: + """Build a query response.""" + self._is_probe = any(_message_is_probe(msg) for msg in msgs) + self._msg = msgs[0] + self._now = self._msg.now + self._cache = cache + self._additionals: _AnswerWithAdditionalsType = {} + self._ucast: Set[DNSRecord] = set() + self._mcast_now: Set[DNSRecord] = set() + self._mcast_aggregate: Set[DNSRecord] = set() + self._mcast_aggregate_last_second: Set[DNSRecord] = set() + + def add_qu_question_response(self, answers: _AnswerWithAdditionalsType) -> None: + """Generate a response to a multicast QU query.""" + for record, additionals in answers.items(): + self._additionals[record] = additionals + if self._is_probe: + self._ucast.add(record) + if not self._has_mcast_within_one_quarter_ttl(record): + self._mcast_now.add(record) + elif not self._is_probe: + self._ucast.add(record) + + def add_ucast_question_response(self, answers: _AnswerWithAdditionalsType) -> None: + """Generate a response to a unicast query.""" + self._additionals.update(answers) + self._ucast.update(answers.keys()) + + def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> None: + """Generate a response to a multicast query.""" + self._additionals.update(answers) + for answer in answers: + if self._is_probe: + self._mcast_now.add(answer) + continue + + if self._has_mcast_record_in_last_second(answer): + self._mcast_aggregate_last_second.add(answer) + elif len(self._msg.questions) == 1 and self._msg.questions[0].type in _RESPOND_IMMEDIATE_TYPES: + self._mcast_now.add(answer) + else: + self._mcast_aggregate.add(answer) + + def _generate_answers_with_additionals(self, rrset: Set[DNSRecord]) -> _AnswerWithAdditionalsType: + """Create answers with additionals from an rrset.""" + return {record: self._additionals[record] for record in rrset} + + def answers( + self, + ) -> QuestionAnswers: + """Return answer sets that will be queued.""" + return QuestionAnswers( + self._generate_answers_with_additionals(self._ucast), + self._generate_answers_with_additionals(self._mcast_now), + self._generate_answers_with_additionals(self._mcast_aggregate), + self._generate_answers_with_additionals(self._mcast_aggregate_last_second), + ) + + def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: + """Check to see if a record has been mcasted recently. + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 + When receiving a question with the unicast-response bit set, a + responder SHOULD usually respond with a unicast packet directed back + to the querier. However, if the responder has not multicast that + record recently (within one quarter of its TTL), then the responder + SHOULD instead multicast the response so as to keep all the peer + caches up to date + """ + maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) + return bool(maybe_entry and maybe_entry.is_recent(self._now)) + + def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: + """Check if an answer was seen in the last second. + Protect the network against excessive packet flooding + https://datatracker.ietf.org/doc/html/rfc6762#section-14 + """ + maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) + return bool(maybe_entry and self._now - maybe_entry.created < _ONE_SECOND) + + +def _get_address_and_nsec_records(service: ServiceInfo, now: float) -> Set[DNSRecord]: + """Build a set of address records and NSEC records for non-present record types.""" + seen_types: Set[int] = set() + records: Set[DNSRecord] = set() + for dns_address in service.dns_addresses(created=now): + seen_types.add(dns_address.type) + records.add(dns_address) + missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types + if missing_types: + records.add(construct_nsec_record(service.server, list(missing_types), now)) + return records + + +class QueryHandler: + """Query the ServiceRegistry.""" + + def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: QuestionHistory) -> None: + """Init the query handler.""" + self.registry = registry + self.cache = cache + self.question_history = question_history + + def _add_service_type_enumeration_query_answers( + self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float + ) -> None: + """Provide an answer to a service type enumeration query. + + https://datatracker.ietf.org/doc/html/rfc6763#section-9 + """ + for stype in self.registry.async_get_types(): + dns_pointer = DNSPointer( + _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now + ) + if not known_answers.suppresses(dns_pointer): + answer_set[dns_pointer] = set() + + def _add_pointer_answers( + self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float + ) -> None: + """Answer PTR/ANY question.""" + for service in self.registry.async_get_infos_type(name): + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.1. + dns_pointer = service.dns_pointer(created=now) + if known_answers.suppresses(dns_pointer): + continue + additionals: Set[DNSRecord] = {service.dns_service(created=now), service.dns_text(created=now)} + additionals |= _get_address_and_nsec_records(service, now) + answer_set[dns_pointer] = additionals + + def _add_address_answers( + self, + name: str, + answer_set: _AnswerWithAdditionalsType, + known_answers: DNSRRSet, + now: float, + type_: int, + ) -> None: + """Answer A/AAAA/ANY question.""" + for service in self.registry.async_get_infos_server(name): + answers: List[DNSAddress] = [] + additionals: Set[DNSRecord] = set() + seen_types: Set[int] = set() + for dns_address in service.dns_addresses(created=now): + seen_types.add(dns_address.type) + if dns_address.type != type_: + additionals.add(dns_address) + elif not known_answers.suppresses(dns_address): + answers.append(dns_address) + missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types + if answers: + if missing_types: + additionals.add(construct_nsec_record(service.server, list(missing_types), now)) + for answer in answers: + answer_set[answer] = additionals + elif type_ in missing_types: + answer_set[construct_nsec_record(service.server, list(missing_types), now)] = set() + + def _answer_question( + self, + question: DNSQuestion, + known_answers: DNSRRSet, + now: float, + ) -> _AnswerWithAdditionalsType: + answer_set: _AnswerWithAdditionalsType = {} + + if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: + self._add_service_type_enumeration_query_answers(answer_set, known_answers, now) + return answer_set + + type_ = question.type + + if type_ in (_TYPE_PTR, _TYPE_ANY): + self._add_pointer_answers(question.name, answer_set, known_answers, now) + + if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): + self._add_address_answers(question.name, answer_set, known_answers, now, type_) + + if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): + service = self.registry.async_get_info_name(question.name) # type: ignore + if service is not None: + if type_ in (_TYPE_SRV, _TYPE_ANY): + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.2. + dns_service = service.dns_service(created=now) + if not known_answers.suppresses(dns_service): + answer_set[dns_service] = _get_address_and_nsec_records(service, now) + if type_ in (_TYPE_TXT, _TYPE_ANY): + dns_text = service.dns_text(created=now) + if not known_answers.suppresses(dns_text): + answer_set[dns_text] = set() + + return answer_set + + def async_response( # pylint: disable=unused-argument + self, msgs: List[DNSIncoming], ucast_source: bool + ) -> QuestionAnswers: + """Deal with incoming query packets. Provides a response if possible. + + This function must be run in the event loop as it is not + threadsafe. + """ + known_answers = DNSRRSet( + itertools.chain.from_iterable(msg.answers for msg in msgs if not _message_is_probe(msg)) + ) + query_res = _QueryResponse(self.cache, msgs) + + for msg in msgs: + for question in msg.questions: + if not question.unicast: + self.question_history.add_question_at_time(question, msg.now, set(known_answers.lookup)) + answer_set = self._answer_question(question, known_answers, msg.now) + if not ucast_source and question.unicast: + query_res.add_qu_question_response(answer_set) + continue + if ucast_source: + query_res.add_ucast_question_response(answer_set) + # We always multicast as well even if its a unicast + # source as long as we haven't done it recently (75% of ttl) + query_res.add_mcast_question_response(answer_set) + + return query_res.answers() + + +class RecordManager: + """Process records into the cache and notify listeners.""" + + def __init__(self, zeroconf: 'Zeroconf') -> None: + """Init the record manager.""" + self.zc = zeroconf + self.cache = zeroconf.cache + self.listeners: List[RecordUpdateListener] = [] + + def async_updates(self, now: float, records: List[RecordUpdate]) -> None: + """Used to notify listeners of new information that has updated + a record. + + This method must be called before the cache is updated. + + This method will be run in the event loop. + """ + for listener in self.listeners: + listener.async_update_records(self.zc, now, records) + + def async_updates_complete(self) -> None: + """Used to notify listeners of new information that has updated + a record. + + This method must be called after the cache is updated. + + This method will be run in the event loop. + """ + for listener in self.listeners: + listener.async_update_records_complete() + self.zc.async_notify_all() + + def async_updates_from_response(self, msg: DNSIncoming) -> None: + """Deal with incoming response packets. All answers + are held in the cache, and listeners are notified. + + This function must be run in the event loop as it is not + threadsafe. + """ + updates: List[RecordUpdate] = [] + address_adds: List[DNSAddress] = [] + other_adds: List[DNSRecord] = [] + removes: Set[DNSRecord] = set() + now = msg.now + unique_types: Set[Tuple[str, int, int]] = set() + + for record in msg.answers: + sanitize_incoming_record(record) + + if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 + unique_types.add((record.name, record.type, record.class_)) + + maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record)) + if not record.is_expired(now): + if maybe_entry is not None: + maybe_entry.reset_ttl(record) + else: + if isinstance(record, DNSAddress): + address_adds.append(record) + else: + other_adds.append(record) + updates.append(RecordUpdate(record, maybe_entry)) + # This is likely a goodbye since the record is + # expired and exists in the cache + elif maybe_entry is not None: + updates.append(RecordUpdate(record, maybe_entry)) + removes.add(record) + + if unique_types: + self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now) + + if updates: + self.async_updates(now, updates) + # The cache adds must be processed AFTER we trigger + # the updates since we compare existing data + # with the new data and updating the cache + # ahead of update_record will cause listeners + # to miss changes + # + # We must process address adds before non-addresses + # otherwise a fetch of ServiceInfo may miss an address + # because it thinks the cache is complete + # + # The cache is processed under the context manager to ensure + # that any ServiceBrowser that is going to call + # zc.get_service_info will see the cached value + # but ONLY after all the record updates have been + # processsed. + if other_adds or address_adds: + self.cache.async_add_records(itertools.chain(address_adds, other_adds)) + # Removes are processed last since + # ServiceInfo could generate an un-needed query + # because the data was not yet populated. + if removes: + self.cache.async_remove_records(removes) + if updates: + self.async_updates_complete() + + def _async_mark_unique_cached_records_older_than_1s_to_expire( + self, unique_types: Set[Tuple[str, int, int]], answers: Iterable[DNSRecord], now: float + ) -> None: + # rfc6762#section-10.2 para 2 + # Since unique is set, all old records with that name, rrtype, + # and rrclass that were received more than one second ago are declared + # invalid, and marked to expire from the cache in one second. + answers_rrset = DNSRRSet(answers) + for name, type_, class_ in unique_types: + for entry in self.cache.async_all_by_details(name, type_, class_): + if (now - entry.created > _ONE_SECOND) and entry not in answers_rrset: + # Expire in 1s + entry.set_created_ttl(now, 1) + + def async_add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s). + + This function is not threadsafe and must be called in the eventloop. + """ + self.listeners.append(listener) + + if question is None: + return + + questions = [question] if isinstance(question, DNSQuestion) else question + assert self.zc.loop is not None + self._async_update_matching_records(listener, questions) + + def _async_update_matching_records( + self, listener: RecordUpdateListener, questions: List[DNSQuestion] + ) -> None: + """Calls back any existing entries in the cache that answer the question. + + This function must be run from the event loop. + """ + now = current_time_millis() + records: List[RecordUpdate] = [] + for question in questions: + for record in self.cache.async_entries_with_name(question.name): + if not record.is_expired(now) and question.answered_by(record): + records.append(RecordUpdate(record, None)) + + if not records: + return + listener.async_update_records(self.zc, now, records) + listener.async_update_records_complete() + self.zc.async_notify_all() + + def async_remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener. + + This function is not threadsafe and must be called in the eventloop. + """ + try: + self.listeners.remove(listener) + self.zc.async_notify_all() + except ValueError as e: + log.exception('Failed to remove listener: %r', e) + + +class MulticastOutgoingQueue: + """An outgoing queue used to aggregate multicast responses.""" + + def __init__(self, zeroconf: 'Zeroconf', additional_delay: int, max_aggregation_delay: int) -> None: + self.zc = zeroconf + self.queue: deque = deque() + # Additional delay is used to implement + # Protect the network against excessive packet flooding + # https://datatracker.ietf.org/doc/html/rfc6762#section-14 + self.additional_delay = additional_delay + self.aggregation_delay = max_aggregation_delay + + def async_add(self, now: float, answers: _AnswerWithAdditionalsType) -> None: + """Add a group of answers with additionals to the outgoing queue.""" + assert self.zc.loop is not None + random_delay = random.randint(*_MULTICAST_DELAY_RANDOM_INTERVAL) + self.additional_delay + send_after = now + random_delay + send_before = now + self.aggregation_delay + self.additional_delay + if len(self.queue): + # If we calculate a random delay for the send after time + # that is less than the last group scheduled to go out, + # we instead add the answers to the last group as this + # allows aggregating additonal responses + last_group = self.queue[-1] + if send_after <= last_group.send_after: + last_group.answers.update(answers) + return + else: + self.zc.loop.call_later(millis_to_seconds(random_delay), self.async_ready) + self.queue.append(AnswerGroup(send_after, send_before, answers)) + + def _remove_answers_from_queue(self, answers: _AnswerWithAdditionalsType) -> None: + """Remove a set of answers from the outgoing queue.""" + for pending in self.queue: + for record in answers: + pending.answers.pop(record, None) + + def async_ready(self) -> None: + """Process anything in the queue that is ready.""" + assert self.zc.loop is not None + now = current_time_millis() + + if len(self.queue) > 1 and self.queue[0].send_before > now: + # There is more than one answer in the queue, + # delay until we have to send it (first answer group reaches send_before) + self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_before - now), self.async_ready) + return + + answers: _AnswerWithAdditionalsType = {} + # Add all groups that can be sent now + while len(self.queue) and self.queue[0].send_after <= now: + answers.update(self.queue.popleft().answers) + + if len(self.queue): + # If there are still groups in the queue that are not ready to send + # be sure we schedule them to go out later + self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_after - now), self.async_ready) + + if answers: + # If we have the same answer scheduled to go out, remove them + self._remove_answers_from_queue(answers) + self.zc.async_send(construct_outgoing_multicast_answers(answers)) diff --git a/zeroconf/_history.py b/zeroconf/_history.py new file mode 100644 index 00000000..cbb36144 --- /dev/null +++ b/zeroconf/_history.py @@ -0,0 +1,70 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import Dict, Set, Tuple + +from ._dns import DNSQuestion, DNSRecord +from .const import _DUPLICATE_QUESTION_INTERVAL + +# The QuestionHistory is used to implement Duplicate Question Suppression +# https://datatracker.ietf.org/doc/html/rfc6762#section-7.3 + + +class QuestionHistory: + def __init__(self) -> None: + self._history: Dict[DNSQuestion, Tuple[float, Set[DNSRecord]]] = {} + + def add_question_at_time(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> None: + """Remember a question with known answers.""" + self._history[question] = (now, known_answers) + + def suppresses(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> bool: + """Check to see if a question should be suppressed. + + https://datatracker.ietf.org/doc/html/rfc6762#section-7.3 + When multiple queriers on the network are querying + for the same resource records, there is no need for them to all be + repeatedly asking the same question. + """ + previous_question = self._history.get(question) + # There was not previous question in the history + if not previous_question: + return False + than, previous_known_answers = previous_question + # The last question was older than 999ms + if now - than > _DUPLICATE_QUESTION_INTERVAL: + return False + # The last question has more known answers than + # we knew so we have to ask + if previous_known_answers - known_answers: + return False + return True + + def async_expire(self, now: float) -> None: + """Expire the history of old questions.""" + removes = [ + question + for question, now_known_answers in self._history.items() + if now - now_known_answers[0] > _DUPLICATE_QUESTION_INTERVAL + ] + for question in removes: + del self._history[question] diff --git a/zeroconf/_logger.py b/zeroconf/_logger.py new file mode 100644 index 00000000..932d1a2f --- /dev/null +++ b/zeroconf/_logger.py @@ -0,0 +1,74 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import logging +import sys +from typing import Any, Dict, Union, cast + +log = logging.getLogger(__name__.split('.', maxsplit=1)[0]) +log.addHandler(logging.NullHandler()) + + +def set_logger_level_if_unset() -> None: + if log.level == logging.NOTSET: + log.setLevel(logging.WARN) + + +set_logger_level_if_unset() + + +class QuietLogger: + _seen_logs: Dict[str, Union[int, tuple]] = {} + + @classmethod + def log_exception_warning(cls, *logger_data: Any) -> None: + exc_info = sys.exc_info() + exc_str = str(exc_info[1]) + if exc_str not in cls._seen_logs: + # log at warning level the first time this is seen + cls._seen_logs[exc_str] = exc_info + logger = log.warning + else: + logger = log.debug + logger(*(logger_data or ['Exception occurred']), exc_info=True) + + @classmethod + def log_warning_once(cls, *args: Any) -> None: + msg_str = args[0] + if msg_str not in cls._seen_logs: + cls._seen_logs[msg_str] = 0 + logger = log.warning + else: + logger = log.debug + cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 + logger(*args) + + @classmethod + def log_exception_once(cls, exc: Exception, *args: Any) -> None: + msg_str = args[0] + if msg_str not in cls._seen_logs: + cls._seen_logs[msg_str] = 0 + logger = log.warning + else: + logger = log.debug + cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 + logger(*args, exc_info=exc) diff --git a/zeroconf/_protocol/__init__.py b/zeroconf/_protocol/__init__.py new file mode 100644 index 00000000..360b599d --- /dev/null +++ b/zeroconf/_protocol/__init__.py @@ -0,0 +1,51 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from ..const import ( + _FLAGS_QR_MASK, + _FLAGS_QR_QUERY, + _FLAGS_QR_RESPONSE, + _FLAGS_TC, +) + + +class DNSMessage: + """A base class for DNS messages.""" + + __slots__ = ('flags',) + + def __init__(self, flags: int) -> None: + """Construct a DNS message.""" + self.flags = flags + + def is_query(self) -> bool: + """Returns true if this is a query.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY + + def is_response(self) -> bool: + """Returns true if this is a response.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE + + @property + def truncated(self) -> bool: + """Returns true if this is a truncated.""" + return (self.flags & _FLAGS_TC) == _FLAGS_TC diff --git a/zeroconf/_protocol/incoming.py b/zeroconf/_protocol/incoming.py new file mode 100644 index 00000000..6d7a6153 --- /dev/null +++ b/zeroconf/_protocol/incoming.py @@ -0,0 +1,315 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import struct +from typing import Callable, Dict, List, Optional, Set, Tuple, cast + +from . import DNSMessage +from .._dns import DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText +from .._exceptions import IncomingDecodeError +from .._logger import QuietLogger, log +from .._utils.time import current_time_millis +from ..const import ( + _TYPES, + _TYPE_A, + _TYPE_AAAA, + _TYPE_CNAME, + _TYPE_HINFO, + _TYPE_NSEC, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) + +DNS_COMPRESSION_HEADER_LEN = 1 +DNS_COMPRESSION_POINTER_LEN = 2 +MAX_DNS_LABELS = 128 +MAX_NAME_LENGTH = 253 + +DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError) + + +class DNSIncoming(DNSMessage, QuietLogger): + + """Object representation of an incoming DNS packet""" + + __slots__ = ( + 'offset', + 'data', + 'data_len', + 'name_cache', + 'questions', + '_answers', + 'id', + 'num_questions', + 'num_answers', + 'num_authorities', + 'num_additionals', + 'valid', + 'now', + 'scope_id', + 'source', + ) + + def __init__( + self, + data: bytes, + source: Optional[Tuple[str, int]] = None, + scope_id: Optional[int] = None, + now: Optional[float] = None, + ) -> None: + """Constructor from string holding bytes of packet""" + super().__init__(0) + self.offset = 0 + self.data = data + self.data_len = len(data) + self.name_cache: Dict[int, List[str]] = {} + self.questions: List[DNSQuestion] = [] + self._answers: List[DNSRecord] = [] + self.id = 0 + self.num_questions = 0 + self.num_answers = 0 + self.num_authorities = 0 + self.num_additionals = 0 + self.valid = False + self._read_others = False + self.now = now or current_time_millis() + self.source = source + self.scope_id = scope_id + self._parse_data(self._initial_parse) + + def _initial_parse(self) -> None: + """Parse the data needed to initalize the packet object.""" + self.read_header() + self.read_questions() + if not self.num_questions: + self.read_others() + self.valid = True + + def _parse_data(self, parser_call: Callable) -> None: + """Parse part of the packet and catch exceptions.""" + try: + parser_call() + except DECODE_EXCEPTIONS: + self.log_exception_warning( + 'Received invalid packet from %s at offset %d while unpacking %r', + self.source, + self.offset, + self.data, + ) + + @property + def answers(self) -> List[DNSRecord]: + """Answers in the packet.""" + if not self._read_others: + self._parse_data(self.read_others) + return self._answers + + def __repr__(self) -> str: + return '' % ', '.join( + [ + 'id=%s' % self.id, + 'flags=%s' % self.flags, + 'truncated=%s' % self.truncated, + 'n_q=%s' % self.num_questions, + 'n_ans=%s' % self.num_answers, + 'n_auth=%s' % self.num_authorities, + 'n_add=%s' % self.num_additionals, + 'questions=%s' % self.questions, + 'answers=%s' % self.answers, + ] + ) + + def unpack(self, format_: bytes, length: int) -> tuple: + self.offset += length + return struct.unpack(format_, self.data[self.offset - length : self.offset]) + + def read_header(self) -> None: + """Reads header portion of packet""" + ( + self.id, + self.flags, + self.num_questions, + self.num_answers, + self.num_authorities, + self.num_additionals, + ) = self.unpack(b'!6H', 12) + + def read_questions(self) -> None: + """Reads questions section of packet""" + self.questions = [ + DNSQuestion(self.read_name(), *self.unpack(b'!HH', 4)) for _ in range(self.num_questions) + ] + + def read_character_string(self) -> bytes: + """Reads a character string from the packet""" + length = self.data[self.offset] + self.offset += 1 + return self.read_string(length) + + def read_string(self, length: int) -> bytes: + """Reads a string of a given length from the packet""" + info = self.data[self.offset : self.offset + length] + self.offset += length + return info + + def read_unsigned_short(self) -> int: + """Reads an unsigned short from the packet""" + return cast(int, self.unpack(b'!H', 2)[0]) + + def read_others(self) -> None: + """Reads the answers, authorities and additionals section of the + packet""" + self._read_others = True + n = self.num_answers + self.num_authorities + self.num_additionals + for _ in range(n): + domain = self.read_name() + type_, class_, ttl, length = self.unpack(b'!HHiH', 10) + end = self.offset + length + rec = None + try: + rec = self.read_record(domain, type_, class_, ttl, length) + except DECODE_EXCEPTIONS: + # Skip records that fail to decode if we know the length + # If the packet is really corrupt read_name and the unpack + # above would fail and hit the exception catch in read_others + self.offset = end + log.debug( + 'Unable to parse; skipping record for %s with type %s at offset %d while unpacking %r', + domain, + _TYPES.get(type_, type_), + self.offset, + self.data, + exc_info=True, + ) + if rec is not None: + self._answers.append(rec) + + def read_record(self, domain: str, type_: int, class_: int, ttl: int, length: int) -> Optional[DNSRecord]: + """Read known records types and skip unknown ones.""" + if type_ == _TYPE_A: + return DNSAddress(domain, type_, class_, ttl, self.read_string(4), created=self.now) + if type_ in (_TYPE_CNAME, _TYPE_PTR): + return DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now) + if type_ == _TYPE_TXT: + return DNSText(domain, type_, class_, ttl, self.read_string(length), self.now) + if type_ == _TYPE_SRV: + return DNSService( + domain, + type_, + class_, + ttl, + self.read_unsigned_short(), + self.read_unsigned_short(), + self.read_unsigned_short(), + self.read_name(), + self.now, + ) + if type_ == _TYPE_HINFO: + return DNSHinfo( + domain, + type_, + class_, + ttl, + self.read_character_string().decode('utf-8'), + self.read_character_string().decode('utf-8'), + self.now, + ) + if type_ == _TYPE_AAAA: + return DNSAddress( + domain, type_, class_, ttl, self.read_string(16), created=self.now, scope_id=self.scope_id + ) + if type_ == _TYPE_NSEC: + name_start = self.offset + return DNSNsec( + domain, + type_, + class_, + ttl, + self.read_name(), + self.read_bitmap(name_start + length), + self.now, + ) + # Try to ignore types we don't know about + # Skip the payload for the resource record so the next + # records can be parsed correctly + self.offset += length + return None + + def read_bitmap(self, end: int) -> List[int]: + """Reads an NSEC bitmap from the packet.""" + rdtypes = [] + while self.offset < end: + window = self.data[self.offset] + bitmap_length = self.data[self.offset + 1] + for i, byte in enumerate(self.data[self.offset + 2 : self.offset + 2 + bitmap_length]): + for bit in range(0, 8): + if byte & (0x80 >> bit): + rdtypes.append(bit + window * 256 + i * 8) + self.offset += 2 + bitmap_length + return rdtypes + + def read_name(self) -> str: + """Reads a domain name from the packet.""" + labels: List[str] = [] + seen_pointers: Set[int] = set() + self.offset = self._decode_labels_at_offset(self.offset, labels, seen_pointers) + name = ".".join(labels) + "." + if len(name) > MAX_NAME_LENGTH: + raise IncomingDecodeError(f"DNS name {name} exceeds maximum length of {MAX_NAME_LENGTH}") + return name + + def _decode_labels_at_offset(self, off: int, labels: List[str], seen_pointers: Set[int]) -> int: + # This is a tight loop that is called frequently, small optimizations can make a difference. + while off < self.data_len: + length = self.data[off] + if length == 0: + return off + DNS_COMPRESSION_HEADER_LEN + + if length < 0x40: + label_idx = off + DNS_COMPRESSION_HEADER_LEN + labels.append(str(self.data[label_idx : label_idx + length], 'utf-8', 'replace')) + off += DNS_COMPRESSION_HEADER_LEN + length + continue + + if length < 0xC0: + raise IncomingDecodeError(f"DNS compression type {length} is unknown at {off}") + + # We have a DNS compression pointer + link = (length & 0x3F) * 256 + self.data[off + 1] + if link > self.data_len: + raise IncomingDecodeError(f"DNS compression pointer at {off} points to {link} beyond packet") + if link == off: + raise IncomingDecodeError(f"DNS compression pointer at {off} points to itself") + if link in seen_pointers: + raise IncomingDecodeError(f"DNS compression pointer at {off} was seen again") + seen_pointers.add(link) + linked_labels = self.name_cache.get(link, []) + if not linked_labels: + self._decode_labels_at_offset(link, linked_labels, seen_pointers) + self.name_cache[link] = linked_labels + labels.extend(linked_labels) + if len(labels) > MAX_DNS_LABELS: + raise IncomingDecodeError(f"Maximum dns labels reached while processing pointer at {off}") + return off + DNS_COMPRESSION_POINTER_LEN + + raise IncomingDecodeError("Corrupt packet received while decoding name") diff --git a/zeroconf/_protocol/outgoing.py b/zeroconf/_protocol/outgoing.py new file mode 100644 index 00000000..59c8382e --- /dev/null +++ b/zeroconf/_protocol/outgoing.py @@ -0,0 +1,442 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +import struct +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from . import DNSMessage +from .incoming import DNSIncoming +from .._cache import DNSCache +from .._dns import DNSPointer, DNSQuestion, DNSRecord +from .._exceptions import NamePartTooLongException +from .._logger import log +from ..const import ( + _CLASS_UNIQUE, + _DNS_PACKET_HEADER_LEN, + _FLAGS_TC, + _MAX_MSG_ABSOLUTE, + _MAX_MSG_TYPICAL, +) + + +class DNSOutgoing(DNSMessage): + + """Object representation of an outgoing packet""" + + __slots__ = ( + 'finished', + 'id', + 'multicast', + 'packets_data', + 'names', + 'data', + 'size', + 'allow_long', + 'state', + 'questions', + 'answers', + 'authorities', + 'additionals', + ) + + def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: + super().__init__(flags) + self.finished = False + self.id = id_ + self.multicast = multicast + self.packets_data: List[bytes] = [] + + # these 3 are per-packet -- see also _reset_for_next_packet() + self.names: Dict[str, int] = {} + self.data: List[bytes] = [] + self.size: int = _DNS_PACKET_HEADER_LEN + self.allow_long: bool = True + + self.state = self.State.init + + self.questions: List[DNSQuestion] = [] + self.answers: List[Tuple[DNSRecord, float]] = [] + self.authorities: List[DNSPointer] = [] + self.additionals: List[DNSRecord] = [] + + def _reset_for_next_packet(self) -> None: + self.names = {} + self.data = [] + self.size = _DNS_PACKET_HEADER_LEN + self.allow_long = True + + def __repr__(self) -> str: + return '' % ', '.join( + [ + 'multicast=%s' % self.multicast, + 'flags=%s' % self.flags, + 'questions=%s' % self.questions, + 'answers=%s' % self.answers, + 'authorities=%s' % self.authorities, + 'additionals=%s' % self.additionals, + ] + ) + + class State(enum.Enum): + init = 0 + finished = 1 + + def add_question(self, record: DNSQuestion) -> None: + """Adds a question""" + self.questions.append(record) + + def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: + """Adds an answer""" + if not record.suppressed_by(inp): + self.add_answer_at_time(record, 0) + + def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None: + """Adds an answer if it does not expire by a certain time""" + if record is not None and (now == 0 or not record.is_expired(now)): + self.answers.append((record, now)) + + def add_authorative_answer(self, record: DNSPointer) -> None: + """Adds an authoritative answer""" + self.authorities.append(record) + + def add_additional_answer(self, record: DNSRecord) -> None: + """Adds an additional answer + + From: RFC 6763, DNS-Based Service Discovery, February 2013 + + 12. DNS Additional Record Generation + + DNS has an efficiency feature whereby a DNS server may place + additional records in the additional section of the DNS message. + These additional records are records that the client did not + explicitly request, but the server has reasonable grounds to expect + that the client might request them shortly, so including them can + save the client from having to issue additional queries. + + This section recommends which additional records SHOULD be generated + to improve network efficiency, for both Unicast and Multicast DNS-SD + responses. + + 12.1. PTR Records + + When including a DNS-SD Service Instance Enumeration or Selective + Instance Enumeration (subtype) PTR record in a response packet, the + server/responder SHOULD include the following additional records: + + o The SRV record(s) named in the PTR rdata. + o The TXT record(s) named in the PTR rdata. + o All address records (type "A" and "AAAA") named in the SRV rdata. + + 12.2. SRV Records + + When including an SRV record in a response packet, the + server/responder SHOULD include the following additional records: + + o All address records (type "A" and "AAAA") named in the SRV rdata. + + """ + self.additionals.append(record) + + def add_question_or_one_cache( + self, cache: DNSCache, now: float, name: str, type_: int, class_: int + ) -> None: + """Add a question if it is not already cached.""" + cached_entry = cache.get_by_details(name, type_, class_) + if not cached_entry: + self.add_question(DNSQuestion(name, type_, class_)) + else: + self.add_answer_at_time(cached_entry, now) + + def add_question_or_all_cache( + self, cache: DNSCache, now: float, name: str, type_: int, class_: int + ) -> None: + """Add a question if it is not already cached. + This is currently only used for IPv6 addresses. + """ + cached_entries = cache.get_all_by_details(name, type_, class_) + if not cached_entries: + self.add_question(DNSQuestion(name, type_, class_)) + return + for cached_entry in cached_entries: + self.add_answer_at_time(cached_entry, now) + + def _pack(self, format_: Union[bytes, str], size: int, value: Any) -> None: + self.data.append(struct.pack(format_, value)) + self.size += size + + def _write_byte(self, value: int) -> None: + """Writes a single byte to the packet""" + self._pack(b'!c', 1, bytes((value,))) + + def _insert_short_at_start(self, value: int) -> None: + """Inserts an unsigned short at the start of the packet""" + self.data.insert(0, struct.pack(b'!H', value)) + + def _replace_short(self, index: int, value: int) -> None: + """Replaces an unsigned short in a certain position in the packet""" + self.data[index] = struct.pack(b'!H', value) + + def write_short(self, value: int) -> None: + """Writes an unsigned short to the packet""" + self._pack(b'!H', 2, value) + + def _write_int(self, value: Union[float, int]) -> None: + """Writes an unsigned integer to the packet""" + self._pack(b'!I', 4, int(value)) + + def write_string(self, value: bytes) -> None: + """Writes a string to the packet""" + assert isinstance(value, bytes) + self.data.append(value) + self.size += len(value) + + def _write_utf(self, s: str) -> None: + """Writes a UTF-8 string of a given length to the packet""" + utfstr = s.encode('utf-8') + length = len(utfstr) + if length > 64: + raise NamePartTooLongException + self._write_byte(length) + self.write_string(utfstr) + + def write_character_string(self, value: bytes) -> None: + assert isinstance(value, bytes) + length = len(value) + if length > 256: + raise NamePartTooLongException + self._write_byte(length) + self.write_string(value) + + def write_name(self, name: str) -> None: + """ + Write names to packet + + 18.14. Name Compression + + When generating Multicast DNS messages, implementations SHOULD use + name compression wherever possible to compress the names of resource + records, by replacing some or all of the resource record name with a + compact two-byte reference to an appearance of that data somewhere + earlier in the message [RFC1035]. + """ + + # split name into each label + name_length = None + if name.endswith('.'): + name = name[: len(name) - 1] + labels = name.split('.') + # Write each new label or a pointer to the existing + # on in the packet + start_size = self.size + for count in range(len(labels)): + label = name if count == 0 else '.'.join(labels[count:]) + index = self.names.get(label) + if index: + # If part of the name already exists in the packet, + # create a pointer to it + self._write_byte((index >> 8) | 0xC0) + self._write_byte(index & 0xFF) + return + if name_length is None: + name_length = len(name.encode('utf-8')) + self.names[label] = start_size + name_length - len(label.encode('utf-8')) + self._write_utf(labels[count]) + + # this is the end of a name + self._write_byte(0) + + def _write_question(self, question: DNSQuestion) -> bool: + """Writes a question to the packet""" + start_data_length, start_size = len(self.data), self.size + self.write_name(question.name) + self.write_short(question.type) + self._write_record_class(question) + return self._check_data_limit_or_rollback(start_data_length, start_size) + + def _write_record_class(self, record: Union[DNSQuestion, DNSRecord]) -> None: + """Write out the record class including the unique/unicast (QU) bit.""" + if record.unique and self.multicast: + self.write_short(record.class_ | _CLASS_UNIQUE) + else: + self.write_short(record.class_) + + def _write_ttl(self, record: DNSRecord, now: float) -> None: + """Write out the record ttl.""" + self._write_int(record.ttl if now == 0 else record.get_remaining_ttl(now)) + + def _write_record(self, record: DNSRecord, now: float) -> bool: + """Writes a record (answer, authoritative answer, additional) to + the packet. Returns True on success, or False if we did not + because the packet because the record does not fit.""" + start_data_length, start_size = len(self.data), self.size + self.write_name(record.name) + self.write_short(record.type) + self._write_record_class(record) + self._write_ttl(record, now) + index = len(self.data) + self.write_short(0) # Will get replaced with the actual size + record.write(self) + # Adjust size for the short we will write before this record + length = sum(len(d) for d in self.data[index + 1 :]) + # Here we replace the 0 length short we wrote + # before with the actual length + self._replace_short(index, length) + return self._check_data_limit_or_rollback(start_data_length, start_size) + + def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) -> bool: + """Check data limit, if we go over, then rollback and return False.""" + len_limit = _MAX_MSG_ABSOLUTE if self.allow_long else _MAX_MSG_TYPICAL + self.allow_long = False + + if self.size <= len_limit: + return True + + log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit) + del self.data[start_data_length:] + self.size = start_size + + rollback_names = [name for name, idx in self.names.items() if idx >= start_size] + for name in rollback_names: + del self.names[name] + return False + + def _write_questions_from_offset(self, questions_offset: int) -> int: + questions_written = 0 + for question in self.questions[questions_offset:]: + if not self._write_question(question): + break + questions_written += 1 + return questions_written + + def _write_answers_from_offset(self, answer_offset: int) -> int: + answers_written = 0 + for answer, time_ in self.answers[answer_offset:]: + if not self._write_record(answer, time_): + break + answers_written += 1 + return answers_written + + def _write_records_from_offset(self, records: Sequence[DNSRecord], offset: int) -> int: + records_written = 0 + for record in records[offset:]: + if not self._write_record(record, 0): + break + records_written += 1 + return records_written + + def _has_more_to_add( + self, questions_offset: int, answer_offset: int, authority_offset: int, additional_offset: int + ) -> bool: + """Check if all questions, answers, authority, and additionals have been written to the packet.""" + return ( + questions_offset < len(self.questions) + or answer_offset < len(self.answers) + or authority_offset < len(self.authorities) + or additional_offset < len(self.additionals) + ) + + def packets(self) -> List[bytes]: + """Returns a list of bytestrings containing the packets' bytes + + No further parts should be added to the packet once this + is done. The packets are each restricted to _MAX_MSG_TYPICAL + or less in length, except for the case of a single answer which + will be written out to a single oversized packet no more than + _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP + fragmentation potentially).""" + + if self.state == self.State.finished: + return self.packets_data + + questions_offset = 0 + answer_offset = 0 + authority_offset = 0 + additional_offset = 0 + # we have to at least write out the question + first_time = True + + while first_time or self._has_more_to_add( + questions_offset, answer_offset, authority_offset, additional_offset + ): + first_time = False + log.debug( + "offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", + questions_offset, + answer_offset, + authority_offset, + additional_offset, + ) + log.debug( + "lengths = questions=%d, answers=%d, authorities=%d, additionals=%d", + len(self.questions), + len(self.answers), + len(self.authorities), + len(self.additionals), + ) + + questions_written = self._write_questions_from_offset(questions_offset) + answers_written = self._write_answers_from_offset(answer_offset) + authorities_written = self._write_records_from_offset(self.authorities, authority_offset) + additionals_written = self._write_records_from_offset(self.additionals, additional_offset) + + self._insert_short_at_start(additionals_written) + self._insert_short_at_start(authorities_written) + self._insert_short_at_start(answers_written) + self._insert_short_at_start(questions_written) + + questions_offset += questions_written + answer_offset += answers_written + authority_offset += authorities_written + additional_offset += additionals_written + log.debug( + "now offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", + questions_offset, + answer_offset, + authority_offset, + additional_offset, + ) + + if self.is_query() and self._has_more_to_add( + questions_offset, answer_offset, authority_offset, additional_offset + ): + # https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 + log.debug("Setting TC flag") + self._insert_short_at_start(self.flags | _FLAGS_TC) + else: + self._insert_short_at_start(self.flags) + + if self.multicast: + self._insert_short_at_start(0) + else: + self._insert_short_at_start(self.id) + + self.packets_data.append(b''.join(self.data)) + self._reset_for_next_packet() + + if (questions_written + answers_written + authorities_written + additionals_written) == 0 and ( + len(self.questions) + len(self.answers) + len(self.authorities) + len(self.additionals) + ) > 0: + log.warning("packets() made no progress adding records; returning") + break + self.state = self.State.finished + return self.packets_data diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py new file mode 100644 index 00000000..5b9fbf01 --- /dev/null +++ b/zeroconf/_services/__init__.py @@ -0,0 +1,72 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +from typing import Any, Callable, List, TYPE_CHECKING + + +if TYPE_CHECKING: + from .._core import Zeroconf + + +@enum.unique +class ServiceStateChange(enum.Enum): + Added = 1 + Removed = 2 + Updated = 3 + + +class ServiceListener: + def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + +class Signal: + def __init__(self) -> None: + self._handlers: List[Callable[..., None]] = [] + + def fire(self, **kwargs: Any) -> None: + for h in list(self._handlers): + h(**kwargs) + + @property + def registration_interface(self) -> 'SignalRegistrationInterface': + return SignalRegistrationInterface(self._handlers) + + +class SignalRegistrationInterface: + def __init__(self, handlers: List[Callable[..., None]]) -> None: + self._handlers = handlers + + def register_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': + self._handlers.append(handler) + return self + + def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': + self._handlers.remove(handler) + return self diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py new file mode 100644 index 00000000..f6448fd2 --- /dev/null +++ b/zeroconf/_services/browser.py @@ -0,0 +1,539 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import queue +import random +import threading +import warnings +from collections import OrderedDict +from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast + +from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSQuestionType, DNSRecord +from .._logger import log +from .._protocol.outgoing import DNSOutgoing +from .._services import ( + ServiceListener, + ServiceStateChange, + Signal, + SignalRegistrationInterface, +) +from .._updates import RecordUpdate, RecordUpdateListener +from .._utils.asyncio import get_best_available_queue +from .._utils.name import service_type_name +from .._utils.time import current_time_millis, millis_to_seconds +from ..const import ( + _BROWSER_BACKOFF_LIMIT, + _BROWSER_TIME, + _CLASS_IN, + _DNS_PACKET_HEADER_LEN, + _EXPIRE_REFRESH_TIME_PERCENT, + _FLAGS_QR_QUERY, + _MAX_MSG_TYPICAL, + _MDNS_ADDR, + _MDNS_ADDR6, + _MDNS_PORT, + _TYPE_PTR, +) + +# https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 +_FIRST_QUERY_DELAY_RANDOM_INTERVAL = (20, 120) # ms + +_ON_CHANGE_DISPATCH = { + ServiceStateChange.Added: "add_service", + ServiceStateChange.Removed: "remove_service", + ServiceStateChange.Updated: "update_service", +} + +if TYPE_CHECKING: + from .._core import Zeroconf + + +_QuestionWithKnownAnswers = Dict[DNSQuestion, Set[DNSPointer]] + + +class _DNSPointerOutgoingBucket: + """A DNSOutgoing bucket.""" + + def __init__(self, now: float, multicast: bool) -> None: + """Create a bucke to wrap a DNSOutgoing.""" + self.now = now + self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=multicast) + self.bytes = 0 + + def add(self, max_compressed_size: int, question: DNSQuestion, answers: Set[DNSPointer]) -> None: + """Add a new set of questions and known answers to the outgoing.""" + self.out.add_question(question) + for answer in answers: + self.out.add_answer_at_time(answer, self.now) + self.bytes += max_compressed_size + + +def _group_ptr_queries_with_known_answers( + now: float, multicast: bool, question_with_known_answers: _QuestionWithKnownAnswers +) -> List[DNSOutgoing]: + """Aggregate queries so that as many known answers as possible fit in the same packet + without having known answers spill over into the next packet unless the + question and known answers are always going to exceed the packet size. + + Some responders do not implement multi-packet known answer suppression + so we try to keep all the known answers in the same packet as the + questions. + """ + # This is the maximum size the query + known answers can be with name compression. + # The actual size of the query + known answers may be a bit smaller since other + # parts may be shared when the final DNSOutgoing packets are constructed. The + # goal of this algorithm is to quickly bucket the query + known answers without + # the overhead of actually constructing the packets. + query_by_size: Dict[DNSQuestion, int] = { + question: (question.max_size + sum([answer.max_size_compressed for answer in known_answers])) + for question, known_answers in question_with_known_answers.items() + } + max_bucket_size = _MAX_MSG_TYPICAL - _DNS_PACKET_HEADER_LEN + query_buckets: List[_DNSPointerOutgoingBucket] = [] + for question in sorted( + query_by_size, + key=query_by_size.get, # type: ignore + reverse=True, + ): + max_compressed_size = query_by_size[question] + answers = question_with_known_answers[question] + for query_bucket in query_buckets: + if query_bucket.bytes + max_compressed_size <= max_bucket_size: + query_bucket.add(max_compressed_size, question, answers) + break + else: + # If a single question and known answers won't fit in a packet + # we will end up generating multiple packets, but there will never + # be multiple questions + query_bucket = _DNSPointerOutgoingBucket(now, multicast) + query_bucket.add(max_compressed_size, question, answers) + query_buckets.append(query_bucket) + + return [query_bucket.out for query_bucket in query_buckets] + + +def generate_service_query( + zc: 'Zeroconf', + now: float, + types_: List[str], + multicast: bool = True, + question_type: Optional[DNSQuestionType] = None, +) -> List[DNSOutgoing]: + """Generate a service query for sending with zeroconf.send.""" + questions_with_known_answers: _QuestionWithKnownAnswers = {} + qu_question = not multicast if question_type is None else question_type == DNSQuestionType.QU + for type_ in types_: + question = DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) + question.unicast = qu_question + known_answers = set( + cast(DNSPointer, record) + for record in zc.cache.get_all_by_details(type_, _TYPE_PTR, _CLASS_IN) + if not record.is_stale(now) + ) + if not qu_question and zc.question_history.suppresses( + question, now, cast(Set[DNSRecord], known_answers) + ): + log.debug("Asking %s was suppressed by the question history", question) + continue + questions_with_known_answers[question] = known_answers + if not qu_question: + zc.question_history.add_question_at_time(question, now, cast(Set[DNSRecord], known_answers)) + + return _group_ptr_queries_with_known_answers(now, multicast, questions_with_known_answers) + + +def _service_state_changed_from_listener(listener: ServiceListener) -> Callable[..., None]: + """Generate a service_state_changed handlers from a listener.""" + assert listener is not None + if not hasattr(listener, 'update_service'): + warnings.warn( + "%r has no update_service method. Provide one (it can be empty if you " + "don't care about the updates), it'll become mandatory." % (listener,), + FutureWarning, + ) + + def on_change( + zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange + ) -> None: + getattr(listener, _ON_CHANGE_DISPATCH[state_change])(zeroconf, service_type, name) + + return on_change + + +class QueryScheduler: + """Schedule outgoing PTR queries for Continuous Multicast DNS Querying + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 + + """ + + def __init__( + self, + types: Set[str], + delay: int, + first_random_delay_interval: Tuple[int, int], + ): + self._schedule_changed_event: Optional[asyncio.Event] = None + self._types = types + self._next_time: Dict[str, float] = {} + self._first_random_delay_interval = first_random_delay_interval + self._delay: Dict[str, float] = {check_type_: delay for check_type_ in self._types} + + def start(self, now: float) -> None: + """Start the scheduler.""" + self._schedule_changed_event = asyncio.Event() + self._generate_first_next_time(now) + + def _generate_first_next_time(self, now: float) -> None: + """Generate the initial next query times. + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 + To avoid accidental synchronization when, for some reason, multiple + clients begin querying at exactly the same moment (e.g., because of + some common external trigger event), a Multicast DNS querier SHOULD + also delay the first query of the series by a randomly chosen amount + in the range 20-120 ms. + """ + delay = millis_to_seconds(random.randint(*self._first_random_delay_interval)) + next_time = now + delay + self._next_time = {check_type_: next_time for check_type_ in self._types} + + def millis_to_wait(self, now: float) -> float: + """Returns the number of milliseconds to wait for the next event.""" + # Wait for the type has the smallest next time + next_time = min(self._next_time.values()) + return 0 if next_time <= now else next_time - now + + def reschedule_type(self, type_: str, next_time: float) -> bool: + """Reschedule the query for a type to happen sooner.""" + if next_time >= self._next_time[type_]: + return False + self._next_time[type_] = next_time + return True + + def process_ready_types(self, now: float) -> List[str]: + """Generate a list of ready types that is due and schedule the next time.""" + if self.millis_to_wait(now): + return [] + + ready_types: List[str] = [] + + for type_, due in self._next_time.items(): + if due > now: + continue + + ready_types.append(type_) + self._next_time[type_] = now + self._delay[type_] + self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) + + return ready_types + + +class _ServiceBrowserBase(RecordUpdateListener): + """Base class for ServiceBrowser.""" + + def __init__( + self, + zc: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[ServiceListener] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + question_type: Optional[DNSQuestionType] = None, + ) -> None: + """Used to browse for a service for specific type(s). + + Constructor parameters are as follows: + + * `zc`: A Zeroconf instance + * `type_`: fully qualified service type name + * `handler`: ServiceListener or Callable that knows how to process ServiceStateChange events + * `listener`: ServiceListener + * `addr`: address to send queries (will default to multicast) + * `port`: port to send queries (will default to mdns 5353) + * `delay`: The initial delay between answering questions + * `question_type`: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU) + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability. + """ + assert handlers or listener, 'You need to specify at least one handler' + self.types: Set[str] = set(type_ if isinstance(type_, list) else [type_]) + for check_type_ in self.types: + # Will generate BadTypeInNameException on a bad name + service_type_name(check_type_, strict=False) + self.zc = zc + self.addr = addr + self.port = port + self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) + self.question_type = question_type + self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() + self._service_state_changed = Signal() + self.query_scheduler = QueryScheduler(self.types, delay, _FIRST_QUERY_DELAY_RANDOM_INTERVAL) + self.queue: Optional[queue.Queue] = None + self.done = False + self._first_request: bool = True + self._next_send_timer: Optional[asyncio.TimerHandle] = None + + if hasattr(handlers, 'add_service'): + listener = cast('ServiceListener', handlers) + handlers = None + + handlers = cast(List[Callable[..., None]], handlers or []) + + if listener: + handlers.append(_service_state_changed_from_listener(listener)) + + for h in handlers: + self.service_state_changed.register_handler(h) + + def _async_start(self) -> None: + """Generate the next time and setup listeners. + + Must be called by uses of this base class after they + have finished setting their properties. + """ + self.query_scheduler.start(current_time_millis()) + self.zc.async_add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) + # Only start queries after the listener is installed + asyncio.ensure_future(self._async_start_query_sender()) + + @property + def service_state_changed(self) -> SignalRegistrationInterface: + return self._service_state_changed.registration_interface + + def _record_matching_type(self, record: DNSRecord) -> Optional[str]: + """Return the type if the record matches one of the types we are browsing.""" + return next((type_ for type_ in self.types if record.name.endswith(type_)), None) + + def _enqueue_callback( + self, + state_change: ServiceStateChange, + type_: str, + name: str, + ) -> None: + # Code to ensure we only do a single update message + # Precedence is; Added, Remove, Update + key = (name, type_) + if ( + state_change is ServiceStateChange.Added + or ( + state_change is ServiceStateChange.Removed + and self._pending_handlers.get(key) != ServiceStateChange.Added + ) + or (state_change is ServiceStateChange.Updated and key not in self._pending_handlers) + ): + self._pending_handlers[key] = state_change + + def _async_process_record_update( + self, now: float, record: DNSRecord, old_record: Optional[DNSRecord] + ) -> None: + """Process a single record update from a batch of updates.""" + if isinstance(record, DNSPointer): + if record.name not in self.types: + return + if old_record is None: + self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias) + elif record.is_expired(now): + self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) + else: + self.reschedule_type(record.name, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)) + return + + # If its expired or already exists in the cache it cannot be updated. + if old_record or record.is_expired(now): + return + + if isinstance(record, DNSAddress): + # Iterate through the DNSCache and callback any services that use this address + for service in self.zc.cache.async_entries_with_server(record.name): + type_ = self._record_matching_type(service) + if type_: + self._enqueue_callback(ServiceStateChange.Updated, type_, service.name) + break + + return + + type_ = self._record_matching_type(record) + if type_: + self._enqueue_callback(ServiceStateChange.Updated, type_, record.name) + + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None: + """Callback invoked by Zeroconf when new information arrives. + + Updates information required by browser in the Zeroconf cache. + + Ensures that there is are no unecessary duplicates in the list. + + This method will be run in the event loop. + """ + for record in records: + self._async_process_record_update(now, record[0], record[1]) + + def async_update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + + This method will be run in the event loop. + """ + while self._pending_handlers: + event = self._pending_handlers.popitem(False) + # If there is a queue running (ServiceBrowser) + # get fired in dedicated thread + if self.queue: + self.queue.put(event) + else: + self._fire_service_state_changed_event(event) + + def _fire_service_state_changed_event(self, event: Tuple[Tuple[str, str], ServiceStateChange]) -> None: + """Fire a service state changed event. + + When running with ServiceBrowser, this will happen in the dedicated + thread. + + When running with AsyncServiceBrowser, this will happen in the event loop. + """ + name_type, state_change = event + self._service_state_changed.fire( + zeroconf=self.zc, + service_type=name_type[1], + name=name_type[0], + state_change=state_change, + ) + + def _async_cancel(self) -> None: + """Cancel the browser.""" + self.done = True + self._cancel_send_timer() + self.zc.async_remove_listener(self) + + def _generate_ready_queries(self, first_request: bool) -> List[DNSOutgoing]: + """Generate the service browser query for any type that is due.""" + now = current_time_millis() + ready_types = self.query_scheduler.process_ready_types(now) + if not ready_types: + return [] + + # If they did not specify and this is the first request, ask QU questions + # https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 since we are + # just starting up and we know our cache is likely empty. This ensures + # the next outgoing will be sent with the known answers list. + question_type = DNSQuestionType.QU if not self.question_type and first_request else self.question_type + return generate_service_query(self.zc, now, ready_types, self.multicast, question_type) + + async def _async_start_query_sender(self) -> None: + """Start scheduling queries.""" + await self.zc.async_wait_for_start() + self._async_send_ready_queries() + self._async_schedule_next() + + def _cancel_send_timer(self) -> None: + """Cancel the next send.""" + if self._next_send_timer: + self._next_send_timer.cancel() + + def reschedule_type(self, type_: str, next_time: float) -> None: + """Reschedule a type to be refreshed in the future.""" + if self.query_scheduler.reschedule_type(type_, next_time): + self._cancel_send_timer() + self._async_schedule_next() + self._async_send_ready_queries() + + def _async_send_ready_queries(self) -> None: + """Send any ready queries.""" + if self.done or self.zc.done: + return + + outs = self._generate_ready_queries(self._first_request) + if outs: + self._first_request = False + for out in outs: + self.zc.async_send(out, addr=self.addr, port=self.port) + + def _async_send_ready_queries_schedule_next(self) -> None: + """Send ready queries and schedule next one.""" + self._async_send_ready_queries() + self._async_schedule_next() + + def _async_schedule_next(self) -> None: + """Scheule the next time.""" + assert self.zc.loop is not None + delay = millis_to_seconds(self.query_scheduler.millis_to_wait(current_time_millis())) + self._next_send_timer = self.zc.loop.call_later(delay, self._async_send_ready_queries_schedule_next) + + +class ServiceBrowser(_ServiceBrowserBase, threading.Thread): + """Used to browse for a service of a specific type. + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability.""" + + def __init__( + self, + zc: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[ServiceListener] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + question_type: Optional[DNSQuestionType] = None, + ) -> None: + assert zc.loop is not None + if not zc.loop.is_running(): + raise RuntimeError("The event loop is not running") + threading.Thread.__init__(self) + super().__init__(zc, type_, handlers, listener, addr, port, delay, question_type) + # Add the queue before the listener is installed in _setup + # to ensure that events run in the dedicated thread and do + # not block the event loop + self.queue = get_best_available_queue() + self.daemon = True + self.start() + zc.loop.call_soon_threadsafe(self._async_start) + self.name = "zeroconf-ServiceBrowser-%s-%s" % ( + '-'.join([type_[:-7] for type_ in self.types]), + getattr(self, 'native_id', self.ident), + ) + + def cancel(self) -> None: + """Cancel the browser.""" + assert self.zc.loop is not None + assert self.queue is not None + self.queue.put(None) + self.zc.loop.call_soon_threadsafe(self._async_cancel) + self.join() + + def run(self) -> None: + """Run the browser thread.""" + assert self.queue is not None + while True: + event = self.queue.get() + if event is None: + return + self._fire_service_state_changed_event(event) diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py new file mode 100644 index 00000000..beaf0678 --- /dev/null +++ b/zeroconf/_services/info.py @@ -0,0 +1,513 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import ipaddress +import random +import socket +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast + +from .._dns import DNSAddress, DNSPointer, DNSQuestionType, DNSRecord, DNSService, DNSText +from .._exceptions import BadTypeInNameException +from .._protocol.outgoing import DNSOutgoing +from .._updates import RecordUpdate, RecordUpdateListener +from .._utils.asyncio import get_running_loop, run_coro_with_timeout +from .._utils.name import service_type_name +from .._utils.net import ( + IPVersion, + _encode_address, + _is_v6_address, +) +from .._utils.time import current_time_millis +from ..const import ( + _CLASS_IN, + _CLASS_UNIQUE, + _DNS_HOST_TTL, + _DNS_OTHER_TTL, + _FLAGS_QR_QUERY, + _LISTENER_TIME, + _TYPE_A, + _TYPE_AAAA, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) + + +# https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 +# The most common case for calling ServiceInfo is from a +# ServiceBrowser. After the first request we add a few random +# milliseconds to the delay between requests to reduce the chance +# that there are multiple ServiceBrowser callbacks running on +# the network that are firing at the same time when they +# see the same multicast response and decide to refresh +# the A/AAAA/SRV records for a host. +_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120) + +if TYPE_CHECKING: + from .._core import Zeroconf + + +def instance_name_from_service_info(info: "ServiceInfo") -> str: + """Calculate the instance name from the ServiceInfo.""" + # This is kind of funky because of the subtype based tests + # need to make subtypes a first class citizen + service_name = service_type_name(info.name) + if not info.type.endswith(service_name): + raise BadTypeInNameException + return info.name[: -len(service_name) - 1] + + +class ServiceInfo(RecordUpdateListener): + """Service information. + + Constructor parameters are as follows: + + * `type_`: fully qualified service type name + * `name`: fully qualified service name + * `port`: port that the service runs on + * `weight`: weight of the service + * `priority`: priority of the service + * `properties`: dictionary of properties (or a bytes object holding the contents of the `text` field). + converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to + value-less attributes. + * `server`: fully qualified name for service host (defaults to name) + * `host_ttl`: ttl used for A/SRV records + * `other_ttl`: ttl used for PTR/TXT records + * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, + or in parsed form as text; at most one of those parameters can be provided) + * interface_index: scope_id or zone_id for IPv6 link-local addresses i.e. an identifier of the interface + where the peer is connected to + """ + + text = b'' + + def __init__( + self, + type_: str, + name: str, + port: Optional[int] = None, + weight: int = 0, + priority: int = 0, + properties: Union[bytes, Dict] = b'', + server: Optional[str] = None, + host_ttl: int = _DNS_HOST_TTL, + other_ttl: int = _DNS_OTHER_TTL, + *, + addresses: Optional[List[bytes]] = None, + parsed_addresses: Optional[List[str]] = None, + interface_index: Optional[int] = None, + ) -> None: + # Accept both none, or one, but not both. + if addresses is not None and parsed_addresses is not None: + raise TypeError("addresses and parsed_addresses cannot be provided together") + if not type_.endswith(service_type_name(name, strict=False)): + raise BadTypeInNameException + self.type = type_ + self._name = name + self.key = name.lower() + if addresses is not None: + self._addresses = addresses + elif parsed_addresses is not None: + self._addresses = [_encode_address(a) for a in parsed_addresses] + else: + self._addresses = [] + # This results in an ugly error when registering, better check now + invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)] + if invalid: + raise TypeError( + 'Addresses must be bytes, got %s. Hint: convert string addresses ' + 'with socket.inet_pton' % invalid + ) + self.port = port + self.weight = weight + self.priority = priority + self.server = server if server else name + self.server_key = self.server.lower() + self._properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} + if isinstance(properties, bytes): + self._set_text(properties) + else: + self._set_properties(properties) + self.host_ttl = host_ttl + self.other_ttl = other_ttl + self.interface_index = interface_index + + @property + def name(self) -> str: + """The name of the service.""" + return self._name + + @name.setter + def name(self, name: str) -> None: + """Replace the the name and reset the key.""" + self._name = name + self.key = name.lower() + + @property + def addresses(self) -> List[bytes]: + """IPv4 addresses of this service. + + Only IPv4 addresses are returned for backward compatibility. + Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to + include IPv6 addresses as well. + """ + return self.addresses_by_version(IPVersion.V4Only) + + @addresses.setter + def addresses(self, value: List[bytes]) -> None: + """Replace the addresses list. + + This replaces all currently stored addresses, both IPv4 and IPv6. + """ + self._addresses = value + + @property + def properties(self) -> Dict: + """If properties were set in the constructor this property returns the original dictionary + of type `Dict[Union[bytes, str], Any]`. + + If properties are coming from the network, after decoding a TXT record, the keys are always + bytes and the values are either bytes, if there was a value, even empty, or `None`, if there + was none. No further decoding is attempted. The type returned is `Dict[bytes, Optional[bytes]]`. + """ + return self._properties + + def addresses_by_version(self, version: IPVersion) -> List[bytes]: + """List addresses matching IP version.""" + if version == IPVersion.V4Only: + return [addr for addr in self._addresses if not _is_v6_address(addr)] + if version == IPVersion.V6Only: + return list(filter(_is_v6_address, self._addresses)) + return self._addresses + + def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + """List addresses in their parsed string form.""" + result = self.addresses_by_version(version) + return [ + socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr) + for addr in result + ] + + def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + """Equivalent to parsed_addresses, with the exception that IPv6 Link-Local + addresses are qualified with % when available + """ + if self.interface_index is None: + return self.parsed_addresses(version) + + def is_link_local(addr_str: str) -> Any: + addr = ipaddress.ip_address(addr_str) + return addr.version == 6 and addr.is_link_local + + ll_addrs = list(filter(is_link_local, self.parsed_addresses(version))) + other_addrs = list(filter(lambda addr: not is_link_local(addr), self.parsed_addresses(version))) + return ["{}%{}".format(addr, self.interface_index) for addr in ll_addrs] + other_addrs + + def _set_properties(self, properties: Dict) -> None: + """Sets properties and text of this info from a dictionary""" + self._properties = properties + list_ = [] + result = b'' + for key, value in properties.items(): + if isinstance(key, str): + key = key.encode('utf-8') + + record = key + if value is not None: + if not isinstance(value, bytes): + value = str(value).encode('utf-8') + record += b'=' + value + list_.append(record) + for item in list_: + result = b''.join((result, bytes((len(item),)), item)) + self.text = result + + def _set_text(self, text: bytes) -> None: + """Sets properties and text given a text field""" + self.text = text + end = len(text) + if end == 0: + self._properties = {} + return + result: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} + index = 0 + strs = [] + while index < end: + length = text[index] + index += 1 + strs.append(text[index : index + length]) + index += length + + key: bytes + value: Optional[bytes] + for s in strs: + try: + key, value = s.split(b'=', 1) + except ValueError: + # No equals sign at all + key = s + value = None + + # Only update non-existent properties + if key and result.get(key) is None: + result[key] = value + + self._properties = result + + def get_name(self) -> str: + """Name accessor""" + return self.name[: len(self.name) - len(self.type) - 1] + + def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: + """Updates service information from a DNS record. + + This method is deprecated and will be removed in a future version. + update_records should be implemented instead. + + This method will be run in the event loop. + """ + if record is not None: + self._process_records_threadsafe(zc, now, [RecordUpdate(record, None)]) + + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None: + """Updates service information from a DNS record. + + This method will be run in the event loop. + """ + self._process_records_threadsafe(zc, now, records) + + def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None: + """Thread safe record updating.""" + update_addresses = False + for record_update in records: + if isinstance(record_update[0], DNSService): + update_addresses = True + self._process_record_threadsafe(record_update[0], now) + + # Only update addresses if the DNSService (.server) has changed + if not update_addresses: + return + + for record in self._get_address_records_from_cache(zc): + self._process_record_threadsafe(record, now) + + def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None: + if record.is_expired(now): + return + + if isinstance(record, DNSAddress): + if record.key == self.server_key and record.address not in self._addresses: + self._addresses.append(record.address) + if record.type is _TYPE_AAAA and ipaddress.IPv6Address(record.address).is_link_local: + self.interface_index = record.scope_id + return + + if isinstance(record, DNSService): + if record.key != self.key: + return + self.name = record.name + self.server = record.server + self.server_key = record.server.lower() + self.port = record.port + self.weight = record.weight + self.priority = record.priority + return + + if isinstance(record, DNSText): + if record.key == self.key: + self._set_text(record.text) + + def dns_addresses( + self, + override_ttl: Optional[int] = None, + version: IPVersion = IPVersion.All, + created: Optional[float] = None, + ) -> List[DNSAddress]: + """Return matching DNSAddress from ServiceInfo.""" + return [ + DNSAddress( + self.server, + _TYPE_AAAA if _is_v6_address(address) else _TYPE_A, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + address, + created=created, + ) + for address in self.addresses_by_version(version) + ] + + def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer: + """Return DNSPointer from ServiceInfo.""" + return DNSPointer( + self.type, + _TYPE_PTR, + _CLASS_IN, + override_ttl if override_ttl is not None else self.other_ttl, + self.name, + created, + ) + + def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: + """Return DNSService from ServiceInfo.""" + return DNSService( + self.name, + _TYPE_SRV, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + self.priority, + self.weight, + cast(int, self.port), + self.server, + created, + ) + + def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText: + """Return DNSText from ServiceInfo.""" + return DNSText( + self.name, + _TYPE_TXT, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.other_ttl, + self.text, + created, + ) + + def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: + """Get the address records from the cache.""" + return [ + *zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN), + *zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN), + ] + + def load_from_cache(self, zc: 'Zeroconf') -> bool: + """Populate the service info from the cache. + + This method is designed to be threadsafe. + """ + now = current_time_millis() + record_updates: List[RecordUpdate] = [] + cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) + if cached_srv_record: + # If there is a srv record, A and AAAA will already + # be called and we do not want to do it twice + record_updates.append(RecordUpdate(cached_srv_record, None)) + else: + for record in self._get_address_records_from_cache(zc): + record_updates.append(RecordUpdate(record, None)) + cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) + if cached_txt_record: + record_updates.append(RecordUpdate(cached_txt_record, None)) + self._process_records_threadsafe(zc, now, record_updates) + return self._is_complete + + @property + def _is_complete(self) -> bool: + """The ServiceInfo has all expected properties.""" + return not (self.text is None or not self._addresses) + + def request( + self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None + ) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + """ + assert zc.loop is not None and zc.loop.is_running() + if zc.loop == get_running_loop(): + raise RuntimeError("Use AsyncServiceInfo.async_request from the event loop") + return bool(run_coro_with_timeout(self.async_request(zc, timeout, question_type), zc.loop, timeout)) + + async def async_request( + self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None + ) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + """ + if self.load_from_cache(zc): + return True + + first_request = True + now = current_time_millis() + delay = _LISTENER_TIME + next_ = now + last = now + timeout + await zc.async_wait_for_start() + try: + zc.async_add_listener(self, None) + while not self._is_complete: + if last <= now: + return False + if next_ <= now: + out = self.generate_request_query( + zc, now, question_type or DNSQuestionType.QU if first_request else DNSQuestionType.QM + ) + first_request = False + if not out.questions: + return self.load_from_cache(zc) + zc.async_send(out) + next_ = now + delay + delay *= 2 + next_ += random.randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL) + + await zc.async_wait(min(next_, last) - now) + now = current_time_millis() + finally: + zc.async_remove_listener(self) + + return True + + def generate_request_query( + self, zc: 'Zeroconf', now: float, question_type: Optional[DNSQuestionType] = None + ) -> DNSOutgoing: + """Generate the request query.""" + out = DNSOutgoing(_FLAGS_QR_QUERY) + out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) + out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) + out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN) + if question_type == DNSQuestionType.QU: + for question in out.questions: + question.unicast = True + return out + + def __eq__(self, other: object) -> bool: + """Tests equality of service name""" + return isinstance(other, ServiceInfo) and other.name == self.name + + def __repr__(self) -> str: + """String representation""" + return '%s(%s)' % ( + type(self).__name__, + ', '.join( + '%s=%r' % (name, getattr(self, name)) + for name in ( + 'type', + 'name', + 'addresses', + 'port', + 'weight', + 'priority', + 'server', + 'properties', + 'interface_index', + ) + ), + ) diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py new file mode 100644 index 00000000..203b3b39 --- /dev/null +++ b/zeroconf/_services/registry.py @@ -0,0 +1,99 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import Dict, List, Optional, Union + + +from .info import ServiceInfo +from .._exceptions import ServiceNameAlreadyRegistered + + +class ServiceRegistry: + """A registry to keep track of services. + + The registry must only be accessed from + the event loop as it is not thread safe. + """ + + def __init__( + self, + ) -> None: + """Create the ServiceRegistry class.""" + self._services: Dict[str, ServiceInfo] = {} + self.types: Dict[str, List] = {} + self.servers: Dict[str, List] = {} + + def async_add(self, info: ServiceInfo) -> None: + """Add a new service to the registry.""" + self._add(info) + + def async_remove(self, info: Union[List[ServiceInfo], ServiceInfo]) -> None: + """Remove a new service from the registry.""" + self._remove(info if isinstance(info, list) else [info]) + + def async_update(self, info: ServiceInfo) -> None: + """Update new service in the registry.""" + self._remove([info]) + self._add(info) + + def async_get_service_infos(self) -> List[ServiceInfo]: + """Return all ServiceInfo.""" + return list(self._services.values()) + + def async_get_info_name(self, name: str) -> Optional[ServiceInfo]: + """Return all ServiceInfo for the name.""" + return self._services.get(name.lower()) + + def async_get_types(self) -> List[str]: + """Return all types.""" + return list(self.types.keys()) + + def async_get_infos_type(self, type_: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching type.""" + return self._async_get_by_index(self.types, type_) + + def async_get_infos_server(self, server: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching server.""" + return self._async_get_by_index(self.servers, server) + + def _async_get_by_index(self, records: Dict[str, List], key: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching the index.""" + return [self._services[name] for name in records.get(key.lower(), [])] + + def _add(self, info: ServiceInfo) -> None: + """Add a new service under the lock.""" + if info.key in self._services: + raise ServiceNameAlreadyRegistered + + self._services[info.key] = info + self.types.setdefault(info.type.lower(), []).append(info.key) + self.servers.setdefault(info.server_key, []).append(info.key) + + def _remove(self, infos: List[ServiceInfo]) -> None: + """Remove a services under the lock.""" + for info in infos: + if info.key not in self._services: + continue + old_service_info = self._services[info.key] + self.types[old_service_info.type.lower()].remove(info.key) + self.servers[old_service_info.server_key].remove(info.key) + del self._services[info.key] diff --git a/zeroconf/_services/types.py b/zeroconf/_services/types.py new file mode 100644 index 00000000..34b000f1 --- /dev/null +++ b/zeroconf/_services/types.py @@ -0,0 +1,83 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import time +from typing import Optional, Set, Tuple, Union + +from .browser import ServiceBrowser +from .._core import Zeroconf +from .._services import ServiceListener +from .._utils.net import IPVersion, InterfaceChoice, InterfacesType +from ..const import _SERVICE_TYPE_ENUMERATION_NAME + + +class ZeroconfServiceTypes(ServiceListener): + """ + Return all of the advertised services on any local networks + """ + + def __init__(self) -> None: + """Keep track of found services in a set.""" + self.found_services: Set[str] = set() + + def add_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Service added.""" + self.found_services.add(name) + + def update_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Service updated.""" + + def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Service removed.""" + + @classmethod + def find( + cls, + zc: Optional[Zeroconf] = None, + timeout: Union[int, float] = 5, + interfaces: InterfacesType = InterfaceChoice.All, + ip_version: Optional[IPVersion] = None, + ) -> Tuple[str, ...]: + """ + Return all of the advertised services on any local networks. + + :param zc: Zeroconf() instance. Pass in if already have an + instance running or if non-default interfaces are needed + :param timeout: seconds to wait for any responses + :param interfaces: interfaces to listen on. + :param ip_version: IP protocol version to use. + :return: tuple of service type strings + """ + local_zc = zc or Zeroconf(interfaces=interfaces, ip_version=ip_version) + listener = cls() + browser = ServiceBrowser(local_zc, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener) + + # wait for responses + time.sleep(timeout) + + browser.cancel() + + # close down anything we opened + if zc is None: + local_zc.close() + + return tuple(sorted(listener.found_services)) diff --git a/zeroconf/_updates.py b/zeroconf/_updates.py new file mode 100644 index 00000000..bc7dcab5 --- /dev/null +++ b/zeroconf/_updates.py @@ -0,0 +1,76 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import List, NamedTuple, Optional, TYPE_CHECKING + + +from ._dns import DNSRecord + + +if TYPE_CHECKING: + from ._core import Zeroconf + + +class RecordUpdate(NamedTuple): + new: DNSRecord + old: Optional[DNSRecord] + + +class RecordUpdateListener: + def update_record( # pylint: disable=no-self-use + self, zc: 'Zeroconf', now: float, record: DNSRecord + ) -> None: + """Update a single record. + + This method is deprecated and will be removed in a future version. + update_records should be implemented instead. + """ + raise RuntimeError("update_record is deprecated and will be removed in a future version.") + + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None: + """Update multiple records in one shot. + + All records that are received in a single packet are passed + to update_records. + + This implementation is a compatiblity shim to ensure older code + that uses RecordUpdateListener as a base class will continue to + get calls to update_record. This method will raise + NotImplementedError in a future version. + + At this point the cache will not have the new records + + Records are passed as a list of RecordUpdate. This + allows consumers of async_update_records to avoid cache lookups. + + This method will be run in the event loop. + """ + for record in records: + self.update_record(zc, now, record[0]) + + def async_update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + + This method will be run in the event loop. + """ diff --git a/zeroconf/_utils/__init__.py b/zeroconf/_utils/__init__.py new file mode 100644 index 00000000..2ef4b15b --- /dev/null +++ b/zeroconf/_utils/__init__.py @@ -0,0 +1,21 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" diff --git a/zeroconf/_utils/asyncio.py b/zeroconf/_utils/asyncio.py new file mode 100644 index 00000000..10b8b3d9 --- /dev/null +++ b/zeroconf/_utils/asyncio.py @@ -0,0 +1,123 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import contextlib +import queue +from typing import Any, Awaitable, Coroutine, List, Optional, Set, cast + +from .time import millis_to_seconds +from ..const import _LOADED_SYSTEM_TIMEOUT + +# The combined timeouts should be lower than _CLOSE_TIMEOUT + _WAIT_FOR_LOOP_TASKS_TIMEOUT +_TASK_AWAIT_TIMEOUT = 1 +_GET_ALL_TASKS_TIMEOUT = 3 +_WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT + + +def get_best_available_queue() -> queue.Queue: + """Create the best available queue type.""" + if hasattr(queue, "SimpleQueue"): + return queue.SimpleQueue() # type: ignore # pylint: disable=all + return queue.Queue() + + +# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed +async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None: + """Wait for an event or timeout.""" + loop = asyncio.get_event_loop() + future = loop.create_future() + + def _handle_timeout_or_wait_complete(*_: Any) -> None: + if not future.done(): + future.set_result(None) + + timer_handle = loop.call_later(timeout, _handle_timeout_or_wait_complete) + event_wait = loop.create_task(event.wait()) + event_wait.add_done_callback(_handle_timeout_or_wait_complete) + + try: + await future + finally: + timer_handle.cancel() + if not event_wait.done(): + event_wait.cancel() + with contextlib.suppress(asyncio.CancelledError): + await event_wait + + +async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> List[asyncio.Task]: + """Return all tasks running.""" + await asyncio.sleep(0) # flush out any call_soon_threadsafe + # If there are multiple event loops running, all_tasks is not + # safe EVEN WHEN CALLED FROM THE EVENTLOOP + # under PyPy so we have to try a few times. + for _ in range(3): + with contextlib.suppress(RuntimeError): + if hasattr(asyncio, 'all_tasks'): + return asyncio.all_tasks(loop) # type: ignore # pylint: disable=no-member + return asyncio.Task.all_tasks(loop) # type: ignore # pylint: disable=no-member + return [] + + +async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: + """Wait for the event loop thread we started to shutdown.""" + await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT) + + +async def await_awaitable(aw: Awaitable) -> None: + """Wait on an awaitable and the task it returns.""" + task = await aw + await task + + +def run_coro_with_timeout(aw: Coroutine, loop: asyncio.AbstractEventLoop, timeout: float) -> Any: + """Run a coroutine with a timeout.""" + return asyncio.run_coroutine_threadsafe(aw, loop).result( + millis_to_seconds(timeout) + _LOADED_SYSTEM_TIMEOUT + ) + + +def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: + """Wait for pending tasks and stop an event loop.""" + pending_tasks = set( + asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result(_GET_ALL_TASKS_TIMEOUT) + ) + pending_tasks -= set(task for task in pending_tasks if task.done()) + if pending_tasks: + asyncio.run_coroutine_threadsafe(_wait_for_loop_tasks(pending_tasks), loop).result( + _WAIT_FOR_LOOP_TASKS_TIMEOUT + ) + loop.call_soon_threadsafe(loop.stop) + + +# Remove the call to _get_running_loop once we drop python 3.6 support +def get_running_loop() -> Optional[asyncio.AbstractEventLoop]: + """Check if an event loop is already running.""" + with contextlib.suppress(RuntimeError): + if hasattr(asyncio, "get_running_loop"): + return cast( + asyncio.AbstractEventLoop, + asyncio.get_running_loop(), # type: ignore # pylint: disable=no-member # noqa + ) + return asyncio._get_running_loop() # pylint: disable=no-member,protected-access + return None diff --git a/zeroconf/_utils/name.py b/zeroconf/_utils/name.py new file mode 100644 index 00000000..c59ac33a --- /dev/null +++ b/zeroconf/_utils/name.py @@ -0,0 +1,157 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from .._exceptions import BadTypeInNameException +from ..const import ( + _HAS_ASCII_CONTROL_CHARS, + _HAS_A_TO_Z, + _HAS_ONLY_A_TO_Z_NUM_HYPHEN, + _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE, + _LOCAL_TRAILER, + _NONTCP_PROTOCOL_LOCAL_TRAILER, + _TCP_PROTOCOL_LOCAL_TRAILER, +) + + +def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: disable=too-many-branches + """ + Validate a fully qualified service name, instance or subtype. [rfc6763] + + Returns fully qualified service name. + + Domain names used by mDNS-SD take the following forms: + + . <_tcp|_udp> . local. + . . <_tcp|_udp> . local. + ._sub . . <_tcp|_udp> . local. + + 1) must end with 'local.' + + This is true because we are implementing mDNS and since the 'm' means + multi-cast, the 'local.' domain is mandatory. + + 2) local is preceded with either '_udp.' or '_tcp.' unless + strict is False + + 3) service name precedes <_tcp|_udp> unless + strict is False + + The rules for Service Names [RFC6335] state that they may be no more + than fifteen characters long (not counting the mandatory underscore), + consisting of only letters, digits, and hyphens, must begin and end + with a letter or digit, must not contain consecutive hyphens, and + must contain at least one letter. + + The instance name and sub type may be up to 63 bytes. + + The portion of the Service Instance Name is a user- + friendly name consisting of arbitrary Net-Unicode text [RFC5198]. It + MUST NOT contain ASCII control characters (byte values 0x00-0x1F and + 0x7F) [RFC20] but otherwise is allowed to contain any characters, + without restriction, including spaces, uppercase, lowercase, + punctuation -- including dots -- accented characters, non-Roman text, + and anything else that may be represented using Net-Unicode. + + :param type_: Type, SubType or service name to validate + :return: fully qualified service name (eg: _http._tcp.local.) + """ + if len(type_) > 256: + # https://datatracker.ietf.org/doc/html/rfc6763#section-7.2 + raise BadTypeInNameException("Full name (%s) must be > 256 bytes" % type_) + + if type_.endswith((_TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER)): + remaining = type_[: -len(_TCP_PROTOCOL_LOCAL_TRAILER)].split('.') + trailer = type_[-len(_TCP_PROTOCOL_LOCAL_TRAILER) :] + has_protocol = True + elif strict: + raise BadTypeInNameException( + "Type '%s' must end with '%s' or '%s'" + % (type_, _TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER) + ) + elif type_.endswith(_LOCAL_TRAILER): + remaining = type_[: -len(_LOCAL_TRAILER)].split('.') + trailer = type_[-len(_LOCAL_TRAILER) + 1 :] + has_protocol = False + else: + raise BadTypeInNameException("Type '%s' must end with '%s'" % (type_, _LOCAL_TRAILER)) + + if strict or has_protocol: + service_name = remaining.pop() + if not service_name: + raise BadTypeInNameException("No Service name found") + + if len(remaining) == 1 and len(remaining[0]) == 0: + raise BadTypeInNameException("Type '%s' must not start with '.'" % type_) + + if service_name[0] != '_': + raise BadTypeInNameException("Service name (%s) must start with '_'" % service_name) + + test_service_name = service_name[1:] + + if strict and len(test_service_name) > 15: + # https://datatracker.ietf.org/doc/html/rfc6763#section-7.2 + raise BadTypeInNameException("Service name (%s) must be <= 15 bytes" % test_service_name) + + if '--' in test_service_name: + raise BadTypeInNameException("Service name (%s) must not contain '--'" % test_service_name) + + if '-' in (test_service_name[0], test_service_name[-1]): + raise BadTypeInNameException( + "Service name (%s) may not start or end with '-'" % test_service_name + ) + + if not _HAS_A_TO_Z.search(test_service_name): + raise BadTypeInNameException( + "Service name (%s) must contain at least one letter (eg: 'A-Z')" % test_service_name + ) + + allowed_characters_re = ( + _HAS_ONLY_A_TO_Z_NUM_HYPHEN if strict else _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE + ) + + if not allowed_characters_re.search(test_service_name): + raise BadTypeInNameException( + "Service name (%s) must contain only these characters: " + "A-Z, a-z, 0-9, hyphen ('-')%s" % (test_service_name, "" if strict else ", underscore ('_')") + ) + else: + service_name = '' + + if remaining and remaining[-1] == '_sub': + remaining.pop() + if len(remaining) == 0 or len(remaining[0]) == 0: + raise BadTypeInNameException("_sub requires a subtype name") + + if len(remaining) > 1: + remaining = ['.'.join(remaining)] + + if remaining: + length = len(remaining[0].encode('utf-8')) + if length > 63: + raise BadTypeInNameException("Too long: '%s'" % remaining[0]) + + if _HAS_ASCII_CONTROL_CHARS.search(remaining[0]): + raise BadTypeInNameException( + "Ascii control character 0x00-0x1F and 0x7F illegal in '%s'" % remaining[0] + ) + + return service_name + trailer diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py new file mode 100644 index 00000000..c53ec978 --- /dev/null +++ b/zeroconf/_utils/net.py @@ -0,0 +1,404 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +import errno +import ipaddress +import socket +import struct +import sys +from typing import Any, List, Optional, Tuple, Union, cast + +import ifaddr + +from .._logger import log +from ..const import _IPPROTO_IPV6, _MDNS_ADDR, _MDNS_ADDR6, _MDNS_PORT + + +@enum.unique +class InterfaceChoice(enum.Enum): + Default = 1 + All = 2 + + +InterfacesType = Union[List[Union[str, int, Tuple[Tuple[str, int, int], int]]], InterfaceChoice] + + +@enum.unique +class ServiceStateChange(enum.Enum): + Added = 1 + Removed = 2 + Updated = 3 + + +@enum.unique +class IPVersion(enum.Enum): + V4Only = 1 + V6Only = 2 + All = 3 + + +# utility functions + + +def _is_v6_address(addr: bytes) -> bool: + return len(addr) == 16 + + +def _encode_address(address: str) -> bytes: + is_ipv6 = ':' in address + address_family = socket.AF_INET6 if is_ipv6 else socket.AF_INET + return socket.inet_pton(address_family, address) + + +def get_all_addresses() -> List[str]: + return list(set(addr.ip for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4)) + + +def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]: + # IPv6 multicast uses positive indexes for interfaces + # TODO: What about multi-address interfaces? + return list( + set((addr.ip, iface.index) for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv6) + ) + + +def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, int, int], int]: + ipaddr = ipaddress.ip_address(ip) + for adapter in adapters: + for adapter_ip in adapter.ips: + # IPv6 addresses are represented as tuples + if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr: + return (cast(Tuple[str, int, int], adapter_ip.ip), cast(int, adapter.index)) + + raise RuntimeError('No adapter found for IP address %s' % ip) + + +def interface_index_to_ip6_address(adapters: List[Any], index: int) -> Tuple[str, int, int]: + for adapter in adapters: + if adapter.index == index: + for adapter_ip in adapter.ips: + # IPv6 addresses are represented as tuples + if isinstance(adapter_ip.ip, tuple): + return cast(Tuple[str, int, int], adapter_ip.ip) + + raise RuntimeError('No adapter found for index %s' % index) + + +def ip6_addresses_to_indexes( + interfaces: List[Union[str, int, Tuple[Tuple[str, int, int], int]]] +) -> List[Tuple[Tuple[str, int, int], int]]: + """Convert IPv6 interface addresses to interface indexes. + + IPv4 addresses are ignored. + + :param interfaces: List of IP addresses and indexes. + :returns: List of indexes. + """ + result = [] + adapters = ifaddr.get_adapters() + + for iface in interfaces: + if isinstance(iface, int): + result.append((interface_index_to_ip6_address(adapters, iface), iface)) + elif isinstance(iface, str) and ipaddress.ip_address(iface).version == 6: + result.append(ip6_to_address_and_index(adapters, iface)) + + return result + + +def normalize_interface_choice( + choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only +) -> List[Union[str, Tuple[Tuple[str, int, int], int]]]: + """Convert the interfaces choice into internal representation. + + :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only). + :param ip_address: IP version to use (ignored if `choice` is a list). + :returns: List of IP addresses (for IPv4) and indexes (for IPv6). + """ + result: List[Union[str, Tuple[Tuple[str, int, int], int]]] = [] + if choice is InterfaceChoice.Default: + if ip_version != IPVersion.V4Only: + # IPv6 multicast uses interface 0 to mean the default + result.append((('', 0, 0), 0)) + if ip_version != IPVersion.V6Only: + result.append('0.0.0.0') + elif choice is InterfaceChoice.All: + if ip_version != IPVersion.V4Only: + result.extend(get_all_addresses_v6()) + if ip_version != IPVersion.V6Only: + result.extend(get_all_addresses()) + if not result: + raise RuntimeError( + 'No interfaces to listen on, check that any interfaces have IP version %s' % ip_version + ) + elif isinstance(choice, list): + # First, take IPv4 addresses. + result = [i for i in choice if isinstance(i, str) and ipaddress.ip_address(i).version == 4] + # Unlike IP_ADD_MEMBERSHIP, IPV6_JOIN_GROUP requires interface indexes. + result += ip6_addresses_to_indexes(choice) + else: + raise TypeError("choice must be a list or InterfaceChoice, got %r" % choice) + return result + + +def disable_ipv6_only_or_raise(s: socket.socket) -> None: + """Make V6 sockets work for both V4 and V6 (required for Windows).""" + try: + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, False) + except OSError: + log.error('Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6') + raise + + +def set_so_reuseport_if_available(s: socket.socket) -> None: + """Set SO_REUSEADDR on a socket if available.""" + # SO_REUSEADDR should be equivalent to SO_REUSEPORT for + # multicast UDP sockets (p 731, "TCP/IP Illustrated, + # Volume 2"), but some BSD-derived systems require + # SO_REUSEPORT to be specified explicitly. Also, not all + # versions of Python have SO_REUSEPORT available. + # Catch OSError and socket.error for kernel versions <3.9 because lacking + # SO_REUSEPORT support. + if not hasattr(socket, 'SO_REUSEPORT'): + return + + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # pylint: disable=no-member + except OSError as err: + if err.errno != errno.ENOPROTOOPT: + raise + + +def set_mdns_port_socket_options_for_ip_version( + s: socket.socket, bind_addr: Union[Tuple[str], Tuple[str, int, int]], ip_version: IPVersion +) -> None: + """Set ttl/hops and loop for mdns port.""" + if ip_version != IPVersion.V6Only: + ttl = struct.pack(b'B', 255) + loop = struct.pack(b'B', 1) + # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and + # IP_MULTICAST_LOOP socket options as an unsigned char. + try: + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) + except socket.error as e: + if bind_addr[0] != '' or get_errno(e) != errno.EINVAL: # Fails to set on MacOS + raise + + if ip_version != IPVersion.V4Only: + # However, char doesn't work here (at least on Linux) + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True) + + +def new_socket( + bind_addr: Union[Tuple[str], Tuple[str, int, int]], + port: int = _MDNS_PORT, + ip_version: IPVersion = IPVersion.V4Only, + apple_p2p: bool = False, +) -> socket.socket: + log.debug( + 'Creating new socket with port %s, ip_version %s, apple_p2p %s and bind_addr %r', + port, + ip_version, + apple_p2p, + bind_addr, + ) + socket_family = socket.AF_INET if ip_version == IPVersion.V4Only else socket.AF_INET6 + s = socket.socket(socket_family, socket.SOCK_DGRAM) + + if ip_version == IPVersion.All: + disable_ipv6_only_or_raise(s) + + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + set_so_reuseport_if_available(s) + + if port == _MDNS_PORT: + set_mdns_port_socket_options_for_ip_version(s, bind_addr, ip_version) + + if apple_p2p: + # SO_RECV_ANYIF = 0x1104 + # https://opensource.apple.com/source/xnu/xnu-4570.41.2/bsd/sys/socket.h + s.setsockopt(socket.SOL_SOCKET, 0x1104, 1) + + s.bind((bind_addr[0], port, *bind_addr[1:])) + log.debug('Created socket %s', s) + return s + + +def add_multicast_member( + listen_socket: socket.socket, + interface: Union[str, Tuple[Tuple[str, int, int], int]], +) -> bool: + # This is based on assumptions in normalize_interface_choice + is_v6 = isinstance(interface, tuple) + err_einval = {errno.EINVAL} + if sys.platform == 'win32': + # No WSAEINVAL definition in typeshed + err_einval |= {cast(Any, errno).WSAEINVAL} # pylint: disable=no-member + log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno()) + try: + if is_v6: + try: + mdns_addr6_bytes = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) + except OSError: + log.info( + 'Unable to translate IPv6 address when adding %s to multicast group, ' + 'this can happen if IPv6 is disabled on the system', + interface, + ) + return False + iface_bin = struct.pack('@I', cast(int, interface[1])) + _value = mdns_addr6_bytes + iface_bin + listen_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, _value) + else: + _value = socket.inet_aton(_MDNS_ADDR) + socket.inet_aton(cast(str, interface)) + listen_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value) + except socket.error as e: + _errno = get_errno(e) + if _errno == errno.EADDRINUSE: + log.info( + 'Address in use when adding %s to multicast group, ' + 'it is expected to happen on some systems', + interface, + ) + return False + if _errno == errno.EADDRNOTAVAIL: + log.info( + 'Address not available when adding %s to multicast ' + 'group, it is expected to happen on some systems', + interface, + ) + return False + if _errno in err_einval: + log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface) + return False + if _errno == errno.ENOPROTOOPT: + log.info( + 'Failed to set socket option on %s, this can happen if ' + 'the network adapter is in a disconnected state', + interface, + ) + return False + if is_v6 and _errno == errno.ENODEV: + log.info( + 'Address in use when adding %s to multicast group, ' + 'it is expected to happen when the device does not have ipv6', + interface, + ) + return False + raise + return True + + +def new_respond_socket( + interface: Union[str, Tuple[Tuple[str, int, int], int]], + apple_p2p: bool = False, +) -> Optional[socket.socket]: + is_v6 = isinstance(interface, tuple) + respond_socket = new_socket( + ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), + apple_p2p=apple_p2p, + bind_addr=cast(Tuple[Tuple[str, int, int], int], interface)[0] if is_v6 else (cast(str, interface),), + ) + log.debug('Configuring socket %s with multicast interface %s', respond_socket, interface) + if is_v6: + iface_bin = struct.pack('@I', cast(int, interface[1])) + respond_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, iface_bin) + else: + respond_socket.setsockopt( + socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(cast(str, interface)) + ) + return respond_socket + + +def create_sockets( + interfaces: InterfacesType = InterfaceChoice.All, + unicast: bool = False, + ip_version: IPVersion = IPVersion.V4Only, + apple_p2p: bool = False, +) -> Tuple[Optional[socket.socket], List[socket.socket]]: + if unicast: + listen_socket = None + else: + listen_socket = new_socket(ip_version=ip_version, apple_p2p=apple_p2p, bind_addr=('',)) + + normalized_interfaces = normalize_interface_choice(interfaces, ip_version) + + # If we are using InterfaceChoice.Default we can use + # a single socket to listen and respond. + if not unicast and interfaces is InterfaceChoice.Default: + for i in normalized_interfaces: + add_multicast_member(cast(socket.socket, listen_socket), i) + return listen_socket, [cast(socket.socket, listen_socket)] + + respond_sockets = [] + + for i in normalized_interfaces: + if not unicast: + if add_multicast_member(cast(socket.socket, listen_socket), i): + respond_socket = new_respond_socket(i, apple_p2p=apple_p2p) + else: + respond_socket = None + else: + respond_socket = new_socket( + port=0, + ip_version=ip_version, + apple_p2p=apple_p2p, + bind_addr=i[0] if isinstance(i, tuple) else (i,), + ) + + if respond_socket is not None: + respond_sockets.append(respond_socket) + + return listen_socket, respond_sockets + + +def get_errno(e: Exception) -> int: + assert isinstance(e, socket.error) + return cast(int, e.args[0]) + + +def can_send_to(ipv6_socket: bool, address: str) -> bool: + """Check if the address type matches the socket type. + + This function does not validate if the address is a valid + ipv6 or ipv4 address. + """ + return ":" in address if ipv6_socket else ":" not in address + + +def autodetect_ip_version(interfaces: InterfacesType) -> IPVersion: + """Auto detect the IP version when it is not provided.""" + if isinstance(interfaces, list): + has_v6 = any( + isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6) + for i in interfaces + ) + has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces) + if has_v4 and has_v6: + return IPVersion.All + if has_v6: + return IPVersion.V6Only + + return IPVersion.V4Only diff --git a/zeroconf/_utils/time.py b/zeroconf/_utils/time.py new file mode 100644 index 00000000..0ba91ead --- /dev/null +++ b/zeroconf/_utils/time.py @@ -0,0 +1,34 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + + +import time + + +def current_time_millis() -> float: + """Current system time in milliseconds""" + return time.time() * 1000 + + +def millis_to_seconds(millis: float) -> float: + """Convert milliseconds to seconds.""" + return millis / 1000.0 diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py new file mode 100644 index 00000000..ef7e7f64 --- /dev/null +++ b/zeroconf/asyncio.py @@ -0,0 +1,267 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" +import asyncio +import contextlib +from types import TracebackType # noqa # used in type hints +from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union + +from ._core import Zeroconf +from ._dns import DNSQuestionType +from ._services import ServiceListener +from ._services.browser import _ServiceBrowserBase +from ._services.info import ServiceInfo +from ._services.types import ZeroconfServiceTypes +from ._utils.net import IPVersion, InterfaceChoice, InterfacesType +from .const import ( + _BROWSER_TIME, + _MDNS_PORT, + _SERVICE_TYPE_ENUMERATION_NAME, +) + + +__all__ = [ + "AsyncZeroconf", + "AsyncServiceInfo", + "AsyncServiceBrowser", + "AsyncZeroconfServiceTypes", +] + + +class AsyncServiceInfo(ServiceInfo): + """An async version of ServiceInfo.""" + + +class AsyncServiceBrowser(_ServiceBrowserBase): + """Used to browse for a service for specific type(s). + + Constructor parameters are as follows: + + * `zc`: A Zeroconf instance + * `type_`: fully qualified service type name + * `handler`: ServiceListener or Callable that knows how to process ServiceStateChange events + * `listener`: ServiceListener + * `addr`: address to send queries (will default to multicast) + * `port`: port to send queries (will default to mdns 5353) + * `delay`: The initial delay between answering questions + * `question_type`: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU) + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability. + """ + + def __init__( + self, + zeroconf: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[ServiceListener] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + question_type: Optional[DNSQuestionType] = None, + ) -> None: + super().__init__(zeroconf, type_, handlers, listener, addr, port, delay, question_type) + self._async_start() + + async def async_cancel(self) -> None: + """Cancel the browser.""" + self._async_cancel() + + +class AsyncZeroconfServiceTypes(ZeroconfServiceTypes): + """An async version of ZeroconfServiceTypes.""" + + @classmethod + async def async_find( + cls, + aiozc: Optional['AsyncZeroconf'] = None, + timeout: Union[int, float] = 5, + interfaces: InterfacesType = InterfaceChoice.All, + ip_version: Optional[IPVersion] = None, + ) -> Tuple[str, ...]: + """ + Return all of the advertised services on any local networks. + + :param aiozc: AsyncZeroconf() instance. Pass in if already have an + instance running or if non-default interfaces are needed + :param timeout: seconds to wait for any responses + :param interfaces: interfaces to listen on. + :param ip_version: IP protocol version to use. + :return: tuple of service type strings + """ + local_zc = aiozc or AsyncZeroconf(interfaces=interfaces, ip_version=ip_version) + listener = cls() + async_browser = AsyncServiceBrowser( + local_zc.zeroconf, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener + ) + + # wait for responses + await asyncio.sleep(timeout) + + await async_browser.async_cancel() + + # close down anything we opened + if aiozc is None: + await local_zc.async_close() + + return tuple(sorted(listener.found_services)) + + +class AsyncZeroconf: + """Implementation of Zeroconf Multicast DNS Service Discovery + + Supports registration, unregistration, queries and browsing. + + The async version is currently a wrapper around the sync version + with I/O being done in the executor for backwards compatibility. + """ + + def __init__( + self, + interfaces: InterfacesType = InterfaceChoice.All, + unicast: bool = False, + ip_version: Optional[IPVersion] = None, + apple_p2p: bool = False, + zc: Optional[Zeroconf] = None, + ) -> None: + """Creates an instance of the Zeroconf class, establishing + multicast communications, listening and reaping threads. + + :param interfaces: :class:`InterfaceChoice` or a list of IP addresses + (IPv4 and IPv6) and interface indexes (IPv6 only). + + IPv6 notes for non-POSIX systems: + * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` + on Python versions before 3.8. + + Also listening on loopback (``::1``) doesn't work, use a real address. + :param ip_version: IP versions to support. If `choice` is a list, the default is detected + from it. Otherwise defaults to V4 only for backward compatibility. + :param apple_p2p: use AWDL interface (only macOS) + """ + self.zeroconf = zc or Zeroconf( + interfaces=interfaces, + unicast=unicast, + ip_version=ip_version, + apple_p2p=apple_p2p, + ) + self.async_browsers: Dict[ServiceListener, AsyncServiceBrowser] = {} + + async def async_register_service( + self, + info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, + cooperating_responders: bool = False, + ) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`). + + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. + """ + return await self.zeroconf.async_register_service( + info, ttl, allow_name_change, cooperating_responders + ) + + async def async_unregister_all_services(self) -> None: + """Unregister all registered services. + + Unlike async_register_service and async_unregister_service, this + method does not return a future and is always expected to be + awaited since its only called at shutdown. + """ + await self.zeroconf.async_unregister_all_services() + + async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: + """Unregister a service. + + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. + """ + return await self.zeroconf.async_unregister_service(info) + + async def async_update_service(self, info: ServiceInfo) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. + + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. + """ + return await self.zeroconf.async_update_service(info) + + async def async_close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries.""" + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(self.zeroconf.async_wait_for_start(), timeout=1) + await self.async_remove_all_service_listeners() + await self.async_unregister_all_services() + await self.zeroconf._async_close() # pylint: disable=protected-access + + async def async_get_service_info( + self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None + ) -> Optional[AsyncServiceInfo]: + """Returns network's service information for a particular + name and type, or None if no service matches by the timeout, + which defaults to 3 seconds.""" + info = AsyncServiceInfo(type_, name) + if await info.async_request(self.zeroconf, timeout, question_type): + return info + return None + + async def async_add_service_listener(self, type_: str, listener: ServiceListener) -> None: + """Adds a listener for a particular service type. This object + will then have its add_service and remove_service methods called when + services of that type become available and unavailable.""" + await self.async_remove_service_listener(listener) + self.async_browsers[listener] = AsyncServiceBrowser(self.zeroconf, type_, listener) + + async def async_remove_service_listener(self, listener: ServiceListener) -> None: + """Removes a listener from the set that is currently listening.""" + if listener in self.async_browsers: + await self.async_browsers[listener].async_cancel() + del self.async_browsers[listener] + + async def async_remove_all_service_listeners(self) -> None: + """Removes a listener from the set that is currently listening.""" + await asyncio.gather( + *(self.async_remove_service_listener(listener) for listener in list(self.async_browsers)) + ) + + async def __aenter__(self) -> 'AsyncZeroconf': + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + await self.async_close() + return None diff --git a/zeroconf/const.py b/zeroconf/const.py new file mode 100644 index 00000000..ff5cc3a2 --- /dev/null +++ b/zeroconf/const.py @@ -0,0 +1,159 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import re +import socket + +# Some timing constants + +_UNREGISTER_TIME = 125 # ms +_CHECK_TIME = 175 # ms +_REGISTER_TIME = 225 # ms +_LISTENER_TIME = 200 # ms +_BROWSER_TIME = 1000 # ms +_DUPLICATE_QUESTION_INTERVAL = _BROWSER_TIME - 1 # ms +_BROWSER_BACKOFF_LIMIT = 3600 # s +_CACHE_CLEANUP_INTERVAL = 10000 # ms +_LOADED_SYSTEM_TIMEOUT = 10 # s +_ONE_SECOND = 1000 # ms + +# If the system is loaded or the event +# loop was blocked by another task that was doing I/O in the loop +# (shouldn't happen but it does in practice) we need to give +# a buffer timeout to ensure a coroutine can finish before +# the future times out + +# Some DNS constants + +_MDNS_ADDR = '224.0.0.251' +_MDNS_ADDR6 = 'ff02::fb' +_MDNS_PORT = 5353 +_DNS_PORT = 53 +_DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762 +_DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762 +# Currently we enforce a minimum TTL for PTR records to avoid +# ServiceBrowsers generating excessive queries refresh queries. +# Apple uses a 15s minimum TTL, however we do not have the same +# level of rate limit and safe guards so we use 1/4 of the recommended value +_DNS_PTR_MIN_TTL = _DNS_OTHER_TTL / 4 + +_DNS_PACKET_HEADER_LEN = 12 + +_MAX_MSG_TYPICAL = 1460 # unused +_MAX_MSG_ABSOLUTE = 8966 + +_FLAGS_QR_MASK = 0x8000 # query response mask +_FLAGS_QR_QUERY = 0x0000 # query +_FLAGS_QR_RESPONSE = 0x8000 # response + +_FLAGS_AA = 0x0400 # Authoritative answer +_FLAGS_TC = 0x0200 # Truncated +_FLAGS_RD = 0x0100 # Recursion desired +_FLAGS_RA = 0x8000 # Recursion available + +_FLAGS_Z = 0x0040 # Zero +_FLAGS_AD = 0x0020 # Authentic data +_FLAGS_CD = 0x0010 # Checking disabled + +_CLASS_IN = 1 +_CLASS_CS = 2 +_CLASS_CH = 3 +_CLASS_HS = 4 +_CLASS_NONE = 254 +_CLASS_ANY = 255 +_CLASS_MASK = 0x7FFF +_CLASS_UNIQUE = 0x8000 + +_TYPE_A = 1 +_TYPE_NS = 2 +_TYPE_MD = 3 +_TYPE_MF = 4 +_TYPE_CNAME = 5 +_TYPE_SOA = 6 +_TYPE_MB = 7 +_TYPE_MG = 8 +_TYPE_MR = 9 +_TYPE_NULL = 10 +_TYPE_WKS = 11 +_TYPE_PTR = 12 +_TYPE_HINFO = 13 +_TYPE_MINFO = 14 +_TYPE_MX = 15 +_TYPE_TXT = 16 +_TYPE_AAAA = 28 +_TYPE_SRV = 33 +_TYPE_NSEC = 47 +_TYPE_ANY = 255 + +# Mapping constants to names + +_CLASSES = { + _CLASS_IN: "in", + _CLASS_CS: "cs", + _CLASS_CH: "ch", + _CLASS_HS: "hs", + _CLASS_NONE: "none", + _CLASS_ANY: "any", +} + +_TYPES = { + _TYPE_A: "a", + _TYPE_NS: "ns", + _TYPE_MD: "md", + _TYPE_MF: "mf", + _TYPE_CNAME: "cname", + _TYPE_SOA: "soa", + _TYPE_MB: "mb", + _TYPE_MG: "mg", + _TYPE_MR: "mr", + _TYPE_NULL: "null", + _TYPE_WKS: "wks", + _TYPE_PTR: "ptr", + _TYPE_HINFO: "hinfo", + _TYPE_MINFO: "minfo", + _TYPE_MX: "mx", + _TYPE_TXT: "txt", + _TYPE_AAAA: "quada", + _TYPE_SRV: "srv", + _TYPE_ANY: "any", + _TYPE_NSEC: "nsec", +} + +_HAS_A_TO_Z = re.compile(r'[A-Za-z]') +_HAS_ONLY_A_TO_Z_NUM_HYPHEN = re.compile(r'^[A-Za-z0-9\-]+$') +_HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$') +_HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]') + +_EXPIRE_REFRESH_TIME_PERCENT = 75 + +_LOCAL_TRAILER = '.local.' +_TCP_PROTOCOL_LOCAL_TRAILER = '._tcp.local.' +_NONTCP_PROTOCOL_LOCAL_TRAILER = '._udp.local.' + +# https://datatracker.ietf.org/doc/html/rfc6763#section-9 +_SERVICE_TYPE_ENUMERATION_NAME = "_services._dns-sd._udp.local." + +try: + _IPPROTO_IPV6 = socket.IPPROTO_IPV6 +except AttributeError: + # Sigh: https://bugs.python.org/issue29515 + _IPPROTO_IPV6 = 41 diff --git a/zeroconf/test.py b/zeroconf/test.py deleted file mode 100644 index 5a89c9f6..00000000 --- a/zeroconf/test.py +++ /dev/null @@ -1,1370 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - - -""" Unit tests for zeroconf.py """ - -import copy -import logging -import socket -import struct -import time -import unittest -from threading import Event -from typing import Dict, Optional # noqa # used in type hints -from typing import cast - -from nose.plugins.attrib import attr - -import zeroconf as r -from zeroconf import ( - DNSHinfo, - DNSText, - ServiceBrowser, - ServiceInfo, - ServiceStateChange, - Zeroconf, - ZeroconfServiceTypes, -) - -log = logging.getLogger('zeroconf') -original_logging_level = logging.NOTSET - - -def setup_module(): - global original_logging_level - original_logging_level = log.level - log.setLevel(logging.DEBUG) - - -def teardown_module(): - if original_logging_level != logging.NOTSET: - log.setLevel(original_logging_level) - - -class TestDunder(unittest.TestCase): - def test_dns_text_repr(self): - # There was an issue on Python 3 that prevented DNSText's repr - # from working when the text was longer than 10 bytes - text = DNSText('irrelevant', 0, 0, 0, b'12345678901') - repr(text) - - text = DNSText('irrelevant', 0, 0, 0, b'123') - repr(text) - - def test_dns_hinfo_repr_eq(self): - hinfo = DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os') - assert hinfo == hinfo - repr(hinfo) - - def test_dns_pointer_repr(self): - pointer = r.DNSPointer('irrelevant', r._TYPE_PTR, r._CLASS_IN, r._DNS_OTHER_TTL, '123') - repr(pointer) - - def test_dns_address_repr(self): - address = r.DNSAddress('irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - repr(address) - - def test_dns_question_repr(self): - question = r.DNSQuestion('irrelevant', r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE) - repr(question) - assert not question != question - - def test_dns_service_repr(self): - service = r.DNSService('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, 'a') - repr(service) - - def test_dns_record_abc(self): - record = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) - self.assertRaises(r.AbstractMethodException, record.__eq__, record) - self.assertRaises(r.AbstractMethodException, record.write, None) - - def test_dns_record_reset_ttl(self): - record = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) - time.sleep(1) - record2 = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) - now = r.current_time_millis() - - assert record.created != record2.created - assert record.get_remaining_ttl(now) != record2.get_remaining_ttl(now) - - record.reset_ttl(record2) - - assert record.ttl == record2.ttl - assert record.created == record2.created - assert record.get_remaining_ttl(now) == record2.get_remaining_ttl(now) - - def test_service_info_dunder(self): - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - info = ServiceInfo( - type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, b'', "ash-2.local." - ) - - assert not info != info - repr(info) - - def test_service_info_text_properties_not_given(self): - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - info = ServiceInfo( - type_=type_, - name=registration_name, - address=socket.inet_aton("10.0.1.2"), - port=80, - server="ash-2.local.", - ) - - assert isinstance(info.text, bytes) - repr(info) - - def test_dns_outgoing_repr(self): - dns_outgoing = r.DNSOutgoing(r._FLAGS_QR_QUERY) - repr(dns_outgoing) - - -class PacketGeneration(unittest.TestCase): - def test_parse_own_packet_simple(self): - generated = r.DNSOutgoing(0) - r.DNSIncoming(generated.packet()) - - def test_parse_own_packet_simple_unicast(self): - generated = r.DNSOutgoing(0, False) - r.DNSIncoming(generated.packet()) - - def test_parse_own_packet_flags(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - r.DNSIncoming(generated.packet()) - - def test_parse_own_packet_question(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - generated.add_question(r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN)) - r.DNSIncoming(generated.packet()) - - def test_parse_own_packet_response(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - generated.add_answer_at_time( - r.DNSService("æøå.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local."), 0 - ) - parsed = r.DNSIncoming(generated.packet()) - self.assertEqual(len(generated.answers), 1) - self.assertEqual(len(generated.answers), len(parsed.answers)) - - def test_match_question(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) - generated.add_question(question) - parsed = r.DNSIncoming(generated.packet()) - self.assertEqual(len(generated.questions), 1) - self.assertEqual(len(generated.questions), len(parsed.questions)) - self.assertEqual(question, parsed.questions[0]) - - def test_suppress_answer(self): - query_generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) - query_generated.add_question(question) - answer1 = r.DNSService( - "testname1.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local." - ) - staleanswer2 = r.DNSService( - "testname2.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL / 2, 0, 0, 80, "foo.local." - ) - answer2 = r.DNSService( - "testname2.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local." - ) - query_generated.add_answer_at_time(answer1, 0) - query_generated.add_answer_at_time(staleanswer2, 0) - query = r.DNSIncoming(query_generated.packet()) - - # Should be suppressed - response = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - response.add_answer(query, answer1) - assert len(response.answers) == 0 - - # Should not be suppressed, TTL in query is too short - response.add_answer(query, answer2) - assert len(response.answers) == 1 - - # Should not be suppressed, name is different - tmp = copy.copy(answer1) - tmp.name = "testname3.local." - response.add_answer(query, tmp) - assert len(response.answers) == 2 - - # Should not be suppressed, type is different - tmp = copy.copy(answer1) - tmp.type = r._TYPE_A - response.add_answer(query, tmp) - assert len(response.answers) == 3 - - # Should not be suppressed, class is different - tmp = copy.copy(answer1) - tmp.class_ = r._CLASS_NONE - response.add_answer(query, tmp) - assert len(response.answers) == 4 - - # ::TODO:: could add additional tests for DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService - - def test_dns_hinfo(self): - generated = r.DNSOutgoing(0) - generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os')) - parsed = r.DNSIncoming(generated.packet()) - answer = cast(r.DNSHinfo, parsed.answers[0]) - self.assertEqual(answer.cpu, u'cpu') - self.assertEqual(answer.os, u'os') - - generated = r.DNSOutgoing(0) - generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) - self.assertRaises(r.NamePartTooLongException, generated.packet) - - -class PacketForm(unittest.TestCase): - def test_transaction_id(self): - """ID must be zero in a DNS-SD packet""" - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - bytes = generated.packet() - id = bytes[0] << 8 | bytes[1] - self.assertEqual(id, 0) - - def test_query_header_bits(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - bytes = generated.packet() - flags = bytes[2] << 8 | bytes[3] - self.assertEqual(flags, 0x0) - - def test_response_header_bits(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - bytes = generated.packet() - flags = bytes[2] << 8 | bytes[3] - self.assertEqual(flags, 0x8000) - - def test_numbers(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - bytes = generated.packet() - (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) - self.assertEqual(num_questions, 0) - self.assertEqual(num_answers, 0) - self.assertEqual(num_authorities, 0) - self.assertEqual(num_additionals, 0) - - def test_numbers_questions(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) - for i in range(10): - generated.add_question(question) - bytes = generated.packet() - (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) - self.assertEqual(num_questions, 10) - self.assertEqual(num_answers, 0) - self.assertEqual(num_authorities, 0) - self.assertEqual(num_additionals, 0) - - -class Names(unittest.TestCase): - def test_long_name(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - question = r.DNSQuestion( - "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", r._TYPE_SRV, r._CLASS_IN - ) - generated.add_question(question) - r.DNSIncoming(generated.packet()) - - def test_exceedingly_long_name(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - name = "%slocal." % ("part." * 1000) - question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN) - generated.add_question(question) - r.DNSIncoming(generated.packet()) - - def test_exceedingly_long_name_part(self): - name = "%s.local." % ("a" * 1000) - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN) - generated.add_question(question) - self.assertRaises(r.NamePartTooLongException, generated.packet) - - def test_same_name(self): - name = "paired.local." - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN) - generated.add_question(question) - generated.add_question(question) - r.DNSIncoming(generated.packet()) - - def test_lots_of_names(self): - - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - - # create a bunch of servers - type_ = "_my-service._tcp.local." - name = 'a wonderful service' - server_count = 300 - self.generate_many_hosts(zc, type_, name, server_count) - - # verify that name changing works - self.verify_name_change(zc, type_, name, server_count) - - # we are going to monkey patch the zeroconf send to check packet sizes - old_send = zc.send - - longest_packet_len = 0 - longest_packet = None # type: Optional[r.DNSOutgoing] - - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): - """Sends an outgoing packet.""" - packet = out.packet() - nonlocal longest_packet_len, longest_packet - if longest_packet_len < len(packet): - longest_packet_len = len(packet) - longest_packet = out - old_send(out, addr=addr, port=port) - - # monkey patch the zeroconf send - setattr(zc, "send", send) - - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - pass - - # start a browser - browser = ServiceBrowser(zc, type_, [on_service_state_change]) - - # wait until the browse request packet has maxed out in size - sleep_count = 0 - while sleep_count < 100 and longest_packet_len < r._MAX_MSG_ABSOLUTE - 100: - sleep_count += 1 - time.sleep(0.1) - - browser.cancel() - time.sleep(0.5) - - import zeroconf - - zeroconf.log.debug('sleep_count %d, sized %d', sleep_count, longest_packet_len) - - # now the browser has sent at least one request, verify the size - assert longest_packet_len <= r._MAX_MSG_ABSOLUTE - assert longest_packet_len >= r._MAX_MSG_ABSOLUTE - 100 - - # mock zeroconf's logger warning() and debug() - from unittest.mock import patch - - patch_warn = patch('zeroconf.log.warning') - patch_debug = patch('zeroconf.log.debug') - mocked_log_warn = patch_warn.start() - mocked_log_debug = patch_debug.start() - - # now that we have a long packet in our possession, let's verify the - # exception handling. - out = longest_packet - assert out is not None - out.data.append(b'\0' * 1000) - - # mock the zeroconf logger and check for the correct logging backoff - call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count - # try to send an oversized packet - zc.send(out) - assert mocked_log_warn.call_count == call_counts[0] + 1 - assert mocked_log_debug.call_count == call_counts[0] - zc.send(out) - assert mocked_log_warn.call_count == call_counts[0] + 1 - assert mocked_log_debug.call_count == call_counts[0] + 1 - - # force a receive of an oversized packet - packet = out.packet() - s = zc._respond_sockets[0] - - # mock the zeroconf logger and check for the correct logging backoff - call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count - # force receive on oversized packet - s.sendto(packet, 0, (r._MDNS_ADDR, r._MDNS_PORT)) - s.sendto(packet, 0, (r._MDNS_ADDR, r._MDNS_PORT)) - time.sleep(2.0) - zeroconf.log.debug( - 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts - ) - assert mocked_log_debug.call_count > call_counts[0] - - # close our zeroconf which will close the sockets - zc.close() - - # pop the big chunk off the end of the data and send on a closed socket - out.data.pop() - zc._GLOBAL_DONE = False - - # mock the zeroconf logger and check for the correct logging backoff - call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count - # send on a closed socket (force a socket error) - zc.send(out) - zeroconf.log.debug( - 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts - ) - assert mocked_log_warn.call_count > call_counts[0] - assert mocked_log_debug.call_count > call_counts[0] - zc.send(out) - zeroconf.log.debug( - 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts - ) - assert mocked_log_debug.call_count > call_counts[0] + 2 - - mocked_log_warn.stop() - mocked_log_debug.stop() - - def verify_name_change(self, zc, type_, name, number_hosts): - desc = {'path': '/~paulsm/'} - info_service = ServiceInfo( - type_, '%s.%s' % (name, type_), socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." - ) - - # verify name conflict - self.assertRaises(r.NonUniqueNameException, zc.register_service, info_service) - - zc.register_service(info_service, allow_name_change=True) - assert info_service.name.split('.')[0] == '%s-%d' % (name, number_hosts + 1) - - def generate_many_hosts(self, zc, type_, name, number_hosts): - records_per_server = 2 - block_size = 25 - number_hosts = int(((number_hosts - 1) / block_size + 1)) * block_size - for i in range(1, number_hosts + 1): - next_name = name if i == 1 else '%s-%d' % (name, i) - self.generate_host(zc, next_name, type_) - if i % block_size == 0: - sleep_count = 0 - while sleep_count < 40 and i * records_per_server > len(zc.cache.entries_with_name(type_)): - sleep_count += 1 - time.sleep(0.05) - - @staticmethod - def generate_host(zc, host_name, type_): - name = '.'.join((host_name, type_)) - out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) - out.add_answer_at_time(r.DNSPointer(type_, r._TYPE_PTR, r._CLASS_IN, r._DNS_OTHER_TTL, name), 0) - out.add_answer_at_time( - r.DNSService(type_, r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, name), 0 - ) - zc.send(out) - - -class Framework(unittest.TestCase): - def test_launch_and_close(self): - rv = r.Zeroconf(interfaces=r.InterfaceChoice.All) - rv.close() - rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default) - rv.close() - - @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') - @attr('IPv6') - def test_launch_and_close_v4_v6(self): - rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.All) - rv.close() - rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.All) - rv.close() - - @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') - @attr('IPv6') - def test_launch_and_close_v6_only(self): - rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.V6Only) - rv.close() - rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.V6Only) - rv.close() - - def test_handle_response(self): - def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: - ttl = 120 - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - - if service_state_change == r.ServiceStateChange.Updated: - generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 - ) - return r.DNSIncoming(generated.packet()) - - if service_state_change == r.ServiceStateChange.Removed: - ttl = 0 - - generated.add_answer_at_time( - r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_name), 0 - ) - generated.add_answer_at_time( - r.DNSService( - service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server - ), - 0, - ) - generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 - ) - generated.add_answer_at_time( - r.DNSAddress( - service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - socket.inet_aton(service_address), - ), - 0, - ) - - return r.DNSIncoming(generated.packet()) - - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-2.local.' - service_text = b'path=/~paulsm/' - service_address = '10.0.1.2' - - zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) - - try: - # service added - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Added)) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) - assert dns_text is not None - assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~paulsm/' - - # https://tools.ietf.org/html/rfc6762#section-10.2 - # Instead of merging this new record additively into the cache in addition - # to any previous records with the same name, rrtype, and rrclass, - # all old records with that name, rrtype, and rrclass that were received - # more than one second ago are declared invalid, - # and marked to expire from the cache in one second. - time.sleep(1.1) - - # service updated. currently only text record can be updated - service_text = b'path=/~humingchun/' - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) - assert dns_text is not None - assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' - - time.sleep(1.1) - - # service removed - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Removed)) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) - assert dns_text is None - - finally: - zeroconf.close() - - -class Exceptions(unittest.TestCase): - - browser = None # type: Zeroconf - - @classmethod - def setUpClass(cls): - cls.browser = Zeroconf(interfaces=['127.0.0.1']) - - @classmethod - def tearDownClass(cls): - cls.browser.close() - del cls.browser - - def test_bad_service_info_name(self): - self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, "type", "type_not") - - def test_bad_service_names(self): - bad_names_to_try = ( - '', - 'local', - '_tcp.local.', - '_udp.local.', - '._udp.local.', - '_@._tcp.local.', - '_A@._tcp.local.', - '_x--x._tcp.local.', - '_-x._udp.local.', - '_x-._tcp.local.', - '_22._udp.local.', - '_2-2._tcp.local.', - '_1234567890-abcde._udp.local.', - '\x00._x._udp.local.', - ) - for name in bad_names_to_try: - self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, name, 'x.' + name) - - def test_good_instance_names(self): - good_names_to_try = ( - '.._x._tcp.local.', - 'x.sub._http._tcp.local.', - '6d86f882b90facee9170ad3439d72a4d6ee9f511._zget._http._tcp.local.', - ) - for name in good_names_to_try: - r.service_type_name(name) - - def test_bad_types(self): - bad_names_to_try = ( - '._x._tcp.local.', - 'a' * 64 + '._sub._http._tcp.local.', - 'a' * 62 + u'â._sub._http._tcp.local.', - ) - for name in bad_names_to_try: - self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) - - def test_bad_sub_types(self): - bad_names_to_try = ( - '_sub._http._tcp.local.', - '._sub._http._tcp.local.', - '\x7f._sub._http._tcp.local.', - '\x1f._sub._http._tcp.local.', - ) - for name in bad_names_to_try: - self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) - - def test_good_service_names(self): - good_names_to_try = ( - '_x._tcp.local.', - '_x._udp.local.', - '_12345-67890-abc._udp.local.', - 'x._sub._http._tcp.local.', - 'a' * 63 + '._sub._http._tcp.local.', - 'a' * 61 + u'â._sub._http._tcp.local.', - ) - for name in good_names_to_try: - r.service_type_name(name) - - r.service_type_name('_one_two._tcp.local.', allow_underscores=True) - - def test_invalid_addresses(self): - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - bad = ('127.0.0.1', '::1', 42) - for addr in bad: - self.assertRaisesRegex( - TypeError, - 'Addresses must be bytes', - ServiceInfo, - type_, - registration_name, - port=80, - addresses=[addr], - ) - - -class TestDnsIncoming(unittest.TestCase): - def test_incoming_exception_handling(self): - generated = r.DNSOutgoing(0) - packet = generated.packet() - packet = packet[:8] + b'deadbeef' + packet[8:] - parsed = r.DNSIncoming(packet) - parsed = r.DNSIncoming(packet) - assert parsed.valid is False - - def test_incoming_unknown_type(self): - generated = r.DNSOutgoing(0) - answer = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - generated.add_additional_answer(answer) - packet = generated.packet() - parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 0 - assert parsed.is_query() != parsed.is_response() - - def test_incoming_ipv6(self): - addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com - packed = socket.inet_pton(socket.AF_INET6, addr) - generated = r.DNSOutgoing(0) - answer = r.DNSAddress('domain', r._TYPE_AAAA, r._CLASS_IN, 1, packed) - generated.add_additional_answer(answer) - packet = generated.packet() - parsed = r.DNSIncoming(packet) - record = parsed.answers[0] - assert isinstance(record, r.DNSAddress) - assert record.address == packed - - -class TestRegistrar(unittest.TestCase): - def test_ttl(self): - - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - - # service definition - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." - ) - - # we are going to monkey patch the zeroconf send to check packet sizes - old_send = zc.send - - nbr_answers = nbr_additionals = nbr_authorities = 0 - - def get_ttl(record_type): - if expected_ttl is not None: - return expected_ttl - elif record_type in [r._TYPE_A, r._TYPE_SRV]: - return r._DNS_HOST_TTL - else: - return r._DNS_OTHER_TTL - - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): - """Sends an outgoing packet.""" - nonlocal nbr_answers, nbr_additionals, nbr_authorities - - for answer, time_ in out.answers: - nbr_answers += 1 - assert answer.ttl == get_ttl(answer.type) - for answer in out.additionals: - nbr_additionals += 1 - assert answer.ttl == get_ttl(answer.type) - for answer in out.authorities: - nbr_authorities += 1 - assert answer.ttl == get_ttl(answer.type) - old_send(out, addr=addr, port=port) - - # monkey patch the zeroconf send - setattr(zc, "send", send) - - # register service with default TTL - expected_ttl = None - zc.register_service(info) - assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # query - query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) - query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, r._TYPE_SRV, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, r._TYPE_TXT, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.server, r._TYPE_A, r._CLASS_IN)) - zc.handle_query(r.DNSIncoming(query.packet()), r._MDNS_ADDR, r._MDNS_PORT) - assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # unregister - expected_ttl = 0 - zc.unregister_service(info) - assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # register service with custom TTL - expected_ttl = r._DNS_HOST_TTL * 2 - assert expected_ttl != r._DNS_HOST_TTL - zc.register_service(info, ttl=expected_ttl) - assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # query - query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) - query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, r._TYPE_SRV, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, r._TYPE_TXT, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.server, r._TYPE_A, r._CLASS_IN)) - zc.handle_query(r.DNSIncoming(query.packet()), r._MDNS_ADDR, r._MDNS_PORT) - assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # unregister - expected_ttl = 0 - zc.unregister_service(info) - assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - -class TestDNSCache(unittest.TestCase): - def test_order(self): - record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') - cache = r.DNSCache() - cache.add(record1) - cache.add(record2) - entry = r.DNSEntry('a', r._TYPE_SOA, r._CLASS_IN) - cached_record = cache.get(entry) - self.assertEqual(cached_record, record2) - - -class ServiceTypesQuery(unittest.TestCase): - def test_integration_with_listener(self): - - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." - ) - zeroconf_registrar.register_service(info) - - try: - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) - assert type_ in service_types - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert type_ in service_types - - finally: - zeroconf_registrar.close() - - @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') - def test_integration_with_listener_v6_records(self): - - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, socket.inet_pton(socket.AF_INET6, addr), 80, 0, 0, desc, "ash-2.local." - ) - zeroconf_registrar.register_service(info) - - try: - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) - assert type_ in service_types - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert type_ in service_types - - finally: - zeroconf_registrar.close() - - @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') - @attr('IPv6') - def test_integration_with_listener_ipv6(self): - - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - zeroconf_registrar = Zeroconf(ip_version=r.IPVersion.V6Only) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." - ) - zeroconf_registrar.register_service(info) - - try: - service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) - assert type_ in service_types, service_types - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert type_ in service_types, service_types - - finally: - zeroconf_registrar.close() - - def test_integration_with_subtype_and_listener(self): - subtype_ = "_subtype._sub" - type_ = "_type._tcp.local." - name = "xxxyyy" - # Note: discovery returns only DNS-SD type not subtype - discovery_type = "%s.%s" % (subtype_, type_) - registration_name = "%s.%s" % (name, type_) - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - discovery_type, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." - ) - zeroconf_registrar.register_service(info) - - try: - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) - assert discovery_type in service_types - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert discovery_type in service_types - - finally: - zeroconf_registrar.close() - - -class ListenerTest(unittest.TestCase): - def test_integration_with_listener_class(self): - - service_added = Event() - service_removed = Event() - service_updated = Event() - - subtype_name = "My special Subtype" - type_ = "_http._tcp.local." - subtype = subtype_name + "._sub." + type_ - name = "xxxyyyæøå" - registration_name = "%s.%s" % (name, subtype) - - class MyListener(r.ServiceListener): - def add_service(self, zeroconf, type, name): - zeroconf.get_service_info(type, name) - service_added.set() - - def remove_service(self, zeroconf, type, name): - service_removed.set() - - def update_service(self, zeroconf, type, name): - pass - - class MySubListener(r.ServiceListener): - def add_service(self, zeroconf, type, name): - pass - - def remove_service(self, zeroconf, type, name): - pass - - def update_service(self, zeroconf, type, name): - service_updated.set() - - listener = MyListener() - zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - zeroconf_browser.add_service_listener(subtype, listener) - - properties = dict( - prop_none=None, - prop_string=b'a_prop', - prop_float=1.0, - prop_blank=b'a blanked string', - prop_true=1, - prop_false=0, - ) - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} # type: Dict - desc.update(properties) - addresses = [socket.inet_aton("10.0.1.2")] - if socket.has_ipv6: - addresses.append(socket.inet_pton(socket.AF_INET6, "2001:db8::1")) - info_service = ServiceInfo( - subtype, registration_name, port=80, properties=desc, server="ash-2.local.", addresses=addresses - ) - zeroconf_registrar.register_service(info_service) - - try: - service_added.wait(1) - assert service_added.is_set() - - # short pause to allow multicast timers to expire - time.sleep(3) - - # clear the answer cache to force query - for record in zeroconf_browser.cache.entries(): - zeroconf_browser.cache.remove(record) - - # get service info without answer cache - info = zeroconf_browser.get_service_info(type_, registration_name) - assert info is not None - assert info.properties[b'prop_none'] is False - assert info.properties[b'prop_string'] == properties['prop_string'] - assert info.properties[b'prop_float'] is False - assert info.properties[b'prop_blank'] == properties['prop_blank'] - assert info.properties[b'prop_true'] is True - assert info.properties[b'prop_false'] is False - assert info.addresses == addresses[:1] # no V6 by default - all_addresses = info.addresses_by_version(r.IPVersion.All) - assert all_addresses == addresses, all_addresses - - info = zeroconf_browser.get_service_info(subtype, registration_name) - assert info is not None - assert info.properties[b'prop_none'] is False - - # Begin material test addition - sublistener = MySubListener() - zeroconf_browser.add_service_listener(registration_name, sublistener) - properties['prop_blank'] = b'an updated string' - desc.update(properties) - info_service = ServiceInfo( - subtype, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." - ) - zeroconf_registrar.update_service(info_service) - service_updated.wait(1) - assert service_updated.is_set() - - info = zeroconf_browser.get_service_info(type_, registration_name) - assert info is not None - assert info.properties[b'prop_blank'] == properties['prop_blank'] - # End material test addition - - zeroconf_registrar.unregister_service(info_service) - service_removed.wait(1) - assert service_removed.is_set() - - finally: - zeroconf_registrar.close() - zeroconf_browser.remove_service_listener(listener) - zeroconf_browser.close() - - -class TestServiceBrowser(unittest.TestCase): - def test_update_record(self): - - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-2.local.' - service_text = b'path=/~paulsm/' - service_address = '10.0.1.2' - - service_added = False - service_removed = False - service_updated_count = 0 - service_add_event = Event() - service_removed_event = Event() - service_updated_event = Event() - - class MyServiceListener(r.ServiceListener): - def add_service(self, zc, type_, name) -> None: - nonlocal service_added - service_added = True - service_add_event.set() - - def remove_service(self, zc, type_, name) -> None: - nonlocal service_added, service_removed - service_added = False - service_removed = True - service_removed_event.set() - - def update_service(self, zc, type_, name) -> None: - nonlocal service_updated_count - service_updated_count += 1 - - service_info = zc.get_service_info(type_, name) - assert service_info.text == service_text - service_updated_event.set() - - def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: - ttl = 120 - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - - if service_state_change == r.ServiceStateChange.Updated: - generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 - ) - return r.DNSIncoming(generated.packet()) - - if service_state_change == r.ServiceStateChange.Removed: - ttl = 0 - - generated.add_answer_at_time( - r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_name), 0 - ) - generated.add_answer_at_time( - r.DNSService( - service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server - ), - 0, - ) - generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 - ) - generated.add_answer_at_time( - r.DNSAddress( - service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - socket.inet_aton(service_address), - ), - 0, - ) - - return r.DNSIncoming(generated.packet()) - - zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) - service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener()) - - try: - # service added - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Added)) - service_add_event.wait(1) - service_updated_event.wait(1) - assert service_added is True - assert service_updated_count == 1 - assert service_removed is False - - # service updated. currently only text record can be updated - service_updated_event.clear() - service_text = b'path=/~humingchun/' - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(1) - assert service_added is True - assert service_updated_count == 2 - assert service_removed is False - - # service removed - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Removed)) - service_removed_event.wait(1) - assert service_added is False - assert service_updated_count == 2 - assert service_removed is True - - finally: - service_browser.cancel() - zeroconf.remove_all_service_listeners() - zeroconf.close() - - -def test_backoff(): - got_query = Event() - - type_ = "_http._tcp.local." - zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - - # we are going to monkey patch the zeroconf send to check query transmission - old_send = zeroconf_browser.send - - time_offset = 0.0 - start_time = time.time() * 1000 - initial_query_interval = r._BROWSER_TIME / 1000 - - def current_time_millis(): - """Current system time in milliseconds""" - return start_time + time_offset * 1000 - - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): - """Sends an outgoing packet.""" - got_query.set() - old_send(out, addr=addr, port=port) - - # monkey patch the zeroconf send - setattr(zeroconf_browser, "send", send) - - # monkey patch the zeroconf current_time_millis - r.current_time_millis = current_time_millis - - # monkey patch the backoff limit to prevent test running forever - r._BROWSER_BACKOFF_LIMIT = 10 # seconds - - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - pass - - browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - - try: - # Test that queries are sent at increasing intervals - sleep_count = 0 - next_query_interval = 0.0 - expected_query_time = 0.0 - while True: - zeroconf_browser.notify_all() - sleep_count += 1 - got_query.wait(0.1) - if time_offset == expected_query_time: - assert got_query.is_set() - got_query.clear() - if next_query_interval == r._BROWSER_BACKOFF_LIMIT: - # Only need to test up to the point where we've seen a query - # after the backoff limit has been hit - break - elif next_query_interval == 0: - next_query_interval = initial_query_interval - expected_query_time = initial_query_interval - else: - next_query_interval = min(2 * next_query_interval, r._BROWSER_BACKOFF_LIMIT) - expected_query_time += next_query_interval - else: - assert not got_query.is_set() - time_offset += initial_query_interval - - finally: - browser.cancel() - zeroconf_browser.close() - - -def test_integration(): - service_added = Event() - service_removed = Event() - unexpected_ttl = Event() - got_query = Event() - - type_ = "_http._tcp.local." - registration_name = "xxxyyy.%s" % type_ - - def on_service_state_change(zeroconf, service_type, state_change, name): - if name == registration_name: - if state_change is ServiceStateChange.Added: - service_added.set() - elif state_change is ServiceStateChange.Removed: - service_removed.set() - - zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - - # we are going to monkey patch the zeroconf send to check packet sizes - old_send = zeroconf_browser.send - - time_offset = 0.0 - - def current_time_millis(): - """Current system time in milliseconds""" - return time.time() * 1000 + time_offset * 1000 - - expected_ttl = r._DNS_HOST_TTL - - nbr_answers = 0 - - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): - """Sends an outgoing packet.""" - pout = r.DNSIncoming(out.packet()) - nonlocal nbr_answers - for answer in pout.answers: - nbr_answers += 1 - if not answer.ttl > expected_ttl / 2: - unexpected_ttl.set() - - got_query.set() - old_send(out, addr=addr, port=port) - - # monkey patch the zeroconf send - setattr(zeroconf_browser, "send", send) - - # monkey patch the zeroconf current_time_millis - r.current_time_millis = current_time_millis - - # monkey patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL - r._BROWSER_BACKOFF_LIMIT = int(expected_ttl / 4) - - service_added = Event() - service_removed = Event() - - browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo(type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local.") - zeroconf_registrar.register_service(info) - - try: - service_added.wait(1) - assert service_added.is_set() - - # Test that we receive queries containing answers only if the remaining TTL - # is greater than half the original TTL - sleep_count = 0 - test_iterations = 50 - while nbr_answers < test_iterations: - # Increase simulated time shift by 1/4 of the TTL in seconds - time_offset += expected_ttl / 4 - zeroconf_browser.notify_all() - sleep_count += 1 - got_query.wait(0.1) - got_query.clear() - # Prevent the test running indefinitely in an error condition - assert sleep_count < test_iterations * 4 - assert not unexpected_ttl.is_set() - - # Don't remove service, allow close() to cleanup - - finally: - zeroconf_registrar.close() - service_removed.wait(1) - assert service_removed.is_set() - browser.cancel() - zeroconf_browser.close() - - -def test_multiple_addresses(): - type_ = "_http._tcp.local." - registration_name = "xxxyyy.%s" % type_ - desc = {'path': '/~paulsm/'} - address_parsed = "10.0.1.2" - address = socket.inet_aton(address_parsed) - - # Old way - info = ServiceInfo(type_, registration_name, address, 80, 0, 0, desc, "ash-2.local.") - - assert info.address == address - assert info.addresses == [address] - - # Updating works - address2 = socket.inet_aton("10.0.1.3") - info.address = address2 - - assert info.address == address2 - assert info.addresses == [address2] - - info.address = None - - assert info.address is None - assert info.addresses == [] - - info.addresses = [address2] - - assert info.address == address2 - assert info.addresses == [address2] - - # Compatibility way - info = ServiceInfo(type_, registration_name, [address, address], 80, 0, 0, desc, "ash-2.local.") - - assert info.addresses == [address, address] - - # New kwarg way - info = ServiceInfo( - type_, registration_name, None, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address] - ) - - assert info.addresses == [address, address] - - if socket.has_ipv6: - address_v6_parsed = "2001:db8::1" - address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) - info = ServiceInfo(type_, registration_name, [address, address_v6], 80, 0, 0, desc, "ash-2.local.") - assert info.addresses == [address] - assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] - assert info.addresses_by_version(r.IPVersion.V4Only) == [address] - assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6] - assert info.parsed_addresses() == [address_parsed, address_v6_parsed] - assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] - assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed] - - -def test_ptr_optimization(): - - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - - # service definition - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo(type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local.") - - # we are going to monkey patch the zeroconf send to check packet sizes - old_send = zc.send - - nbr_answers = nbr_additionals = nbr_authorities = 0 - has_srv = has_txt = has_a = False - - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): - """Sends an outgoing packet.""" - nonlocal nbr_answers, nbr_additionals, nbr_authorities - nonlocal has_srv, has_txt, has_a - - nbr_answers += len(out.answers) - nbr_authorities += len(out.authorities) - for answer in out.additionals: - nbr_additionals += 1 - if answer.type == r._TYPE_SRV: - has_srv = True - elif answer.type == r._TYPE_TXT: - has_txt = True - elif answer.type == r._TYPE_A: - has_a = True - - old_send(out, addr=addr, port=port) - - # monkey patch the zeroconf send - setattr(zc, "send", send) - - # register - zc.register_service(info) - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # query - query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) - query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN)) - zc.handle_query(r.DNSIncoming(query.packet()), r._MDNS_ADDR, r._MDNS_PORT) - assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 - assert has_srv and has_txt and has_a - - # unregister - zc.unregister_service(info)