diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 00000000..7648cf0d
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,5 @@
+[report]
+exclude_lines =
+ pragma: no cover
+ if TYPE_CHECKING:
+ if sys.version_info
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 00000000..01b181fe
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,64 @@
+name: CI
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - "**"
+
+jobs:
+ build:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ python-version: [3.6, 3.7, 3.8, 3.9, pypy3]
+ include:
+ - os: ubuntu-latest
+ venvcmd: . env/bin/activate
+ - os: macos-latest
+ venvcmd: . env/bin/activate
+ - os: windows-latest
+ venvcmd: env\Scripts\Activate.ps1
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - uses: actions/cache@v2
+ id: cache
+ with:
+ path: env
+ key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/Makefile') }}-${{ hashFiles('**/requirements-dev.txt') }}
+ restore-keys: |
+ ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/Makefile') }}
+ - name: Install dependencies
+ if: steps.cache.outputs.cache-hit != 'true'
+ run: |
+ python -m venv env
+ ${{ matrix.venvcmd }}
+ pip install --upgrade -r requirements-dev.txt pytest-github-actions-annotate-failures
+ - name: Run flake8
+ if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }}
+ run: |
+ ${{ matrix.venvcmd }}
+ make flake8
+ - name: Run mypy
+ if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }}
+ run: |
+ ${{ matrix.venvcmd }}
+ make mypy
+ - name: Run black_check
+ if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }}
+ run: |
+ ${{ matrix.venvcmd }}
+ make black_check
+ - name: Run tests
+ run: |
+ ${{ matrix.venvcmd }}
+ make test_coverage
+ - name: Report coverage to Codecov
+ uses: codecov/codecov-action@v1
diff --git a/.gitignore b/.gitignore
index ddf8a0d7..0af9ce1e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,3 +11,6 @@ Thumbs.db
.cache
.mypy_cache/
docs/_build/
+.vscode
+/dist/
+/zeroconf.egg-info/
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index acf47981..00000000
--- a/.travis.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-language: python
-python:
- - "3.5"
- - "3.6"
- - "3.7"
- - "3.8"
- - "pypy3.5"
- - "pypy3"
-install:
- - pip install -r requirements-dev.txt
- # mypy can't be installed on pypy
- - if [[ "${TRAVIS_PYTHON_VERSION}" != "pypy"* ]] ; then pip install mypy ; fi
- - if [[ "${TRAVIS_PYTHON_VERSION}" != *"3.5"* && "${TRAVIS_PYTHON_VERSION}" != "pypy"* ]] ; then
- pip install black ; fi
-script:
- # no IPv6 support in Travis :(
- - make TEST_ARGS='-a "!IPv6"' ci
-after_success:
- - coveralls
diff --git a/Makefile b/Makefile
index ea5f8c64..88980ff2 100644
--- a/Makefile
+++ b/Makefile
@@ -1,30 +1,28 @@
+# version: 1.1
+
.PHONY: all virtualenv
MAX_LINE_LENGTH=110
PYTHON_IMPLEMENTATION:=$(shell python -c "import sys;import platform;sys.stdout.write(platform.python_implementation())")
PYTHON_VERSION:=$(shell python -c "import sys;sys.stdout.write('%d.%d' % sys.version_info[:2])")
-TEST_ARGS=
LINT_TARGETS:=flake8
ifneq ($(findstring PyPy,$(PYTHON_IMPLEMENTATION)),PyPy)
- LINT_TARGETS:=$(LINT_TARGETS) mypy
-endif
-ifeq ($(or $(findstring 3.5,$(PYTHON_VERSION)),$(findstring PyPy,$(PYTHON_IMPLEMENTATION))),)
- LINT_TARGETS:=$(LINT_TARGETS) black_check
+ LINT_TARGETS:=$(LINT_TARGETS) mypy black_check pylint
endif
virtualenv: ./env/requirements.built
env:
- virtualenv env
+ python -m venv env
./env/requirements.built: env requirements-dev.txt
./env/bin/pip install -r requirements-dev.txt
cp requirements-dev.txt ./env/requirements.built
.PHONY: ci
-ci: test_coverage lint
+ci: lint test_coverage
.PHONY: lint
lint: $(LINT_TARGETS)
@@ -32,18 +30,24 @@ lint: $(LINT_TARGETS)
flake8:
flake8 --max-line-length=$(MAX_LINE_LENGTH) setup.py examples zeroconf
+pylint:
+ pylint zeroconf
+
.PHONY: black_check
black_check:
black --check setup.py examples zeroconf
mypy:
- mypy examples/*.py zeroconf/*.py
+# --no-warn-redundant-casts --no-warn-unused-ignores is needed since we support multiple python versions
+# We should be able to drop this once python 3.6 goes away
+ mypy --no-warn-redundant-casts --no-warn-unused-ignores examples/*.py zeroconf
test:
- nosetests -v $(TEST_ARGS)
+ pytest --durations=20 --timeout=60 -v tests
test_coverage:
- nosetests -v --with-coverage --cover-package=zeroconf $(TEST_ARGS)
+ pytest --durations=20 --timeout=60 -v --cov=zeroconf --cov-branch --cov-report xml --cov-report html --cov-report term-missing tests
autopep8:
autopep8 --max-line-length=$(MAX_LINE_LENGTH) -i setup.py examples zeroconf
+
diff --git a/README.rst b/README.rst
index 432d3c4b..ea7c3d1c 100644
--- a/README.rst
+++ b/README.rst
@@ -1,14 +1,14 @@
python-zeroconf
===============
-.. image:: https://travis-ci.org/jstasiak/python-zeroconf.svg?branch=master
- :target: https://travis-ci.org/jstasiak/python-zeroconf
-
+.. image:: https://github.com/jstasiak/python-zeroconf/workflows/CI/badge.svg
+ :target: https://github.com/jstasiak/python-zeroconf?query=workflow%3ACI+branch%3Amaster
+
.. image:: https://img.shields.io/pypi/v/zeroconf.svg
:target: https://pypi.python.org/pypi/zeroconf
-.. image:: https://img.shields.io/coveralls/jstasiak/python-zeroconf.svg
- :target: https://coveralls.io/r/jstasiak/python-zeroconf
+.. image:: https://codecov.io/gh/jstasiak/python-zeroconf/branch/master/graph/badge.svg
+ :target: https://codecov.io/gh/jstasiak/python-zeroconf
`Documentation `_.
@@ -37,15 +37,15 @@ Compared to some other Zeroconf/Bonjour/Avahi Python packages, python-zeroconf:
* isn't tied to Bonjour or Avahi
* doesn't use D-Bus
-* doesn't force you to use particular event loop or Twisted
+* doesn't force you to use particular event loop or Twisted (asyncio is used under the hood but not required)
* is pip-installable
* has PyPI distribution
Python compatibility
--------------------
-* CPython 3.5+
-* PyPy3 5.8+
+* CPython 3.6+
+* PyPy3 7.2+
Versioning
----------
@@ -59,8 +59,14 @@ This project's versions follow the following pattern: MAJOR.MINOR.PATCH.
Status
------
-There are some people using this package. I don't actively use it and as such
-any help I can offer with regard to any issues is very limited.
+This project is actively maintained.
+
+Traffic Reduction
+-----------------
+
+Before version 0.32, most traffic reduction techniques described in https://datatracker.ietf.org/doc/html/rfc6762#section-7
+where not implemented which could lead to excessive network traffic. It is highly recommended that version 0.32 or later
+is used if this is a concern.
IPv6 support
------------
@@ -69,8 +75,6 @@ IPv6 support is relatively new and currently limited, specifically:
* `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` on non-POSIX
systems.
-* On Windows specific interfaces can only be requested as interface indexes,
- not as IP addresses.
* Dual-stack IPv6 sockets are used, which may not be supported everywhere (some
BSD variants do not have them).
* Listening on localhost (`::1`) does not work. Help with understanding why is
@@ -134,6 +138,655 @@ See examples directory for more.
Changelog
=========
+0.36.7
+======
+
+* Improved performance of responding to queries (#994) (#996) (#997) @bdraco
+* Improved log message when receiving an invalid or corrupt packet (#998) @bdraco
+
+0.36.6
+======
+
+* Improved performance of sending outgoing packets (#990) @bdraco
+
+0.36.5
+======
+
+* Reduced memory usage for incoming and outgoing packets (#987) @bdraco
+
+0.36.4
+======
+
+* Improved performance of constructing outgoing packets (#978) (#979) @bdraco
+* Deferred parsing of incoming packets when it can be avoided (#983) @bdraco
+
+0.36.3
+======
+
+* Improved performance of parsing incoming packets (#975) @bdraco
+
+0.36.2
+======
+
+* Include NSEC records for non-existent types when responding with addresses (#972) (#971) @bdraco
+ Implements RFC6762 sec 6.2 (http://datatracker.ietf.org/doc/html/rfc6762#section-6.2)
+
+0.36.1
+======
+
+* Skip goodbye packets for addresses when there is another service registered with the same name (#968) @bdraco
+
+ If a ServiceInfo that used the same server name as another ServiceInfo
+ was unregistered, goodbye packets would be sent for the addresses and
+ would cause the other service to be seen as offline.
+* Fixed equality and hash for dns records with the unique bit (#969) @bdraco
+
+ These records should have the same hash and equality since
+ the unique bit (cache flush bit) is not considered when adding or removing
+ the records from the cache.
+
+0.36.0
+======
+
+Technically backwards incompatible:
+
+* Fill incomplete IPv6 tuples to avoid WinError on windows (#965) @lokesh2019
+
+ Fixed #932
+
+0.35.1
+======
+
+* Only reschedule types if the send next time changes (#958) @bdraco
+
+ When the PTR response was seen again, the timer was being canceled and
+ rescheduled even if the timer was for the same time. While this did
+ not cause any breakage, it is quite inefficient.
+* Cache DNS record and question hashes (#960) @bdraco
+
+ The hash was being recalculated every time the object
+ was being used in a set or dict. Since the hashes are
+ effectively immutable, we only calculate them once now.
+
+0.35.0
+======
+
+* Reduced chance of accidental synchronization of ServiceInfo requests (#955) @bdraco
+* Sort aggregated responses to increase chance of name compression (#954) @bdraco
+
+Technically backwards incompatible:
+
+* Send unicast replies on the same socket the query was received (#952) @bdraco
+
+ When replying to a QU question, we do not know if the sending host is reachable
+ from all of the sending sockets. We now avoid this problem by replying via
+ the receiving socket. This was the existing behavior when `InterfaceChoice.Default`
+ is set.
+
+ This change extends the unicast relay behavior to used with `InterfaceChoice.Default`
+ to apply when `InterfaceChoice.All` or interfaces are explicitly passed when
+ instantiating a `Zeroconf` instance.
+
+ Fixes #951
+
+0.34.3
+======
+
+* Fix sending immediate multicast responses (#949) @bdraco
+
+0.34.2
+======
+
+* Coalesce aggregated multicast answers (#945) @bdraco
+
+ When the random delay is shorter than the last scheduled response,
+ answers are now added to the same outgoing time group.
+
+ This reduces traffic when we already know we will be sending a group of answers
+ inside the random delay window described in
+ datatracker.ietf.org/doc/html/rfc6762#section-6.3
+* Ensure ServiceInfo requests can be answered inside the default timeout with network protection (#946) @bdraco
+
+ Adjust the time windows to ensure responses that have triggered the
+ protection against against excessive packet flooding due to
+ software bugs or malicious attack described in RFC6762 section 6
+ can respond in under 1350ms to ensure ServiceInfo can ask two
+ questions within the default timeout of 3000ms
+
+0.34.1
+======
+
+* Ensure multicast aggregation sends responses within 620ms (#942) @bdraco
+
+ Responses that trigger the protection against against excessive
+ packet flooding due to software bugs or malicious attack described
+ in RFC6762 section 6 could cause the multicast aggregation response
+ to be delayed longer than 620ms (The maximum random delay of 120ms
+ and 500ms additional for aggregation).
+
+ Only responses that trigger the protection are delayed longer than 620ms
+
+0.34.0
+======
+
+* Implemented Multicast Response Aggregation (#940) @bdraco
+
+ Responses are now aggregated when possible per rules in RFC6762
+ section 6.4
+
+ Responses that trigger the protection against against excessive
+ packet flooding due to software bugs or malicious attack described
+ in RFC6762 section 6 are delayed instead of discarding as it was
+ causing responders that implement Passive Observation Of Failures
+ (POOF) to evict the records.
+
+ Probe responses are now always sent immediately as there were cases
+ where they would fail to be answered in time to defend a name.
+
+0.33.4
+======
+
+* Ensure zeroconf can be loaded when the system disables IPv6 (#933) @che0
+
+0.33.3
+======
+
+* Added support for forward dns compression pointers (#934) @bdraco
+* Provide sockname when logging a protocol error (#935) @bdraco
+
+0.33.2
+======
+
+* Handle duplicate goodbye answers in the same packet (#928) @bdraco
+
+ Solves an exception being thrown when we tried to remove the known answer
+ from the cache when the second goodbye answer in the same packet was processed
+
+ Fixed #926
+* Skip ipv6 interfaces that return ENODEV (#930) @bdraco
+
+0.33.1
+======
+
+* Version number change only with less restrictive directory permissions
+
+ Fixed #923
+
+0.33.0
+======
+
+This release eliminates all threading locks as all non-threadsafe operations
+now happen in the event loop.
+
+* Let connection_lost close the underlying socket (#918) @bdraco
+
+ The socket was closed during shutdown before asyncio's connection_lost
+ handler had a chance to close it which resulted in a traceback on
+ windows.
+
+ Fixed #917
+
+Technically backwards incompatible:
+
+* Removed duplicate unregister_all_services code (#910) @bdraco
+
+ Calling Zeroconf.close from same asyncio event loop zeroconf is running in
+ will now skip unregister_all_services and log a warning as this a blocking
+ operation and is not async safe and never has been.
+
+ Use AsyncZeroconf instead, or for legacy code call async_unregister_all_services before Zeroconf.close
+
+0.32.1
+======
+
+* Increased timeout in ServiceInfo.request to handle loaded systems (#895) @bdraco
+
+ It can take a few seconds for a loaded system to run the `async_request`
+ coroutine when the event loop is busy, or the system is CPU bound (example being
+ Home Assistant startup). We now add an additional `_LOADED_SYSTEM_TIMEOUT` (10s)
+ to the `run_coroutine_threadsafe` calls to ensure the coroutine has the total
+ amount of time to run up to its internal timeout (default of 3000ms).
+
+ Ten seconds is a bit large of a timeout; however, it is only used in cases
+ where we wrap other timeouts. We now expect the only instance the
+ `run_coroutine_threadsafe` result timeout will happen in a production
+ circumstance is when someone is running a `ServiceInfo.request()` in a thread and
+ another thread calls `Zeroconf.close()` at just the right moment that the future
+ is never completed unless the system is so loaded that it is nearly unresponsive.
+
+ The timeout for `run_coroutine_threadsafe` is the maximum time a thread can
+ cleanly shut down when zeroconf is closed out in another thread, which should
+ always be longer than the underlying thread operation.
+
+0.32.0
+======
+
+This release offers 100% line and branch coverage.
+
+* Made ServiceInfo first question QU (#852) @bdraco
+
+ We want an immediate response when requesting with ServiceInfo
+ by asking a QU question; most responders will not delay the response
+ and respond right away to our question. This also improves compatibility
+ with split networks as we may not have been able to see the response
+ otherwise. If the responder has not multicast the record recently,
+ it may still choose to do so in addition to responding via unicast
+
+ Reduces traffic when there are multiple zeroconf instances running
+ on the network running ServiceBrowsers
+
+ If we don't get an answer on the first try, we ask a QM question
+ in the event, we can't receive a unicast response for some reason
+
+ This change puts ServiceInfo inline with ServiceBrowser which
+ also asks the first question as QU since ServiceInfo is commonly
+ called from ServiceBrowser callbacks
+* Limited duplicate packet suppression to 1s intervals (#841) @bdraco
+
+ Only suppress duplicate packets that happen within the same
+ second. Legitimate queriers will retry the question if they
+ are suppressed. The limit was reduced to one second to be
+ in line with rfc6762
+* Made multipacket known answer suppression per interface (#836) @bdraco
+
+ The suppression was happening per instance of Zeroconf instead
+ of per interface. Since the same network can be seen on multiple
+ interfaces (usually and wifi and ethernet), this would confuse the
+ multi-packet known answer supression since it was not expecting
+ to get the same data more than once
+* New ServiceBrowsers now request QU in the first outgoing when unspecified (#812) @bdraco
+
+ https://datatracker.ietf.org/doc/html/rfc6762#section-5.4
+ When we start a ServiceBrowser and zeroconf has just started up, the known
+ answer list will be small. By asking a QU question first, it is likely
+ that we have a large known answer list by the time we ask the QM question
+ a second later (current default which is likely too low but would be
+ a breaking change to increase). This reduces the amount of traffic on
+ the network, and has the secondary advantage that most responders will
+ answer a QU question without the typical delay answering QM questions.
+* IPv6 link-local addresses are now qualified with scope_id (#343) @ibygrave
+
+ When a service is advertised on an IPv6 address where
+ the scope is link local, i.e. fe80::/64 (see RFC 4007)
+ the resolved IPv6 address must be extended with the
+ scope_id that identifies through the "%" symbol the
+ local interface to be used when routing to that address.
+ A new API `parsed_scoped_addresses()` is provided to
+ return qualified addresses to avoid breaking compatibility
+ on the existing parsed_addresses().
+* Network adapters that are disconnected are now skipped (#327) @ZLJasonG
+* Fixed listeners missing initial packets if Engine starts too quickly (#387) @bdraco
+
+ When manually creating a zeroconf.Engine object, it is no longer started automatically.
+ It must manually be started by calling .start() on the created object.
+
+ The Engine thread is now started after all the listeners have been added to avoid a
+ race condition where packets could be missed at startup.
+* Fixed answering matching PTR queries with the ANY query (#618) @bdraco
+* Fixed lookup of uppercase names in the registry (#597) @bdraco
+
+ If the ServiceInfo was registered with an uppercase name and the query was
+ for a lowercase name, it would not be found and vice-versa.
+* Fixed unicast responses from any source port (#598) @bdraco
+
+ Unicast responses were only being sent if the source port
+ was 53, this prevented responses when testing with dig:
+
+ dig -p 5353 @224.0.0.251 media-12.local
+
+ The above query will now see a response
+* Fixed queries for AAAA records not being answered (#616) @bdraco
+* Removed second level caching from ServiceBrowsers (#737) @bdraco
+
+ The ServiceBrowser had its own cache of the last time it
+ saw a service that was reimplementing the DNSCache and
+ presenting a source of truth problem that lead to unexpected
+ queries when the two disagreed.
+* Fixed server cache not being case-insensitive (#731) @bdraco
+
+ If the server name had uppercase chars and any of the
+ matching records were lowercase, and the server would not be
+ found
+* Fixed cache handling of records with different TTLs (#729) @bdraco
+
+ There should only be one unique record in the cache at
+ a time as having multiple unique records will different
+ TTLs in the cache can result in unexpected behavior since
+ some functions returned all matching records and some
+ fetched from the right side of the list to return the
+ newest record. Instead we now store the records in a dict
+ to ensure that the newest record always replaces the same
+ unique record, and we never have a source of truth problem
+ determining the TTL of a record from the cache.
+* Fixed ServiceInfo with multiple A records (#725) @bdraco
+
+ If there were multiple A records for the host, ServiceInfo
+ would always return the last one that was in the incoming
+ packet, which was usually not the one that was wanted.
+* Fixed stale unique records expiring too quickly (#706) @bdraco
+
+ Records now expire 1s in the future instead of instant removal.
+
+ tools.ietf.org/html/rfc6762#section-10.2
+ Queriers receiving a Multicast DNS response with a TTL of zero SHOULD
+ NOT immediately delete the record from the cache, but instead record
+ a TTL of 1 and then delete the record one second later. In the case
+ of multiple Multicast DNS responders on the network described in
+ Section 6.6 above, if one of the responders shuts down and
+ incorrectly sends goodbye packets for its records, it gives the other
+ cooperating responders one second to send out their own response to
+ "rescue" the records before they expire and are deleted.
+* Fixed exception when unregistering a service multiple times (#679) @bdraco
+* Added an AsyncZeroconfServiceTypes to mirror ZeroconfServiceTypes to zeroconf.asyncio (#658) @bdraco
+* Fixed interface_index_to_ip6_address not skiping ipv4 adapters (#651) @bdraco
+* Added async_unregister_all_services to AsyncZeroconf (#649) @bdraco
+* Fixed services not being removed from the registry when calling unregister_all_services (#644) @bdraco
+
+ There was a race condition where a query could be answered for a service
+ in the registry, while goodbye packets which could result in a fresh record
+ being broadcast after the goodbye if a query came in at just the right
+ time. To avoid this, we now remove the services from the registry right
+ after we generate the goodbye packet
+* Fixed zeroconf exception on load when the system disables IPv6 (#624) @bdraco
+* Fixed the QU bit missing from for probe queries (#609) @bdraco
+
+ The bit should be set per
+ datatracker.ietf.org/doc/html/rfc6762#section-8.1
+
+* Fixed the TC bit missing for query packets where the known answers span multiple packets (#494) @bdraco
+* Fixed packets not being properly separated when exceeding maximum size (#498) @bdraco
+
+ Ensure that questions that exceed the max packet size are
+ moved to the next packet. This fixes DNSQuestions being
+ sent in multiple packets in violation of:
+ datatracker.ietf.org/doc/html/rfc6762#section-7.2
+
+ Ensure only one resource record is sent when a record
+ exceeds _MAX_MSG_TYPICAL
+ datatracker.ietf.org/doc/html/rfc6762#section-17
+* Fixed PTR questions asked in uppercase not being answered (#465) @bdraco
+* Added Support for context managers in Zeroconf and AsyncZeroconf (#284) @shenek
+* Implemented an AsyncServiceBrowser to compliment the sync ServiceBrowser (#429) @bdraco
+* Added async_get_service_info to AsyncZeroconf and async_request to AsyncServiceInfo (#408) @bdraco
+* Implemented allowing passing in a sync Zeroconf instance to AsyncZeroconf (#406) @bdraco
+* Fixed IPv6 setup under MacOS when binding to "" (#392) @bdraco
+* Fixed ZeroconfServiceTypes.find not always cancels the ServiceBrowser (#389) @bdraco
+
+ There was a short window where the ServiceBrowser thread
+ could be left running after Zeroconf is closed because
+ the .join() was never waited for when a new Zeroconf
+ object was created
+* Fixed duplicate packets triggering duplicate updates (#376) @bdraco
+
+ If TXT or SRV records update was already processed and then
+ received again, it was possible for a second update to be
+ called back in the ServiceBrowser
+* Fixed ServiceStateChange.Updated event happening for IPs that already existed (#375) @bdraco
+* Fixed RFC6762 Section 10.2 paragraph 2 compliance (#374) @bdraco
+* Reduced length of ServiceBrowser thread name with many types (#373) @bdraco
+* Fixed empty answers being added in ServiceInfo.request (#367) @bdraco
+* Fixed ServiceInfo not populating all AAAA records (#366) @bdraco
+
+ Use get_all_by_details to ensure all records are loaded
+ into addresses.
+
+ Only load A/AAAA records from the cache once in load_from_cache
+ if there is a SRV record present
+
+ Move duplicate code that checked if the ServiceInfo was complete
+ into its own function
+* Fixed a case where the cache list can change during iteration (#363) @bdraco
+* Return task objects created by AsyncZeroconf (#360) @nocarryr
+
+Traffic Reduction:
+
+* Added support for handling QU questions (#621) @bdraco
+
+ Implements RFC 6762 sec 5.4:
+ Questions Requesting Unicast Responses
+ datatracker.ietf.org/doc/html/rfc6762#section-5.4
+* Implemented protect the network against excessive packet flooding (#619) @bdraco
+* Additionals are now suppressed when they are already in the answers section (#617) @bdraco
+* Additionals are no longer included when the answer is suppressed by known-answer suppression (#614) @bdraco
+* Implemented multi-packet known answer supression (#687) @bdraco
+
+ Implements datatracker.ietf.org/doc/html/rfc6762#section-7.2
+* Implemented efficient bucketing of queries with known answers (#698) @bdraco
+* Implemented duplicate question suppression (#770) @bdraco
+
+ http://datatracker.ietf.org/doc/html/rfc6762#section-7.3
+
+Technically backwards incompatible:
+
+* Update internal version check to match docs (3.6+) (#491) @bdraco
+
+ Python version earlier then 3.6 were likely broken with zeroconf
+ already, however, the version is now explicitly checked.
+* Update python compatibility as PyPy3 7.2 is required (#523) @bdraco
+
+Backwards incompatible:
+
+* Drop oversize packets before processing them (#826) @bdraco
+
+ Oversized packets can quickly overwhelm the system and deny
+ service to legitimate queriers. In practice, this is usually due to broken mDNS
+ implementations rather than malicious actors.
+* Guard against excessive ServiceBrowser queries from PTR records significantly lowerthan recommended (#824) @bdraco
+
+ We now enforce a minimum TTL for PTR records to avoid
+ ServiceBrowsers generating excessive queries refresh queries.
+ Apple uses a 15s minimum TTL, however, we do not have the same
+ level of rate limit and safeguards, so we use 1/4 of the recommended value.
+* RecordUpdateListener now uses async_update_records instead of update_record (#419, #726) @bdraco
+
+ This allows the listener to receive all the records that have
+ been updated in a single transaction such as a packet or
+ cache expiry.
+
+ update_record has been deprecated in favor of async_update_records
+ A compatibility shim exists to ensure classes that use
+ RecordUpdateListener as a base class continue to have
+ update_record called, however, they should be updated
+ as soon as possible.
+
+ A new method async_update_records_complete is now called on each
+ listener when all listeners have completed processing updates
+ and the cache has been updated. This allows ServiceBrowsers
+ to delay calling handlers until they are sure the cache
+ has been updated as its a common pattern to call for
+ ServiceInfo when a ServiceBrowser handler fires.
+
+ The async\_ prefix was chosen to make it clear that these
+ functions run in the eventloop and should never do blocking
+ I/O. Before 0.32+ these functions ran in a select() loop and
+ should not have been doing any blocking I/O, but it was not
+ clear to implementors that I/O would block the loop.
+* Pass both the new and old records to async_update_records (#792) @bdraco
+
+ Pass the old_record (cached) as the value and the new_record (wire)
+ to async_update_records instead of forcing each consumer to
+ check the cache since we will always have the old_record
+ when generating the async_update_records call. This avoids
+ the overhead of multiple cache lookups for each listener.
+
+0.31.0
+======
+
+* Separated cache loading from I/O in ServiceInfo and fixed cache lookup (#356),
+ thanks to J. Nick Koston.
+
+ The ServiceInfo class gained a load_from_cache() method to only fetch information
+ from Zeroconf cache (if it exists) with no IO performed. Additionally this should
+ reduce IO in cases where cache lookups were previously incorrectly failing.
+
+0.30.0
+======
+
+* Some nice refactoring work including removal of the Reaper thread,
+ thanks to J. Nick Koston.
+
+* Fixed a Windows-specific The requested address is not valid in its context regression,
+ thanks to Timothee ‘TTimo’ Besset and J. Nick Koston.
+
+* Provided an asyncio-compatible service registration layer (in the zeroconf.asyncio module),
+ thanks to J. Nick Koston.
+
+0.29.0
+======
+
+* A single socket is used for listening on responding when `InterfaceChoice.Default` is chosen.
+ Thanks to J. Nick Koston.
+
+Backwards incompatible:
+
+* Dropped Python 3.5 support
+
+0.28.8
+======
+
+* Fixed the packet generation when multiple packets are necessary, previously invalid
+ packets were generated sometimes. Patch thanks to J. Nick Koston.
+
+0.28.7
+======
+
+* Fixed the IPv6 address rendering in the browser example, thanks to Alexey Vazhnov.
+* Fixed a crash happening when a service is added or removed during handle_response
+ and improved exception handling, thanks to J. Nick Koston.
+
+0.28.6
+======
+
+* Loosened service name validation when receiving from the network this lets us handle
+ some real world devices previously causing errors, thanks to J. Nick Koston.
+
+0.28.5
+======
+
+* Enabled ignoring duplicated messages which decreases CPU usage, thanks to J. Nick Koston.
+* Fixed spurious AttributeError: module 'unittest' has no attribute 'mock' in tests.
+
+0.28.4
+======
+
+* Improved cache reaper performance significantly, thanks to J. Nick Koston.
+* Added ServiceListener to __all__ as it's part of the public API, thanks to Justin Nesselrotte.
+
+0.28.3
+======
+
+* Reduced a time an internal lock is held which should eliminate deadlocks in high-traffic networks,
+ thanks to J. Nick Koston.
+
+0.28.2
+======
+
+* Stopped asking questions we already have answers for in cache, thanks to Paul Daumlechner.
+* Removed initial delay before querying for service info, thanks to Erik Montnemery.
+
+0.28.1
+======
+
+* Fixed a resource leak connected to using ServiceBrowser with multiple types, thanks to
+ J. Nick Koston.
+
+0.28.0
+======
+
+* Improved Windows support when using socket errno checks, thanks to Sandy Patterson.
+* Added support for passing text addresses to ServiceInfo.
+* Improved logging (includes fixing an incorrect logging call)
+* Improved Windows compatibility by using Adapter.index from ifaddr, thanks to PhilippSelenium.
+* Improved Windows compatibility by stopping using socket.if_nameindex.
+* Fixed an OS X edge case which should also eliminate a memory leak, thanks to Emil Styrke.
+
+Technically backwards incompatible:
+
+* ``ifaddr`` 0.1.7 or newer is required now.
+
+0.27.1
+------
+
+* Improved the logging situation (includes fixing a false-positive "packets() made no progress
+ adding records", thanks to Greg Badros)
+
+0.27.0
+------
+
+* Large multi-resource responses are now split into separate packets which fixes a bad
+ mdns-repeater/ChromeCast Audio interaction ending with ChromeCast Audio crash (and possibly
+ some others) and improves RFC 6762 compliance, thanks to Greg Badros
+* Added a warning presented when the listener passed to ServiceBrowser lacks update_service()
+ callback
+* Added support for finding all services available in the browser example, thanks to Perry Kunder
+
+Backwards incompatible:
+
+* Removed previously deprecated ServiceInfo address constructor parameter and property
+
+0.26.3
+------
+
+* Improved readability of logged incoming data, thanks to Erik Montnemery
+* Threads are given unique names now to aid debugging, thanks to Erik Montnemery
+* Fixed a regression where get_service_info() called within a listener add_service method
+ would deadlock, timeout and incorrectly return None, fix thanks to Erik Montnemery, but
+ Matt Saxon and Hmmbob were also involved in debugging it.
+
+0.26.2
+------
+
+* Added support for multiple types to ServiceBrowser, thanks to J. Nick Koston
+* Fixed a race condition where a listener gets a message before the lock is created, thanks to
+ J. Nick Koston
+
+0.26.1
+------
+
+* Fixed a performance regression introduced in 0.26.0, thanks to J. Nick Koston (this is close in
+ spirit to an optimization made in 0.24.5 by the same author)
+
+0.26.0
+------
+
+* Fixed a regression where service update listener wasn't called on IP address change (it's called
+ on SRV/A/AAAA record changes now), thanks to Matt Saxon
+
+Technically backwards incompatible:
+
+* Service update hook is no longer called on service addition (service added hook is still called),
+ this is related to the fix above
+
+0.25.1
+------
+
+* Eliminated 5s hangup when calling Zeroconf.close(), thanks to Erik Montnemery
+
+0.25.0
+------
+
+* Reverted uniqueness assertions when browsing, they caused a regression
+
+Backwards incompatible:
+
+* Rationalized handling of TXT records. Non-bytes values are converted to str and encoded to bytes
+ using UTF-8 now, None values mean value-less attributes. When receiving TXT records no decoding
+ is performed now, keys are always bytes and values are either bytes or None in value-less
+ attributes.
+
+0.24.5
+------
+
+* Fixed issues with shared records being used where they shouldn't be (TXT, SRV, A records are
+ unique now), thanks to Matt Saxon
+* Stopped unnecessarily excluding host-only interfaces from InterfaceChoice.all as they don't
+ forbid multicast, thanks to Andreas Oberritter
+* Fixed repr() of IPv6 DNSAddress, thanks to Aldo Hoeben
+* Removed duplicate update messages sent to listeners, thanks to Matt Saxon
+* Added support for cooperating responders, thanks to Matt Saxon
+* Optimized handle_response cache check, thanks to J. Nick Koston
+* Fixed memory leak in DNSCache, thanks to J. Nick Koston
+
0.24.4
------
diff --git a/docs/api.rst b/docs/api.rst
index 5bd2508f..1704db5a 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -5,3 +5,8 @@ python-zeroconf API reference
:members:
:undoc-members:
:show-inheritance:
+
+.. automodule:: zeroconf.asyncio
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/index.rst b/docs/index.rst
index c4fa6143..8929f417 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -1,16 +1,14 @@
Welcome to python-zeroconf documentation!
=========================================
-.. image:: https://travis-ci.org/jstasiak/python-zeroconf.svg?branch=master
- :alt: Build status
- :target: https://travis-ci.org/jstasiak/python-zeroconf
+.. image:: https://github.com/jstasiak/python-zeroconf/workflows/CI/badge.svg
+ :target: https://github.com/jstasiak/python-zeroconf?query=workflow%3ACI+branch%3Amaster
.. image:: https://img.shields.io/pypi/v/zeroconf.svg
:target: https://pypi.python.org/pypi/zeroconf
-
-.. image:: https://coveralls.io/repos/github/jstasiak/python-zeroconf/badge.svg?branch=master
- :alt: Covergage status
- :target: https://coveralls.io/github/jstasiak/python-zeroconf?branch=master
+
+.. image:: https://codecov.io/gh/jstasiak/python-zeroconf/branch/master/graph/badge.svg
+ :target: https://codecov.io/gh/jstasiak/python-zeroconf
GitHub (code repository, issues): https://github.com/jstasiak/python-zeroconf
@@ -18,7 +16,7 @@ PyPI (installable, stable distributions): https://pypi.org/project/zeroconf. You
pip install zeroconf
-python-zeroconf works with CPython 3.5+ and PyPy 3 implementing Python 3.5+.
+python-zeroconf works with CPython 3.6+ and PyPy 3 implementing Python 3.6+.
Contents
--------
@@ -28,4 +26,4 @@ Contents
api
-See `the project's README `_ for more information.
+See `the project's README `_ for more information.
diff --git a/examples/async_apple_scanner.py b/examples/async_apple_scanner.py
new file mode 100644
index 00000000..88b54e4a
--- /dev/null
+++ b/examples/async_apple_scanner.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python3
+
+""" Scan for apple devices. """
+
+import argparse
+import asyncio
+import logging
+from typing import Any, Optional, cast
+
+from zeroconf import DNSQuestionType, IPVersion, ServiceStateChange, Zeroconf
+from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
+
+HOMESHARING_SERVICE: str = "_appletv-v2._tcp.local."
+DEVICE_SERVICE: str = "_touch-able._tcp.local."
+MEDIAREMOTE_SERVICE: str = "_mediaremotetv._tcp.local."
+AIRPLAY_SERVICE: str = "_airplay._tcp.local."
+COMPANION_SERVICE: str = "_companion-link._tcp.local."
+RAOP_SERVICE: str = "_raop._tcp.local."
+AIRPORT_ADMIN_SERVICE: str = "_airport._tcp.local."
+DEVICE_INFO_SERVICE: str = "_device-info._tcp.local."
+
+ALL_SERVICES = [
+ HOMESHARING_SERVICE,
+ DEVICE_SERVICE,
+ MEDIAREMOTE_SERVICE,
+ AIRPLAY_SERVICE,
+ COMPANION_SERVICE,
+ RAOP_SERVICE,
+ AIRPORT_ADMIN_SERVICE,
+ DEVICE_INFO_SERVICE,
+]
+
+log = logging.getLogger(__name__)
+
+
+def async_on_service_state_change(
+ zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange
+) -> None:
+ print(f"Service {name} of type {service_type} state changed: {state_change}")
+ if state_change is not ServiceStateChange.Added:
+ return
+ base_name = name[: -len(service_type) - 1]
+ device_name = f"{base_name}.{DEVICE_INFO_SERVICE}"
+ asyncio.ensure_future(_async_show_service_info(zeroconf, service_type, name))
+ # Also probe for device info
+ asyncio.ensure_future(_async_show_service_info(zeroconf, DEVICE_INFO_SERVICE, device_name))
+
+
+async def _async_show_service_info(zeroconf: Zeroconf, service_type: str, name: str) -> None:
+ info = AsyncServiceInfo(service_type, name)
+ await info.async_request(zeroconf, 3000, question_type=DNSQuestionType.QU)
+ print("Info from zeroconf.get_service_info: %r" % (info))
+ if info:
+ addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()]
+ print(" Name: %s" % name)
+ print(" Addresses: %s" % ", ".join(addresses))
+ print(" Weight: %d, priority: %d" % (info.weight, info.priority))
+ print(f" Server: {info.server}")
+ if info.properties:
+ print(" Properties are:")
+ for key, value in info.properties.items():
+ print(f" {key}: {value}")
+ else:
+ print(" No properties")
+ else:
+ print(" No info")
+ print('\n')
+
+
+class AsyncAppleScanner:
+ def __init__(self, args: Any) -> None:
+ self.args = args
+ self.aiobrowser: Optional[AsyncServiceBrowser] = None
+ self.aiozc: Optional[AsyncZeroconf] = None
+
+ async def async_run(self) -> None:
+ self.aiozc = AsyncZeroconf(ip_version=ip_version)
+ await self.aiozc.zeroconf.async_wait_for_start()
+ print("\nBrowsing %s service(s), press Ctrl-C to exit...\n" % ALL_SERVICES)
+ kwargs = {'handlers': [async_on_service_state_change], 'question_type': DNSQuestionType.QU}
+ if self.args.target:
+ kwargs["addr"] = self.args.target
+ self.aiobrowser = AsyncServiceBrowser(self.aiozc.zeroconf, ALL_SERVICES, **kwargs) # type: ignore
+ while True:
+ await asyncio.sleep(1)
+
+ async def async_close(self) -> None:
+ assert self.aiozc is not None
+ assert self.aiobrowser is not None
+ await self.aiobrowser.async_cancel()
+ await self.aiozc.async_close()
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--debug', action='store_true')
+ version_group = parser.add_mutually_exclusive_group()
+ version_group.add_argument('--target', help='Unicast target')
+ version_group.add_argument('--v6', action='store_true')
+ version_group.add_argument('--v6-only', action='store_true')
+ args = parser.parse_args()
+
+ if args.debug:
+ logging.getLogger('zeroconf').setLevel(logging.DEBUG)
+ if args.v6:
+ ip_version = IPVersion.All
+ elif args.v6_only:
+ ip_version = IPVersion.V6Only
+ else:
+ ip_version = IPVersion.V4Only
+
+ loop = asyncio.get_event_loop()
+ runner = AsyncAppleScanner(args)
+ try:
+ loop.run_until_complete(runner.async_run())
+ except KeyboardInterrupt:
+ loop.run_until_complete(runner.async_close())
diff --git a/examples/async_browser.py b/examples/async_browser.py
new file mode 100644
index 00000000..1cce5c20
--- /dev/null
+++ b/examples/async_browser.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+
+""" Example of browsing for a service.
+
+The default is HTTP and HAP; use --find to search for all available services in the network
+"""
+
+import argparse
+import asyncio
+import logging
+from typing import Any, Optional, cast
+
+from zeroconf import IPVersion, ServiceStateChange, Zeroconf
+from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes
+
+
+def async_on_service_state_change(
+ zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange
+) -> None:
+ print(f"Service {name} of type {service_type} state changed: {state_change}")
+ if state_change is not ServiceStateChange.Added:
+ return
+ asyncio.ensure_future(async_display_service_info(zeroconf, service_type, name))
+
+
+async def async_display_service_info(zeroconf: Zeroconf, service_type: str, name: str) -> None:
+ info = AsyncServiceInfo(service_type, name)
+ await info.async_request(zeroconf, 3000)
+ print("Info from zeroconf.get_service_info: %r" % (info))
+ if info:
+ addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_scoped_addresses()]
+ print(" Name: %s" % name)
+ print(" Addresses: %s" % ", ".join(addresses))
+ print(" Weight: %d, priority: %d" % (info.weight, info.priority))
+ print(f" Server: {info.server}")
+ if info.properties:
+ print(" Properties are:")
+ for key, value in info.properties.items():
+ print(f" {key}: {value}")
+ else:
+ print(" No properties")
+ else:
+ print(" No info")
+ print('\n')
+
+
+class AsyncRunner:
+ def __init__(self, args: Any) -> None:
+ self.args = args
+ self.aiobrowser: Optional[AsyncServiceBrowser] = None
+ self.aiozc: Optional[AsyncZeroconf] = None
+
+ async def async_run(self) -> None:
+ self.aiozc = AsyncZeroconf(ip_version=ip_version)
+
+ services = ["_http._tcp.local.", "_hap._tcp.local."]
+ if self.args.find:
+ services = list(
+ await AsyncZeroconfServiceTypes.async_find(aiozc=self.aiozc, ip_version=ip_version)
+ )
+
+ print("\nBrowsing %s service(s), press Ctrl-C to exit...\n" % services)
+ self.aiobrowser = AsyncServiceBrowser(
+ self.aiozc.zeroconf, services, handlers=[async_on_service_state_change]
+ )
+ while True:
+ await asyncio.sleep(1)
+
+ async def async_close(self) -> None:
+ assert self.aiozc is not None
+ assert self.aiobrowser is not None
+ await self.aiobrowser.async_cancel()
+ await self.aiozc.async_close()
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--find', action='store_true', help='Browse all available services')
+ version_group = parser.add_mutually_exclusive_group()
+ version_group.add_argument('--v6', action='store_true')
+ version_group.add_argument('--v6-only', action='store_true')
+ args = parser.parse_args()
+
+ if args.debug:
+ logging.getLogger('zeroconf').setLevel(logging.DEBUG)
+ if args.v6:
+ ip_version = IPVersion.All
+ elif args.v6_only:
+ ip_version = IPVersion.V6Only
+ else:
+ ip_version = IPVersion.V4Only
+
+ loop = asyncio.get_event_loop()
+ runner = AsyncRunner(args)
+ try:
+ loop.run_until_complete(runner.async_run())
+ except KeyboardInterrupt:
+ loop.run_until_complete(runner.async_close())
diff --git a/examples/async_registration.py b/examples/async_registration.py
new file mode 100644
index 00000000..c3aab326
--- /dev/null
+++ b/examples/async_registration.py
@@ -0,0 +1,74 @@
+#!/usr/bin/env python3
+"""Example of announcing 250 services (in this case, a fake HTTP server)."""
+
+import argparse
+import asyncio
+import logging
+import socket
+from typing import List, Optional
+
+from zeroconf import IPVersion
+from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
+
+
+class AsyncRunner:
+ def __init__(self, ip_version: IPVersion) -> None:
+ self.ip_version = ip_version
+ self.aiozc: Optional[AsyncZeroconf] = None
+
+ async def register_services(self, infos: List[AsyncServiceInfo]) -> None:
+ self.aiozc = AsyncZeroconf(ip_version=self.ip_version)
+ tasks = [self.aiozc.async_register_service(info) for info in infos]
+ background_tasks = await asyncio.gather(*tasks)
+ await asyncio.gather(*background_tasks)
+ print("Finished registration, press Ctrl-C to exit...")
+ while True:
+ await asyncio.sleep(1)
+
+ async def unregister_services(self, infos: List[AsyncServiceInfo]) -> None:
+ assert self.aiozc is not None
+ tasks = [self.aiozc.async_unregister_service(info) for info in infos]
+ background_tasks = await asyncio.gather(*tasks)
+ await asyncio.gather(*background_tasks)
+ await self.aiozc.async_close()
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--debug', action='store_true')
+ version_group = parser.add_mutually_exclusive_group()
+ version_group.add_argument('--v6', action='store_true')
+ version_group.add_argument('--v6-only', action='store_true')
+ args = parser.parse_args()
+
+ if args.debug:
+ logging.getLogger('zeroconf').setLevel(logging.DEBUG)
+ if args.v6:
+ ip_version = IPVersion.All
+ elif args.v6_only:
+ ip_version = IPVersion.V6Only
+ else:
+ ip_version = IPVersion.V4Only
+
+ infos = []
+ for i in range(250):
+ infos.append(
+ AsyncServiceInfo(
+ "_http._tcp.local.",
+ f"Paul's Test Web Site {i}._http._tcp.local.",
+ addresses=[socket.inet_aton("127.0.0.1")],
+ port=80,
+ properties={'path': '/~paulsm/'},
+ server=f"zcdemohost-{i}.local.",
+ )
+ )
+
+ print("Registration of 250 services...")
+ loop = asyncio.get_event_loop()
+ runner = AsyncRunner(ip_version)
+ try:
+ loop.run_until_complete(runner.register_services(infos))
+ except KeyboardInterrupt:
+ loop.run_until_complete(runner.unregister_services(infos))
diff --git a/examples/async_service_info_request.py b/examples/async_service_info_request.py
new file mode 100644
index 00000000..dd8265b7
--- /dev/null
+++ b/examples/async_service_info_request.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+"""Example of perodic dump of homekit services.
+
+This example is useful when a user wants an ondemand
+list of HomeKit devices on the network.
+
+"""
+
+import argparse
+import asyncio
+import logging
+from typing import Any, Optional, cast
+
+
+from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf
+from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
+
+
+HAP_TYPE = "_hap._tcp.local."
+
+
+async def async_watch_services(aiozc: AsyncZeroconf) -> None:
+ zeroconf = aiozc.zeroconf
+ while True:
+ await asyncio.sleep(5)
+ infos = []
+ for name in zeroconf.cache.names():
+ if not name.endswith(HAP_TYPE):
+ continue
+ infos.append(AsyncServiceInfo(HAP_TYPE, name))
+ tasks = [info.async_request(aiozc.zeroconf, 3000) for info in infos]
+ await asyncio.gather(*tasks)
+ for info in infos:
+ print("Info for %s" % (info.name))
+ if info:
+ addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()]
+ print(" Addresses: %s" % ", ".join(addresses))
+ print(" Weight: %d, priority: %d" % (info.weight, info.priority))
+ print(f" Server: {info.server}")
+ if info.properties:
+ print(" Properties are:")
+ for key, value in info.properties.items():
+ print(f" {key}: {value}")
+ else:
+ print(" No properties")
+ else:
+ print(" No info")
+ print('\n')
+
+
+class AsyncRunner:
+ def __init__(self, args: Any) -> None:
+ self.args = args
+ self.threaded_browser: Optional[ServiceBrowser] = None
+ self.aiozc: Optional[AsyncZeroconf] = None
+
+ async def async_run(self) -> None:
+ self.aiozc = AsyncZeroconf(ip_version=ip_version)
+ assert self.aiozc is not None
+
+ def on_service_state_change(
+ zeroconf: Zeroconf, service_type: str, state_change: ServiceStateChange, name: str
+ ) -> None:
+ """Dummy handler."""
+
+ self.threaded_browser = ServiceBrowser(
+ self.aiozc.zeroconf, [HAP_TYPE], handlers=[on_service_state_change]
+ )
+ await async_watch_services(self.aiozc)
+
+ async def async_close(self) -> None:
+ assert self.aiozc is not None
+ assert self.threaded_browser is not None
+ self.threaded_browser.cancel()
+ await self.aiozc.async_close()
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--debug', action='store_true')
+ version_group = parser.add_mutually_exclusive_group()
+ version_group.add_argument('--v6', action='store_true')
+ version_group.add_argument('--v6-only', action='store_true')
+ args = parser.parse_args()
+
+ if args.debug:
+ logging.getLogger('zeroconf').setLevel(logging.DEBUG)
+ if args.v6:
+ ip_version = IPVersion.All
+ elif args.v6_only:
+ ip_version = IPVersion.V6Only
+ else:
+ ip_version = IPVersion.V4Only
+
+ print(f"Services with {HAP_TYPE} will be shown every 5s, press Ctrl-C to exit...")
+ loop = asyncio.get_event_loop()
+ runner = AsyncRunner(args)
+ try:
+ loop.run_until_complete(runner.async_run())
+ except KeyboardInterrupt:
+ loop.run_until_complete(runner.async_close())
diff --git a/examples/browser.py b/examples/browser.py
index bf3ebfbd..8c50e409 100755
--- a/examples/browser.py
+++ b/examples/browser.py
@@ -1,32 +1,36 @@
#!/usr/bin/env python3
-""" Example of browsing for a service (in this case, HTTP) """
+""" Example of browsing for a service.
+
+The default is HTTP and HAP; use --find to search for all available services in the network
+"""
import argparse
import logging
-import socket
from time import sleep
from typing import cast
-from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf
+from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf, ZeroconfServiceTypes
def on_service_state_change(
zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange
) -> None:
- print("Service %s of type %s state changed: %s" % (name, service_type, state_change))
+ print(f"Service {name} of type {service_type} state changed: {state_change}")
if state_change is ServiceStateChange.Added:
info = zeroconf.get_service_info(service_type, name)
+ print("Info from zeroconf.get_service_info: %r" % (info))
+
if info:
- addresses = ["%s:%d" % (socket.inet_ntoa(addr), cast(int, info.port)) for addr in info.addresses]
+ addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_scoped_addresses()]
print(" Addresses: %s" % ", ".join(addresses))
print(" Weight: %d, priority: %d" % (info.weight, info.priority))
- print(" Server: %s" % (info.server,))
+ print(f" Server: {info.server}")
if info.properties:
print(" Properties are:")
for key, value in info.properties.items():
- print(" %s: %s" % (key, value))
+ print(f" {key}: {value}")
else:
print(" No properties")
else:
@@ -39,6 +43,7 @@ def on_service_state_change(
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--find', action='store_true', help='Browse all available services')
version_group = parser.add_mutually_exclusive_group()
version_group.add_argument('--v6', action='store_true')
version_group.add_argument('--v6-only', action='store_true')
@@ -54,8 +59,13 @@ def on_service_state_change(
ip_version = IPVersion.V4Only
zeroconf = Zeroconf(ip_version=ip_version)
- print("\nBrowsing services, press Ctrl-C to exit...\n")
- browser = ServiceBrowser(zeroconf, "_http._tcp.local.", handlers=[on_service_state_change])
+
+ services = ["_http._tcp.local.", "_hap._tcp.local."]
+ if args.find:
+ services = list(ZeroconfServiceTypes.find(zc=zeroconf))
+
+ print("\nBrowsing %d service(s), press Ctrl-C to exit...\n" % len(services))
+ browser = ServiceBrowser(zeroconf, services, handlers=[on_service_state_change])
try:
while True:
diff --git a/examples/self_test.py b/examples/self_test.py
index 35007db1..2178629b 100755
--- a/examples/self_test.py
+++ b/examples/self_test.py
@@ -14,7 +14,7 @@
# Test a few module features, including service registration, service
# query (for Zoe), and service unregistration.
- print("Multicast DNS Service Discovery for Python, version %s" % (__version__,))
+ print(f"Multicast DNS Service Discovery for Python, version {__version__}")
r = Zeroconf()
print("1. Testing registration of a service...")
desc = {'version': '0.10', 'a': 'test value', 'b': 'another value'}
@@ -40,7 +40,7 @@
queried_info = r.get_service_info("_http._tcp.local.", "My Service Name._http._tcp.local.")
assert queried_info
assert set(queried_info.parsed_addresses()) == expected
- print(" Getting self: %s" % (queried_info,))
+ print(f" Getting self: {queried_info}")
print(" Query done.")
print("4. Testing unregister of service information...")
r.unregister_service(info)
diff --git a/pyproject.toml b/pyproject.toml
index a5d30b54..cd79b3e2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,4 +1,37 @@
[tool.black]
line-length = 110
-target_version = ['py35', 'py36', 'py37']
+target_version = ['py35', 'py36', 'py37', 'py38']
skip_string_normalization = true
+
+[tool.pylint.BASIC]
+class-const-naming-style = "any"
+good-names = [
+ "e",
+ "er",
+ "h",
+ "i",
+ "id",
+ "ip",
+ "os",
+ "n",
+ "rr",
+ "rs",
+ "s",
+ "t",
+ "wr",
+ "zc",
+ "_GLOBAL_DONE",
+]
+
+[tool.pylint."MESSAGES CONTROL"]
+disable = [
+ "duplicate-code",
+ "fixme",
+ "format",
+ "missing-class-docstring",
+ "missing-function-docstring",
+ "too-few-public-methods",
+ "too-many-arguments",
+ "too-many-instance-attributes",
+ "too-many-public-methods"
+]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index ec443c0b..e7483666 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,9 +1,16 @@
autopep8
+black;implementation_name=="cpython"
+bump2version
coveralls
coverage
-# Version restricted because of https://github.com/PyCQA/pycodestyle/issues/741
-flake8>=3.6.0
+flake8
flake8-import-order
ifaddr
-nose
-pep8-naming!=0.6.0
+mypy;implementation_name=="cpython"
+# 0.11.0 breaks things https://github.com/PyCQA/pep8-naming/issues/152
+pep8-naming!=0.6.0,!=0.11.0
+pylint
+pytest
+pytest-asyncio
+pytest-cov
+pytest-timeout
diff --git a/setup.cfg b/setup.cfg
index 5610cf68..67e50408 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,12 +1,25 @@
+[bumpversion]
+current_version = 0.36.7
+commit = True
+tag = True
+tag_name = {new_version}
+
+[bumpversion:file:zeroconf/__init__.py]
+search = __version__ = '{current_version}'
+replace = __version__ = '{new_version}'
+
+[tool:pytest]
+testpaths = tests
+
[flake8]
show-source = 1
-application-import-names=zeroconf
-max-line-length=110
-ignore=E203,W503
+application-import-names = zeroconf
+max-line-length = 110
+ignore = E203,W503,N818
[mypy]
ignore_missing_imports = true
-follow_imports = error
+follow_imports = skip
check_untyped_defs = true
no_implicit_optional = true
warn_incomplete_stub = true
@@ -15,7 +28,6 @@ warn_redundant_casts = true
warn_unused_configs = true
warn_unused_ignores = true
warn_return_any = true
-# TODO: disallow untyped calls and defs once we have full type hint coverage
disallow_untyped_calls = false
disallow_untyped_defs = true
diff --git a/setup.py b/setup.py
index a011e602..41e74842 100755
--- a/setup.py
+++ b/setup.py
@@ -9,7 +9,7 @@
readme = f.read()
version = (
- [l for l in open(join(PROJECT_ROOT, 'zeroconf', '__init__.py')) if '__version__' in l][0]
+ [ln for ln in open(join(PROJECT_ROOT, 'zeroconf', '__init__.py')) if '__version__' in ln][0]
.split('=')[-1]
.strip()
.strip('\'"')
@@ -23,7 +23,7 @@
author='Paul Scott-Murphy, William McBrine, Jakub Stasiak',
url='https://github.com/jstasiak/python-zeroconf',
package_data={"zeroconf": ["py.typed"]},
- packages=["zeroconf"],
+ packages=["zeroconf", "zeroconf._protocol", "zeroconf._services", "zeroconf._utils"],
platforms=['unix', 'linux', 'osx'],
license='LGPL',
zip_safe=False,
@@ -39,9 +39,10 @@
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy',
],
keywords=['Bonjour', 'Avahi', 'Zeroconf', 'Multicast DNS', 'Service Discovery', 'mDNS'],
- install_requires=['ifaddr', 'typing;python_version<"3.5"'],
+ install_requires=['ifaddr>=0.1.7'],
)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..2671fe62
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1,80 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import asyncio
+import socket
+from functools import lru_cache
+from typing import List
+
+import ifaddr
+
+
+from zeroconf import DNSIncoming, Zeroconf
+
+
+def _inject_responses(zc: Zeroconf, msgs: List[DNSIncoming]) -> None:
+ """Inject a DNSIncoming response."""
+ assert zc.loop is not None
+
+ async def _wait_for_response():
+ for msg in msgs:
+ zc.handle_response(msg)
+
+ asyncio.run_coroutine_threadsafe(_wait_for_response(), zc.loop).result()
+
+
+def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None:
+ """Inject a DNSIncoming response."""
+ _inject_responses(zc, [msg])
+
+
+def _wait_for_start(zc: Zeroconf) -> None:
+ """Wait for all sockets to be up and running."""
+ assert zc.loop is not None
+ asyncio.run_coroutine_threadsafe(zc.async_wait_for_start(), zc.loop).result()
+
+
+@lru_cache(maxsize=None)
+def has_working_ipv6():
+ """Return True if if the system can bind an IPv6 address."""
+ if not socket.has_ipv6:
+ return False
+
+ try:
+ sock = socket.socket(socket.AF_INET6)
+ sock.bind(('::1', 0))
+ except Exception:
+ return False
+ finally:
+ if sock:
+ sock.close()
+
+ for iface in ifaddr.get_adapters():
+ for addr in iface.ips:
+ if addr.is_IPv6 and iface.index is not None:
+ return True
+ return False
+
+
+def _clear_cache(zc):
+ zc.cache.cache.clear()
+ zc.question_history._history.clear()
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..f900e094
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,30 @@
+#!/usr/bin/env python
+
+
+""" conftest for zeroconf tests. """
+
+import threading
+
+import pytest
+
+import unittest
+
+from zeroconf import _core, const
+
+
+@pytest.fixture(autouse=True)
+def verify_threads_ended():
+ """Verify that the threads are not running after the test."""
+ threads_before = frozenset(threading.enumerate())
+ yield
+ threads = frozenset(threading.enumerate()) - threads_before
+ assert not threads
+
+
+@pytest.fixture
+def run_isolated():
+ """Change the mDNS port to run the test in isolation."""
+ with unittest.mock.patch.object(_core, "_MDNS_PORT", 5454), unittest.mock.patch.object(
+ const, "_MDNS_PORT", 5454
+ ):
+ yield
diff --git a/tests/services/__init__.py b/tests/services/__init__.py
new file mode 100644
index 00000000..2ef4b15b
--- /dev/null
+++ b/tests/services/__init__.py
@@ -0,0 +1,21 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py
new file mode 100644
index 00000000..e22ebfe3
--- /dev/null
+++ b/tests/services/test_browser.py
@@ -0,0 +1,1020 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+
+""" Unit tests for zeroconf._services.browser. """
+
+import asyncio
+import logging
+import socket
+import time
+import os
+import unittest
+from threading import Event
+from unittest.mock import patch
+
+import pytest
+
+import zeroconf as r
+from zeroconf import DNSPointer, DNSQuestion, const, current_time_millis, millis_to_seconds
+import zeroconf._services.browser as _services_browser
+from zeroconf import Zeroconf
+from zeroconf._services import ServiceStateChange
+from zeroconf._services.browser import ServiceBrowser
+from zeroconf._services.info import ServiceInfo
+from zeroconf.asyncio import AsyncZeroconf
+
+from .. import has_working_ipv6, _inject_response, _wait_for_start
+
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+def test_service_browser_cancel_multiple_times():
+ """Test we can cancel a ServiceBrowser multiple times before close."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ # start a browser
+ type_ = "_hap._tcp.local."
+
+ class MyServiceListener(r.ServiceListener):
+ pass
+
+ listener = MyServiceListener()
+
+ browser = r.ServiceBrowser(zc, type_, None, listener)
+
+ browser.cancel()
+ browser.cancel()
+ browser.cancel()
+
+ zc.close()
+
+
+def test_service_browser_cancel_multiple_times_after_close():
+ """Test we can cancel a ServiceBrowser multiple times after close."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ # start a browser
+ type_ = "_hap._tcp.local."
+
+ class MyServiceListener(r.ServiceListener):
+ pass
+
+ listener = MyServiceListener()
+
+ browser = r.ServiceBrowser(zc, type_, None, listener)
+
+ zc.close()
+
+ browser.cancel()
+ browser.cancel()
+ browser.cancel()
+
+
+def test_service_browser_started_after_zeroconf_closed():
+ """Test starting a ServiceBrowser after close raises RuntimeError."""
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ # start a browser
+ type_ = "_hap._tcp.local."
+
+ class MyServiceListener(r.ServiceListener):
+ pass
+
+ listener = MyServiceListener()
+ zc.close()
+
+ with pytest.raises(RuntimeError):
+ browser = r.ServiceBrowser(zc, type_, None, listener)
+
+
+def test_multiple_instances_running_close():
+ """Test we can shutdown multiple instances."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ zc2 = Zeroconf(interfaces=['127.0.0.1'])
+ zc3 = Zeroconf(interfaces=['127.0.0.1'])
+
+ assert zc.loop != zc2.loop
+ assert zc.loop != zc3.loop
+
+ class MyServiceListener(r.ServiceListener):
+ pass
+
+ listener = MyServiceListener()
+
+ zc2.add_service_listener("zca._hap._tcp.local.", listener)
+
+ zc.close()
+ zc2.remove_service_listener(listener)
+ zc2.close()
+ zc3.close()
+
+
+class TestServiceBrowser(unittest.TestCase):
+ def test_update_record(self):
+ enable_ipv6 = has_working_ipv6() and not os.environ.get('SKIP_IPV6')
+
+ service_name = 'name._type._tcp.local.'
+ service_type = '_type._tcp.local.'
+ service_server = 'ash-1.local.'
+ service_text = b'path=/~matt1/'
+ service_address = '10.0.1.2'
+ service_v6_address = "2001:db8::1"
+ service_v6_second_address = "6001:db8::1"
+
+ service_added_count = 0
+ service_removed_count = 0
+ service_updated_count = 0
+ service_add_event = Event()
+ service_removed_event = Event()
+ service_updated_event = Event()
+
+ class MyServiceListener(r.ServiceListener):
+ def add_service(self, zc, type_, name) -> None:
+ nonlocal service_added_count
+ service_added_count += 1
+ service_add_event.set()
+
+ def remove_service(self, zc, type_, name) -> None:
+ nonlocal service_removed_count
+ service_removed_count += 1
+ service_removed_event.set()
+
+ def update_service(self, zc, type_, name) -> None:
+ nonlocal service_updated_count
+ service_updated_count += 1
+ service_info = zc.get_service_info(type_, name)
+ assert socket.inet_aton(service_address) in service_info.addresses
+ if enable_ipv6:
+ assert socket.inet_pton(
+ socket.AF_INET6, service_v6_address
+ ) in service_info.addresses_by_version(r.IPVersion.V6Only)
+ assert socket.inet_pton(
+ socket.AF_INET6, service_v6_second_address
+ ) in service_info.addresses_by_version(r.IPVersion.V6Only)
+ assert service_info.text == service_text
+ assert service_info.server == service_server
+ service_updated_event.set()
+
+ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ assert generated.is_response() is True
+
+ if service_state_change == r.ServiceStateChange.Removed:
+ ttl = 0
+ else:
+ ttl = 120
+
+ generated.add_answer_at_time(
+ r.DNSText(
+ service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text
+ ),
+ 0,
+ )
+
+ generated.add_answer_at_time(
+ r.DNSService(
+ service_name,
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ 0,
+ 0,
+ 80,
+ service_server,
+ ),
+ 0,
+ )
+
+ # Send the IPv6 address first since we previously
+ # had a bug where the IPv4 would be missing if the
+ # IPv6 was seen first
+ if enable_ipv6:
+ generated.add_answer_at_time(
+ r.DNSAddress(
+ service_server,
+ const._TYPE_AAAA,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ socket.inet_pton(socket.AF_INET6, service_v6_address),
+ ),
+ 0,
+ )
+ generated.add_answer_at_time(
+ r.DNSAddress(
+ service_server,
+ const._TYPE_AAAA,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ socket.inet_pton(socket.AF_INET6, service_v6_second_address),
+ ),
+ 0,
+ )
+ generated.add_answer_at_time(
+ r.DNSAddress(
+ service_server,
+ const._TYPE_A,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ socket.inet_aton(service_address),
+ ),
+ 0,
+ )
+
+ generated.add_answer_at_time(
+ r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0
+ )
+
+ return r.DNSIncoming(generated.packets()[0])
+
+ zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
+ service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener())
+
+ try:
+ wait_time = 3
+
+ # service added
+ _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added))
+ service_add_event.wait(wait_time)
+ assert service_added_count == 1
+ assert service_updated_count == 0
+ assert service_removed_count == 0
+
+ # service SRV updated
+ service_updated_event.clear()
+ service_server = 'ash-2.local.'
+ _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated))
+ service_updated_event.wait(wait_time)
+ assert service_added_count == 1
+ assert service_updated_count == 1
+ assert service_removed_count == 0
+
+ # service TXT updated
+ service_updated_event.clear()
+ service_text = b'path=/~matt2/'
+ _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated))
+ service_updated_event.wait(wait_time)
+ assert service_added_count == 1
+ assert service_updated_count == 2
+ assert service_removed_count == 0
+
+ # service TXT updated - duplicate update should not trigger another service_updated
+ service_updated_event.clear()
+ service_text = b'path=/~matt2/'
+ _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated))
+ service_updated_event.wait(wait_time)
+ assert service_added_count == 1
+ assert service_updated_count == 2
+ assert service_removed_count == 0
+
+ # service A updated
+ service_updated_event.clear()
+ service_address = '10.0.1.3'
+ # Verify we match on uppercase
+ service_server = service_server.upper()
+ _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated))
+ service_updated_event.wait(wait_time)
+ assert service_added_count == 1
+ assert service_updated_count == 3
+ assert service_removed_count == 0
+
+ # service all updated
+ service_updated_event.clear()
+ service_server = 'ash-3.local.'
+ service_text = b'path=/~matt3/'
+ service_address = '10.0.1.3'
+ _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated))
+ service_updated_event.wait(wait_time)
+ assert service_added_count == 1
+ assert service_updated_count == 4
+ assert service_removed_count == 0
+
+ # service removed
+ _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed))
+ service_removed_event.wait(wait_time)
+ assert service_added_count == 1
+ assert service_updated_count == 4
+ assert service_removed_count == 1
+
+ finally:
+ assert len(zeroconf.listeners) == 1
+ service_browser.cancel()
+ time.sleep(0.2)
+ assert len(zeroconf.listeners) == 0
+ zeroconf.remove_all_service_listeners()
+ zeroconf.close()
+
+
+class TestServiceBrowserMultipleTypes(unittest.TestCase):
+ def test_update_record(self):
+
+ service_names = ['name2._type2._tcp.local.', 'name._type._tcp.local.', 'name._type._udp.local']
+ service_types = ['_type2._tcp.local.', '_type._tcp.local.', '_type._udp.local.']
+
+ service_added_count = 0
+ service_removed_count = 0
+ service_add_event = Event()
+ service_removed_event = Event()
+
+ class MyServiceListener(r.ServiceListener):
+ def add_service(self, zc, type_, name) -> None:
+ nonlocal service_added_count
+ service_added_count += 1
+ if service_added_count == 3:
+ service_add_event.set()
+
+ def remove_service(self, zc, type_, name) -> None:
+ nonlocal service_removed_count
+ service_removed_count += 1
+ if service_removed_count == 3:
+ service_removed_event.set()
+
+ def mock_incoming_msg(
+ service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int
+ ) -> r.DNSIncoming:
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ generated.add_answer_at_time(
+ r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0
+ )
+ return r.DNSIncoming(generated.packets()[0])
+
+ zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
+ service_browser = r.ServiceBrowser(zeroconf, service_types, listener=MyServiceListener())
+
+ try:
+ wait_time = 3
+
+ # all three services added
+ _inject_response(
+ zeroconf,
+ mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120),
+ )
+ _inject_response(
+ zeroconf,
+ mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120),
+ )
+ time.sleep(0.1)
+
+ called_with_refresh_time_check = False
+
+ def _mock_get_expiration_time(self, percent):
+ nonlocal called_with_refresh_time_check
+ if percent == const._EXPIRE_REFRESH_TIME_PERCENT:
+ called_with_refresh_time_check = True
+ return 0
+ return self.created + (percent * self.ttl * 10)
+
+ # Set an expire time that will force a refresh
+ with patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time):
+ _inject_response(
+ zeroconf,
+ mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120),
+ )
+ # Add the last record after updating the first one
+ # to ensure the service_add_event only gets set
+ # after the update
+ _inject_response(
+ zeroconf,
+ mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120),
+ )
+ service_add_event.wait(wait_time)
+ assert called_with_refresh_time_check is True
+ assert service_added_count == 3
+ assert service_removed_count == 0
+
+ _inject_response(
+ zeroconf,
+ mock_incoming_msg(r.ServiceStateChange.Updated, service_types[0], service_names[0], 0),
+ )
+
+ # all three services removed
+ _inject_response(
+ zeroconf,
+ mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0], 0),
+ )
+ _inject_response(
+ zeroconf,
+ mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1], 0),
+ )
+ _inject_response(
+ zeroconf,
+ mock_incoming_msg(r.ServiceStateChange.Removed, service_types[2], service_names[2], 0),
+ )
+ service_removed_event.wait(wait_time)
+ assert service_added_count == 3
+ assert service_removed_count == 3
+
+ finally:
+ assert len(zeroconf.listeners) == 1
+ service_browser.cancel()
+ time.sleep(0.2)
+ assert len(zeroconf.listeners) == 0
+ zeroconf.remove_all_service_listeners()
+ zeroconf.close()
+
+
+def test_backoff():
+ got_query = Event()
+
+ type_ = "_http._tcp.local."
+ zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
+ _wait_for_start(zeroconf_browser)
+
+ # we are going to patch the zeroconf send to check query transmission
+ old_send = zeroconf_browser.async_send
+
+ time_offset = 0.0
+ start_time = time.time() * 1000
+ initial_query_interval = _services_browser._BROWSER_TIME / 1000
+
+ def current_time_millis():
+ """Current system time in milliseconds"""
+ return start_time + time_offset * 1000
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
+ """Sends an outgoing packet."""
+ got_query.set()
+ old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope)
+
+ # patch the zeroconf send
+ # patch the zeroconf current_time_millis
+ # patch the backoff limit to prevent test running forever
+ with patch.object(zeroconf_browser, "async_send", send), patch.object(
+ zeroconf_browser.question_history, "suppresses", return_value=False
+ ), patch.object(_services_browser, "current_time_millis", current_time_millis), patch.object(
+ _services_browser, "_BROWSER_BACKOFF_LIMIT", 10
+ ), patch.object(
+ _services_browser, "_FIRST_QUERY_DELAY_RANDOM_INTERVAL", (0, 0)
+ ):
+ # dummy service callback
+ def on_service_state_change(zeroconf, service_type, state_change, name):
+ pass
+
+ browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change])
+
+ try:
+ # Test that queries are sent at increasing intervals
+ sleep_count = 0
+ next_query_interval = 0.0
+ expected_query_time = 0.0
+ while True:
+ sleep_count += 1
+ got_query.wait(0.1)
+ if time_offset == expected_query_time:
+ assert got_query.is_set()
+ got_query.clear()
+ if next_query_interval == _services_browser._BROWSER_BACKOFF_LIMIT:
+ # Only need to test up to the point where we've seen a query
+ # after the backoff limit has been hit
+ break
+ elif next_query_interval == 0:
+ next_query_interval = initial_query_interval
+ expected_query_time = initial_query_interval
+ else:
+ next_query_interval = min(
+ 2 * next_query_interval, _services_browser._BROWSER_BACKOFF_LIMIT
+ )
+ expected_query_time += next_query_interval
+ else:
+ assert not got_query.is_set()
+ time_offset += initial_query_interval
+ zeroconf_browser.loop.call_soon_threadsafe(browser._async_send_ready_queries_schedule_next)
+
+ finally:
+ browser.cancel()
+ zeroconf_browser.close()
+
+
+def test_first_query_delay():
+ """Verify the first query is delayed.
+
+ https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
+ """
+ type_ = "_http._tcp.local."
+ zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
+ _wait_for_start(zeroconf_browser)
+
+ # we are going to patch the zeroconf send to check query transmission
+ old_send = zeroconf_browser.async_send
+
+ first_query_time = None
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT):
+ """Sends an outgoing packet."""
+ nonlocal first_query_time
+ if first_query_time is None:
+ first_query_time = current_time_millis()
+ old_send(out, addr=addr, port=port)
+
+ # patch the zeroconf send
+ with patch.object(zeroconf_browser, "async_send", send):
+ # dummy service callback
+ def on_service_state_change(zeroconf, service_type, state_change, name):
+ pass
+
+ start_time = current_time_millis()
+ browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change])
+ time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5))
+ try:
+ assert (
+ current_time_millis() - start_time > _services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[0]
+ )
+ finally:
+ browser.cancel()
+ zeroconf_browser.close()
+
+
+def test_asking_default_is_asking_qm_questions_after_the_first_qu():
+ """Verify the service browser's first question is QU and subsequent ones are QM questions."""
+ type_ = "_quservice._tcp.local."
+ zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
+
+ # we are going to patch the zeroconf send to check query transmission
+ old_send = zeroconf_browser.async_send
+
+ first_outgoing = None
+ second_outgoing = None
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT):
+ """Sends an outgoing packet."""
+ nonlocal first_outgoing
+ nonlocal second_outgoing
+ if first_outgoing is not None and second_outgoing is None:
+ second_outgoing = out
+ if first_outgoing is None:
+ first_outgoing = out
+ old_send(out, addr=addr, port=port)
+
+ # patch the zeroconf send
+ with patch.object(zeroconf_browser, "async_send", send):
+ # dummy service callback
+ def on_service_state_change(zeroconf, service_type, state_change, name):
+ pass
+
+ browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change], delay=5)
+ time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 120 + 5))
+ try:
+ assert first_outgoing.questions[0].unicast == True
+ assert second_outgoing.questions[0].unicast == False
+ finally:
+ browser.cancel()
+ zeroconf_browser.close()
+
+
+def test_asking_qm_questions():
+ """Verify explictly asking QM questions."""
+ type_ = "_quservice._tcp.local."
+ zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
+
+ # we are going to patch the zeroconf send to check query transmission
+ old_send = zeroconf_browser.async_send
+
+ first_outgoing = None
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT):
+ """Sends an outgoing packet."""
+ nonlocal first_outgoing
+ if first_outgoing is None:
+ first_outgoing = out
+ old_send(out, addr=addr, port=port)
+
+ # patch the zeroconf send
+ with patch.object(zeroconf_browser, "async_send", send):
+ # dummy service callback
+ def on_service_state_change(zeroconf, service_type, state_change, name):
+ pass
+
+ browser = ServiceBrowser(
+ zeroconf_browser, type_, [on_service_state_change], question_type=r.DNSQuestionType.QM
+ )
+ time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5))
+ try:
+ assert first_outgoing.questions[0].unicast == False
+ finally:
+ browser.cancel()
+ zeroconf_browser.close()
+
+
+def test_asking_qu_questions():
+ """Verify the service browser can ask QU questions."""
+ type_ = "_quservice._tcp.local."
+ zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
+
+ # we are going to patch the zeroconf send to check query transmission
+ old_send = zeroconf_browser.async_send
+
+ first_outgoing = None
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT):
+ """Sends an outgoing packet."""
+ nonlocal first_outgoing
+ if first_outgoing is None:
+ first_outgoing = out
+ old_send(out, addr=addr, port=port)
+
+ # patch the zeroconf send
+ with patch.object(zeroconf_browser, "async_send", send):
+ # dummy service callback
+ def on_service_state_change(zeroconf, service_type, state_change, name):
+ pass
+
+ browser = ServiceBrowser(
+ zeroconf_browser, type_, [on_service_state_change], question_type=r.DNSQuestionType.QU
+ )
+ time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5))
+ try:
+ assert first_outgoing.questions[0].unicast == True
+ finally:
+ browser.cancel()
+ zeroconf_browser.close()
+
+
+def test_legacy_record_update_listener():
+ """Test a RecordUpdateListener that does not implement update_records."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ with pytest.raises(RuntimeError):
+ r.RecordUpdateListener().update_record(
+ zc, 0, r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL)
+ )
+
+ updates = []
+
+ class LegacyRecordUpdateListener(r.RecordUpdateListener):
+ """A RecordUpdateListener that does not implement update_records."""
+
+ def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None:
+ nonlocal updates
+ updates.append(record)
+
+ listener = LegacyRecordUpdateListener()
+
+ zc.add_listener(listener, None)
+
+ # dummy service callback
+ def on_service_state_change(zeroconf, service_type, state_change, name):
+ pass
+
+ # start a browser
+ type_ = "_homeassistant._tcp.local."
+ name = "MyTestHome"
+ browser = ServiceBrowser(zc, type_, [on_service_state_change])
+
+ info_service = ServiceInfo(
+ type_,
+ '%s.%s' % (name, type_),
+ 80,
+ 0,
+ 0,
+ {'path': '/~paulsm/'},
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+
+ zc.register_service(info_service)
+
+ time.sleep(0.001)
+
+ browser.cancel()
+
+ assert len(updates)
+ assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1
+
+ zc.remove_listener(listener)
+ # Removing a second time should not throw
+ zc.remove_listener(listener)
+
+ zc.close()
+
+
+def test_service_browser_is_aware_of_port_changes():
+ """Test that the ServiceBrowser is aware of port changes."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ # start a browser
+ type_ = "_hap._tcp.local."
+ registration_name = "xxxyyy.%s" % type_
+
+ callbacks = []
+ # dummy service callback
+ def on_service_state_change(zeroconf, service_type, state_change, name):
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append((service_type, state_change, name))
+
+ browser = ServiceBrowser(zc, type_, [on_service_state_change])
+
+ desc = {'path': '/~paulsm/'}
+ address_parsed = "10.0.1.2"
+ address = socket.inet_aton(address_parsed)
+ info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address])
+
+ def mock_incoming_msg(records) -> r.DNSIncoming:
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ for record in records:
+ generated.add_answer_at_time(record, 0)
+ return r.DNSIncoming(generated.packets()[0])
+
+ _inject_response(
+ zc,
+ mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]),
+ )
+ time.sleep(0.1)
+
+ assert callbacks == [('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.')]
+ assert zc.get_service_info(type_, registration_name).port == 80
+
+ info.port = 400
+ _inject_response(
+ zc,
+ mock_incoming_msg([info.dns_service()]),
+ )
+ time.sleep(0.1)
+
+ assert callbacks == [
+ ('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.'),
+ ('_hap._tcp.local.', ServiceStateChange.Updated, 'xxxyyy._hap._tcp.local.'),
+ ]
+ assert zc.get_service_info(type_, registration_name).port == 400
+ browser.cancel()
+
+ zc.close()
+
+
+def test_service_browser_listeners_update_service():
+ """Test that the ServiceBrowser ServiceListener that implements update_service."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ # start a browser
+ type_ = "_hap._tcp.local."
+ registration_name = "xxxyyy.%s" % type_
+ callbacks = []
+
+ class MyServiceListener(r.ServiceListener):
+ def add_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("add", type_, name))
+
+ def remove_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("remove", type_, name))
+
+ def update_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("update", type_, name))
+
+ listener = MyServiceListener()
+
+ browser = r.ServiceBrowser(zc, type_, None, listener)
+
+ desc = {'path': '/~paulsm/'}
+ address_parsed = "10.0.1.2"
+ address = socket.inet_aton(address_parsed)
+ info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address])
+
+ def mock_incoming_msg(records) -> r.DNSIncoming:
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ for record in records:
+ generated.add_answer_at_time(record, 0)
+ return r.DNSIncoming(generated.packets()[0])
+
+ _inject_response(
+ zc,
+ mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]),
+ )
+ time.sleep(0.2)
+ info.port = 400
+ _inject_response(
+ zc,
+ mock_incoming_msg([info.dns_service()]),
+ )
+ time.sleep(0.2)
+
+ assert callbacks == [
+ ('add', type_, registration_name),
+ ('update', type_, registration_name),
+ ]
+ browser.cancel()
+
+ zc.close()
+
+
+def test_service_browser_listeners_no_update_service():
+ """Test that the ServiceBrowser ServiceListener that does not implement update_service."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ # start a browser
+ type_ = "_hap._tcp.local."
+ registration_name = "xxxyyy.%s" % type_
+ callbacks = []
+
+ class MyServiceListener:
+ def add_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("add", type_, name))
+
+ def remove_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("remove", type_, name))
+
+ listener = MyServiceListener()
+
+ browser = r.ServiceBrowser(zc, type_, None, listener)
+
+ desc = {'path': '/~paulsm/'}
+ address_parsed = "10.0.1.2"
+ address = socket.inet_aton(address_parsed)
+ info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address])
+
+ def mock_incoming_msg(records) -> r.DNSIncoming:
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ for record in records:
+ generated.add_answer_at_time(record, 0)
+ return r.DNSIncoming(generated.packets()[0])
+
+ _inject_response(
+ zc,
+ mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]),
+ )
+ time.sleep(0.2)
+ info.port = 400
+ _inject_response(
+ zc,
+ mock_incoming_msg([info.dns_service()]),
+ )
+ time.sleep(0.2)
+
+ assert callbacks == [
+ ('add', type_, registration_name),
+ ]
+ browser.cancel()
+
+ zc.close()
+
+
+def test_servicebrowser_uses_non_strict_names():
+ """Verify we can look for technically invalid names as we cannot change what others do."""
+
+ # dummy service callback
+ def on_service_state_change(zeroconf, service_type, state_change, name):
+ pass
+
+ zc = r.Zeroconf(interfaces=['127.0.0.1'])
+ browser = ServiceBrowser(zc, ["_tivo-videostream._tcp.local."], [on_service_state_change])
+ browser.cancel()
+
+ # Still fail on completely invalid
+ with pytest.raises(r.BadTypeInNameException):
+ browser = ServiceBrowser(zc, ["tivo-videostream._tcp.local."], [on_service_state_change])
+ zc.close()
+
+
+def test_group_ptr_queries_with_known_answers():
+ questions_with_known_answers: _services_browser._QuestionWithKnownAnswers = {}
+ now = current_time_millis()
+ for i in range(120):
+ name = f"_hap{i}._tcp._local."
+ questions_with_known_answers[DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN)] = set(
+ DNSPointer(
+ name,
+ const._TYPE_PTR,
+ const._CLASS_IN,
+ 4500,
+ f"zoo{counter}.{name}",
+ )
+ for counter in range(i)
+ )
+ outs = _services_browser._group_ptr_queries_with_known_answers(now, True, questions_with_known_answers)
+ for out in outs:
+ packets = out.packets()
+ # If we generate multiple packets there must
+ # only be one question
+ assert len(packets) == 1 or len(out.questions) == 1
+
+
+# This test uses asyncio because it needs to access the cache directly
+# which is not threadsafe
+@pytest.mark.asyncio
+async def test_generate_service_query_suppress_duplicate_questions():
+ """Generate a service query for sending with zeroconf.send."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zc = aiozc.zeroconf
+ now = current_time_millis()
+ name = "_suppresstest._tcp.local."
+ question = r.DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN)
+ answer = r.DNSPointer(
+ name,
+ const._TYPE_PTR,
+ const._CLASS_IN,
+ 10000,
+ f'known-to-other.{name}',
+ )
+ other_known_answers = set([answer])
+ zc.question_history.add_question_at_time(question, now, other_known_answers)
+ assert zc.question_history.suppresses(question, now, other_known_answers)
+
+ # The known answer list is different, do not suppress
+ outs = _services_browser.generate_service_query(zc, now, [name], multicast=True)
+ assert outs
+
+ zc.cache.async_add_records([answer])
+ # The known answer list contains all the asked questions in the history
+ # we should suppress
+
+ outs = _services_browser.generate_service_query(zc, now, [name], multicast=True)
+ assert not outs
+
+ # We do not suppress once the question history expires
+ outs = _services_browser.generate_service_query(zc, now + 1000, [name], multicast=True)
+ assert outs
+
+ # We do not suppress QU queries ever
+ outs = _services_browser.generate_service_query(zc, now, [name], multicast=False)
+ assert outs
+
+ zc.question_history.async_expire(now + 2000)
+ # No suppression after clearing the history
+ outs = _services_browser.generate_service_query(zc, now, [name], multicast=True)
+ assert outs
+
+ # The previous query we just sent is still remembered and
+ # the next one is suppressed
+ outs = _services_browser.generate_service_query(zc, now, [name], multicast=True)
+ assert not outs
+
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_query_scheduler():
+ delay = const._BROWSER_TIME
+ types_ = set(["_hap._tcp.local.", "_http._tcp.local."])
+ query_scheduler = _services_browser.QueryScheduler(types_, delay, (0, 0))
+
+ now = current_time_millis()
+ query_scheduler.start(now)
+
+ # Test query interval is increasing
+ assert query_scheduler.millis_to_wait(now - 1) == 1
+ assert query_scheduler.millis_to_wait(now) is 0
+ assert query_scheduler.millis_to_wait(now + 1) is 0
+
+ assert set(query_scheduler.process_ready_types(now)) == types_
+ assert set(query_scheduler.process_ready_types(now)) == set()
+ assert query_scheduler.millis_to_wait(now) == delay
+
+ assert set(query_scheduler.process_ready_types(now + delay)) == types_
+ assert set(query_scheduler.process_ready_types(now + delay)) == set()
+ assert query_scheduler.millis_to_wait(now) == delay * 3
+
+ assert set(query_scheduler.process_ready_types(now + delay * 3)) == types_
+ assert set(query_scheduler.process_ready_types(now + delay * 3)) == set()
+ assert query_scheduler.millis_to_wait(now) == delay * 7
+
+ assert set(query_scheduler.process_ready_types(now + delay * 7)) == types_
+ assert set(query_scheduler.process_ready_types(now + delay * 7)) == set()
+ assert query_scheduler.millis_to_wait(now) == delay * 15
+
+ assert set(query_scheduler.process_ready_types(now + delay * 15)) == types_
+ assert set(query_scheduler.process_ready_types(now + delay * 15)) == set()
+
+ # Test if we reschedule 1 second later, the millis_to_wait goes up by 1
+ query_scheduler.reschedule_type("_hap._tcp.local.", now + delay * 16)
+ assert query_scheduler.millis_to_wait(now) == delay * 16
+
+ assert set(query_scheduler.process_ready_types(now + delay * 15)) == set()
+
+ # Test if we reschedule 1 second later... and its ready for processing
+ assert set(query_scheduler.process_ready_types(now + delay * 16)) == set(["_hap._tcp.local."])
+ assert query_scheduler.millis_to_wait(now) == delay * 31
+ assert set(query_scheduler.process_ready_types(now + delay * 20)) == set()
+
+ assert set(query_scheduler.process_ready_types(now + delay * 31)) == set(["_http._tcp.local."])
diff --git a/tests/services/test_info.py b/tests/services/test_info.py
new file mode 100644
index 00000000..a72d82f9
--- /dev/null
+++ b/tests/services/test_info.py
@@ -0,0 +1,779 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+
+""" Unit tests for zeroconf._services.info. """
+
+import logging
+import socket
+import threading
+import os
+import unittest
+from unittest.mock import patch
+from threading import Event
+from typing import List
+
+import pytest
+
+import zeroconf as r
+from zeroconf import DNSAddress, const
+from zeroconf._services.info import ServiceInfo
+from zeroconf.asyncio import AsyncZeroconf
+
+from .. import has_working_ipv6, _inject_response
+
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+class TestServiceInfo(unittest.TestCase):
+ def test_get_name(self):
+ """Verify the name accessor can strip the type."""
+ desc = {'path': '/~paulsm/'}
+ service_name = 'name._type._tcp.local.'
+ service_type = '_type._tcp.local.'
+ service_server = 'ash-1.local.'
+ service_address = socket.inet_aton("10.0.1.2")
+ info = ServiceInfo(
+ service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address]
+ )
+ assert info.get_name() == "name"
+
+ def test_service_info_rejects_non_matching_updates(self):
+ """Verify records with the wrong name are rejected."""
+
+ zc = r.Zeroconf(interfaces=['127.0.0.1'])
+ desc = {'path': '/~paulsm/'}
+ service_name = 'name._type._tcp.local.'
+ service_type = '_type._tcp.local.'
+ service_server = 'ash-1.local.'
+ service_address = socket.inet_aton("10.0.1.2")
+ ttl = 120
+ now = r.current_time_millis()
+ info = ServiceInfo(
+ service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address]
+ )
+ # Verify backwards compatiblity with calling with None
+ info.update_record(zc, now, None)
+ # Matching updates
+ info.update_record(
+ zc,
+ now,
+ r.DNSText(
+ service_name,
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
+ ),
+ )
+ assert info.properties[b"ci"] == b"2"
+ info.update_record(
+ zc,
+ now,
+ r.DNSService(
+ service_name,
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ 0,
+ 0,
+ 80,
+ 'ASH-2.local.',
+ ),
+ )
+ assert info.server_key == 'ash-2.local.'
+ assert info.server == 'ASH-2.local.'
+ new_address = socket.inet_aton("10.0.1.3")
+ info.update_record(
+ zc,
+ now,
+ r.DNSAddress(
+ 'ASH-2.local.',
+ const._TYPE_A,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ new_address,
+ ),
+ )
+ assert new_address in info.addresses
+ # Non-matching updates
+ info.update_record(
+ zc,
+ now,
+ r.DNSText(
+ "incorrect.name.",
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==',
+ ),
+ )
+ assert info.properties[b"ci"] == b"2"
+ info.update_record(
+ zc,
+ now,
+ r.DNSService(
+ "incorrect.name.",
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ 0,
+ 0,
+ 80,
+ 'ASH-2.local.',
+ ),
+ )
+ assert info.server_key == 'ash-2.local.'
+ assert info.server == 'ASH-2.local.'
+ new_address = socket.inet_aton("10.0.1.4")
+ info.update_record(
+ zc,
+ now,
+ r.DNSAddress(
+ "incorrect.name.",
+ const._TYPE_A,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ new_address,
+ ),
+ )
+ assert new_address not in info.addresses
+ zc.close()
+
+ def test_service_info_rejects_expired_records(self):
+ """Verify records that are expired are rejected."""
+ zc = r.Zeroconf(interfaces=['127.0.0.1'])
+ desc = {'path': '/~paulsm/'}
+ service_name = 'name._type._tcp.local.'
+ service_type = '_type._tcp.local.'
+ service_server = 'ash-1.local.'
+ service_address = socket.inet_aton("10.0.1.2")
+ ttl = 120
+ now = r.current_time_millis()
+ info = ServiceInfo(
+ service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address]
+ )
+ # Matching updates
+ info.update_record(
+ zc,
+ now,
+ r.DNSText(
+ service_name,
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
+ ),
+ )
+ assert info.properties[b"ci"] == b"2"
+ # Expired record
+ expired_record = r.DNSText(
+ service_name,
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==',
+ )
+ expired_record.set_created_ttl(1000, 1)
+ info.update_record(zc, now, expired_record)
+ assert info.properties[b"ci"] == b"2"
+ zc.close()
+
+ @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
+ @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+ def test_get_info_partial(self):
+
+ zc = r.Zeroconf(interfaces=['127.0.0.1'])
+
+ service_name = 'name._type._tcp.local.'
+ service_type = '_type._tcp.local.'
+ service_server = 'ash-1.local.'
+ service_text = b'path=/~matt1/'
+ service_address = '10.0.1.2'
+ service_address_v6_ll = 'fe80::52e:c2f2:bc5f:e9c6'
+ service_scope_id = 12
+
+ service_info = None
+ send_event = Event()
+ service_info_event = Event()
+
+ last_sent = None # type: Optional[r.DNSOutgoing]
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
+ """Sends an outgoing packet."""
+ nonlocal last_sent
+
+ last_sent = out
+ send_event.set()
+
+ # patch the zeroconf send
+ with patch.object(zc, "async_send", send):
+
+ def mock_incoming_msg(records) -> r.DNSIncoming:
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+
+ for record in records:
+ generated.add_answer_at_time(record, 0)
+
+ return r.DNSIncoming(generated.packets()[0])
+
+ def get_service_info_helper(zc, type, name):
+ nonlocal service_info
+ service_info = zc.get_service_info(type, name)
+ service_info_event.set()
+
+ try:
+ ttl = 120
+ helper_thread = threading.Thread(
+ target=get_service_info_helper, args=(zc, service_type, service_name)
+ )
+ helper_thread.start()
+ wait_time = 1
+
+ # Expext query for SRV, TXT, A, AAAA
+ send_event.wait(wait_time)
+ assert last_sent is not None
+ assert len(last_sent.questions) == 4
+ assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions
+ assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions
+ assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions
+ assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions
+ assert service_info is None
+
+ # Expext query for SRV, A, AAAA
+ last_sent = None
+ send_event.clear()
+ _inject_response(
+ zc,
+ mock_incoming_msg(
+ [
+ r.DNSText(
+ service_name,
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ service_text,
+ )
+ ]
+ ),
+ )
+ send_event.wait(wait_time)
+ assert last_sent is not None
+ assert len(last_sent.questions) == 3
+ assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions
+ assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions
+ assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions
+ assert service_info is None
+
+ # Expext query for A, AAAA
+ last_sent = None
+ send_event.clear()
+ _inject_response(
+ zc,
+ mock_incoming_msg(
+ [
+ r.DNSService(
+ service_name,
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ 0,
+ 0,
+ 80,
+ service_server,
+ )
+ ]
+ ),
+ )
+ send_event.wait(wait_time)
+ assert last_sent is not None
+ assert len(last_sent.questions) == 2
+ assert r.DNSQuestion(service_server, const._TYPE_A, const._CLASS_IN) in last_sent.questions
+ assert r.DNSQuestion(service_server, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions
+ last_sent = None
+ assert service_info is None
+
+ # Expext no further queries
+ last_sent = None
+ send_event.clear()
+ _inject_response(
+ zc,
+ mock_incoming_msg(
+ [
+ r.DNSAddress(
+ service_server,
+ const._TYPE_A,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ socket.inet_pton(socket.AF_INET, service_address),
+ ),
+ r.DNSAddress(
+ service_server,
+ const._TYPE_AAAA,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ socket.inet_pton(socket.AF_INET6, service_address_v6_ll),
+ scope_id=service_scope_id,
+ ),
+ ]
+ ),
+ )
+ send_event.wait(wait_time)
+ assert last_sent is None
+ assert service_info is not None
+
+ finally:
+ helper_thread.join()
+ zc.remove_all_service_listeners()
+ zc.close()
+
+ def test_get_info_single(self):
+
+ zc = r.Zeroconf(interfaces=['127.0.0.1'])
+
+ service_name = 'name._type._tcp.local.'
+ service_type = '_type._tcp.local.'
+ service_server = 'ash-1.local.'
+ service_text = b'path=/~matt1/'
+ service_address = '10.0.1.2'
+
+ service_info = None
+ send_event = Event()
+ service_info_event = Event()
+
+ last_sent = None # type: Optional[r.DNSOutgoing]
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
+ """Sends an outgoing packet."""
+ nonlocal last_sent
+
+ last_sent = out
+ send_event.set()
+
+ # patch the zeroconf send
+ with patch.object(zc, "async_send", send):
+
+ def mock_incoming_msg(records) -> r.DNSIncoming:
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+
+ for record in records:
+ generated.add_answer_at_time(record, 0)
+
+ return r.DNSIncoming(generated.packets()[0])
+
+ def get_service_info_helper(zc, type, name):
+ nonlocal service_info
+ service_info = zc.get_service_info(type, name)
+ service_info_event.set()
+
+ try:
+ ttl = 120
+ helper_thread = threading.Thread(
+ target=get_service_info_helper, args=(zc, service_type, service_name)
+ )
+ helper_thread.start()
+ wait_time = 1
+
+ # Expext query for SRV, TXT, A, AAAA
+ send_event.wait(wait_time)
+ assert last_sent is not None
+ assert len(last_sent.questions) == 4
+ assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions
+ assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions
+ assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions
+ assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions
+ assert service_info is None
+
+ # Expext no further queries
+ last_sent = None
+ send_event.clear()
+ _inject_response(
+ zc,
+ mock_incoming_msg(
+ [
+ r.DNSText(
+ service_name,
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ service_text,
+ ),
+ r.DNSService(
+ service_name,
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ 0,
+ 0,
+ 80,
+ service_server,
+ ),
+ r.DNSAddress(
+ service_server,
+ const._TYPE_A,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ socket.inet_pton(socket.AF_INET, service_address),
+ ),
+ ]
+ ),
+ )
+ send_event.wait(wait_time)
+ assert last_sent is None
+ assert service_info is not None
+
+ finally:
+ helper_thread.join()
+ zc.remove_all_service_listeners()
+ zc.close()
+
+ def test_service_info_duplicate_properties_txt_records(self):
+ """Verify the first property is always used when there are duplicates in a txt record."""
+
+ zc = r.Zeroconf(interfaces=['127.0.0.1'])
+ desc = {'path': '/~paulsm/'}
+ service_name = 'name._type._tcp.local.'
+ service_type = '_type._tcp.local.'
+ service_server = 'ash-1.local.'
+ service_address = socket.inet_aton("10.0.1.2")
+ ttl = 120
+ now = r.current_time_millis()
+ info = ServiceInfo(
+ service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address]
+ )
+ info.async_update_records(
+ zc,
+ now,
+ [
+ r.RecordUpdate(
+ r.DNSText(
+ service_name,
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==\x04dd=0\x04jl=2\x04qq=0\x0brr=6fLM5A==\x04ci=3',
+ ),
+ None,
+ )
+ ],
+ )
+ assert info.properties[b"dd"] == b"0"
+ assert info.properties[b"jl"] == b"2"
+ assert info.properties[b"ci"] == b"2"
+ zc.close()
+
+
+def test_multiple_addresses():
+ type_ = "_http._tcp.local."
+ registration_name = "xxxyyy.%s" % type_
+ desc = {'path': '/~paulsm/'}
+ address_parsed = "10.0.1.2"
+ address = socket.inet_aton(address_parsed)
+
+ # New kwarg way
+ info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address])
+
+ assert info.addresses == [address, address]
+ assert info.parsed_addresses() == [address_parsed, address_parsed]
+ assert info.parsed_scoped_addresses() == [address_parsed, address_parsed]
+
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ parsed_addresses=[address_parsed, address_parsed],
+ )
+ assert info.addresses == [address, address]
+ assert info.parsed_addresses() == [address_parsed, address_parsed]
+ assert info.parsed_scoped_addresses() == [address_parsed, address_parsed]
+
+ if has_working_ipv6() and not os.environ.get('SKIP_IPV6'):
+ address_v6_parsed = "2001:db8::1"
+ address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed)
+ address_v6_ll_parsed = "fe80::52e:c2f2:bc5f:e9c6"
+ address_v6_ll_scoped_parsed = "fe80::52e:c2f2:bc5f:e9c6%12"
+ address_v6_ll = socket.inet_pton(socket.AF_INET6, address_v6_ll_parsed)
+ interface_index = 12
+ infos = [
+ ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[address, address_v6, address_v6_ll],
+ interface_index=interface_index,
+ ),
+ ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ parsed_addresses=[address_parsed, address_v6_parsed, address_v6_ll_parsed],
+ interface_index=interface_index,
+ ),
+ ]
+ for info in infos:
+ assert info.addresses == [address]
+ assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6, address_v6_ll]
+ assert info.addresses_by_version(r.IPVersion.V4Only) == [address]
+ assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6, address_v6_ll]
+ assert info.parsed_addresses() == [address_parsed, address_v6_parsed, address_v6_ll_parsed]
+ assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed]
+ assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed, address_v6_ll_parsed]
+ assert info.parsed_scoped_addresses() == [
+ address_v6_ll_scoped_parsed,
+ address_parsed,
+ address_v6_parsed,
+ ]
+ assert info.parsed_scoped_addresses(r.IPVersion.V4Only) == [address_parsed]
+ assert info.parsed_scoped_addresses(r.IPVersion.V6Only) == [
+ address_v6_ll_scoped_parsed,
+ address_v6_parsed,
+ ]
+
+
+# This test uses asyncio because it needs to access the cache directly
+# which is not threadsafe
+@pytest.mark.asyncio
+async def test_multiple_a_addresses():
+ type_ = "_http._tcp.local."
+ registration_name = "multiarec.%s" % type_
+ desc = {'path': '/~paulsm/'}
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ cache = aiozc.zeroconf.cache
+ host = "multahost.local."
+ record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'a')
+ record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'b')
+ cache.async_add_records([record1, record2])
+
+ # New kwarg way
+ info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host)
+ info.load_from_cache(aiozc.zeroconf)
+ assert set(info.addresses) == set([b'a', b'b'])
+ await aiozc.async_close()
+
+
+@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
+@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+def test_filter_address_by_type_from_service_info():
+ """Verify dns_addresses can filter by ipversion."""
+ desc = {'path': '/~paulsm/'}
+ type_ = "_homeassistant._tcp.local."
+ name = "MyTestHome"
+ registration_name = "%s.%s" % (name, type_)
+ ipv4 = socket.inet_aton("10.0.1.2")
+ ipv6 = socket.inet_pton(socket.AF_INET6, "2001:db8::1")
+ info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[ipv4, ipv6])
+
+ def dns_addresses_to_addresses(dns_address: List[DNSAddress]):
+ return [address.address for address in dns_address]
+
+ assert dns_addresses_to_addresses(info.dns_addresses()) == [ipv4, ipv6]
+ assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.All)) == [ipv4, ipv6]
+ assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V4Only)) == [ipv4]
+ assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V6Only)) == [ipv6]
+
+
+def test_changing_name_updates_serviceinfo_key():
+ """Verify a name change will adjust the underlying key value."""
+ type_ = "_homeassistant._tcp.local."
+ name = "MyTestHome"
+ info_service = ServiceInfo(
+ type_,
+ '%s.%s' % (name, type_),
+ 80,
+ 0,
+ 0,
+ {'path': '/~paulsm/'},
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ assert info_service.key == "mytesthome._homeassistant._tcp.local."
+ info_service.name = "YourTestHome._homeassistant._tcp.local."
+ assert info_service.key == "yourtesthome._homeassistant._tcp.local."
+
+
+def test_serviceinfo_address_updates():
+ """Verify adding/removing/setting addresses on ServiceInfo."""
+ type_ = "_homeassistant._tcp.local."
+ name = "MyTestHome"
+
+ # Verify addresses and parsed_addresses are mutually exclusive
+ with pytest.raises(TypeError):
+ info_service = ServiceInfo(
+ type_,
+ '%s.%s' % (name, type_),
+ 80,
+ 0,
+ 0,
+ {'path': '/~paulsm/'},
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ parsed_addresses=["10.0.1.2"],
+ )
+
+ info_service = ServiceInfo(
+ type_,
+ '%s.%s' % (name, type_),
+ 80,
+ 0,
+ 0,
+ {'path': '/~paulsm/'},
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ info_service.addresses = [socket.inet_aton("10.0.1.3")]
+ assert info_service.addresses == [socket.inet_aton("10.0.1.3")]
+
+
+def test_serviceinfo_accepts_bytes_or_string_dict():
+ """Verify a bytes or string dict can be passed to ServiceInfo."""
+ type_ = "_homeassistant._tcp.local."
+ name = "MyTestHome"
+ addresses = [socket.inet_aton("10.0.1.2")]
+ server_name = "ash-2.local."
+ info_service = ServiceInfo(
+ type_, '%s.%s' % (name, type_), 80, 0, 0, {b'path': b'/~paulsm/'}, server_name, addresses=addresses
+ )
+ assert info_service.dns_text().text == b'\x0epath=/~paulsm/'
+ info_service = ServiceInfo(
+ type_,
+ '%s.%s' % (name, type_),
+ 80,
+ 0,
+ 0,
+ {'path': '/~paulsm/'},
+ server_name,
+ addresses=addresses,
+ )
+ assert info_service.dns_text().text == b'\x0epath=/~paulsm/'
+ info_service = ServiceInfo(
+ type_,
+ '%s.%s' % (name, type_),
+ 80,
+ 0,
+ 0,
+ {b'path': '/~paulsm/'},
+ server_name,
+ addresses=addresses,
+ )
+ assert info_service.dns_text().text == b'\x0epath=/~paulsm/'
+ info_service = ServiceInfo(
+ type_,
+ '%s.%s' % (name, type_),
+ 80,
+ 0,
+ 0,
+ {'path': b'/~paulsm/'},
+ server_name,
+ addresses=addresses,
+ )
+ assert info_service.dns_text().text == b'\x0epath=/~paulsm/'
+
+
+def test_asking_qu_questions():
+ """Verify explictly asking QU questions."""
+ type_ = "_quservice._tcp.local."
+ zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
+
+ # we are going to patch the zeroconf send to check query transmission
+ old_send = zeroconf.async_send
+
+ first_outgoing = None
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT):
+ """Sends an outgoing packet."""
+ nonlocal first_outgoing
+ if first_outgoing is None:
+ first_outgoing = out
+ old_send(out, addr=addr, port=port)
+
+ # patch the zeroconf send
+ with patch.object(zeroconf, "async_send", send):
+ zeroconf.get_service_info(f"name.{type_}", type_, 500, question_type=r.DNSQuestionType.QU)
+ assert first_outgoing.questions[0].unicast == True
+ zeroconf.close()
+
+
+def test_asking_qm_questions():
+ """Verify explictly asking QM questions."""
+ type_ = "_quservice._tcp.local."
+ zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
+
+ # we are going to patch the zeroconf send to check query transmission
+ old_send = zeroconf.async_send
+
+ first_outgoing = None
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT):
+ """Sends an outgoing packet."""
+ nonlocal first_outgoing
+ if first_outgoing is None:
+ first_outgoing = out
+ old_send(out, addr=addr, port=port)
+
+ # patch the zeroconf send
+ with patch.object(zeroconf, "async_send", send):
+ zeroconf.get_service_info(f"name.{type_}", type_, 500, question_type=r.DNSQuestionType.QM)
+ assert first_outgoing.questions[0].unicast == False
+ zeroconf.close()
+
+
+def test_request_timeout():
+ """Test that the timeout does not throw an exception and finishes close to the actual timeout."""
+ zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
+ start_time = r.current_time_millis()
+ assert zeroconf.get_service_info("_notfound.local.", "notthere._notfound.local.") is None
+ end_time = r.current_time_millis()
+ zeroconf.close()
+ # 3000ms for the default timeout
+ # 1000ms for loaded systems + schedule overhead
+ assert (end_time - start_time) < 3000 + 1000
+
+
+@pytest.mark.asyncio
+async def test_we_try_four_times_with_random_delay():
+ """Verify we try four times even with the random delay."""
+ type_ = "_typethatisnothere._tcp.local."
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+
+ # we are going to patch the zeroconf send to check query transmission
+ request_count = 0
+
+ def async_send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT):
+ """Sends an outgoing packet."""
+ nonlocal request_count
+ request_count += 1
+
+ # patch the zeroconf send
+ with patch.object(aiozc.zeroconf, "async_send", async_send):
+ await aiozc.async_get_service_info(f"willnotbefound.{type_}", type_)
+
+ await aiozc.async_close()
+
+ assert request_count == 4
diff --git a/tests/services/test_registry.py b/tests/services/test_registry.py
new file mode 100644
index 00000000..3c105cbb
--- /dev/null
+++ b/tests/services/test_registry.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+
+"""Unit tests for zeroconf._services.registry."""
+
+import unittest
+import socket
+
+import zeroconf as r
+from zeroconf import ServiceInfo
+
+
+class TestServiceRegistry(unittest.TestCase):
+ def test_only_register_once(self):
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = "%s.%s" % (name, type_)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+
+ registry = r.ServiceRegistry()
+ registry.async_add(info)
+ self.assertRaises(r.ServiceNameAlreadyRegistered, registry.async_add, info)
+ registry.async_remove(info)
+ registry.async_add(info)
+
+ def test_register_same_server(self):
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ name2 = "xxxyyy2"
+ registration_name = "%s.%s" % (name, type_)
+ registration_name2 = "%s.%s" % (name2, type_)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "same.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info2 = ServiceInfo(
+ type_, registration_name2, 80, 0, 0, desc, "same.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ registry = r.ServiceRegistry()
+ registry.async_add(info)
+ registry.async_add(info2)
+ assert registry.async_get_infos_server("same.local.") == [info, info2]
+ registry.async_remove(info)
+ assert registry.async_get_infos_server("same.local.") == [info2]
+ registry.async_remove(info2)
+ assert registry.async_get_infos_server("same.local.") == []
+
+ def test_unregister_multiple_times(self):
+ """Verify we can unregister a service multiple times.
+
+ In production unregister_service and unregister_all_services
+ may happen at the same time during shutdown. We want to treat
+ this as non-fatal since its expected to happen and it is unlikely
+ that the callers know about each other.
+ """
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = "%s.%s" % (name, type_)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+
+ registry = r.ServiceRegistry()
+ registry.async_add(info)
+ self.assertRaises(r.ServiceNameAlreadyRegistered, registry.async_add, info)
+ registry.async_remove(info)
+ registry.async_remove(info)
+
+ def test_lookups(self):
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = "%s.%s" % (name, type_)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+
+ registry = r.ServiceRegistry()
+ registry.async_add(info)
+
+ assert registry.async_get_service_infos() == [info]
+ assert registry.async_get_info_name(registration_name) == info
+ assert registry.async_get_infos_type(type_) == [info]
+ assert registry.async_get_infos_server("ash-2.local.") == [info]
+ assert registry.async_get_types() == [type_]
+
+ def test_lookups_upper_case_by_lower_case(self):
+ type_ = "_test-SRVC-type._tcp.local."
+ name = "Xxxyyy"
+ registration_name = "%s.%s" % (name, type_)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ASH-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+
+ registry = r.ServiceRegistry()
+ registry.async_add(info)
+
+ assert registry.async_get_service_infos() == [info]
+ assert registry.async_get_info_name(registration_name.lower()) == info
+ assert registry.async_get_infos_type(type_.lower()) == [info]
+ assert registry.async_get_infos_server("ash-2.local.") == [info]
+ assert registry.async_get_types() == [type_.lower()]
+
+ def test_lookups_lower_case_by_upper_case(self):
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = "%s.%s" % (name, type_)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+
+ registry = r.ServiceRegistry()
+ registry.async_add(info)
+
+ assert registry.async_get_service_infos() == [info]
+ assert registry.async_get_info_name(registration_name.upper()) == info
+ assert registry.async_get_infos_type(type_.upper()) == [info]
+ assert registry.async_get_infos_server("ASH-2.local.") == [info]
+ assert registry.async_get_types() == [type_]
diff --git a/tests/services/test_types.py b/tests/services/test_types.py
new file mode 100644
index 00000000..b1c312db
--- /dev/null
+++ b/tests/services/test_types.py
@@ -0,0 +1,177 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+
+"""Unit tests for zeroconf._services.types."""
+
+import logging
+import os
+import unittest
+import socket
+import sys
+from unittest.mock import patch
+
+import zeroconf as r
+from zeroconf import Zeroconf, ServiceInfo, ZeroconfServiceTypes
+
+from .. import _clear_cache, has_working_ipv6
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+class ServiceTypesQuery(unittest.TestCase):
+ def test_integration_with_listener(self):
+
+ type_ = "_test-listen-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = "%s.%s" % (name, type_)
+
+ zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1'])
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ zeroconf_registrar.registry.async_add(info)
+ try:
+ with patch.object(
+ zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False
+ ), patch.object(
+ zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False
+ ):
+ service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=2)
+ assert type_ in service_types
+ _clear_cache(zeroconf_registrar)
+ service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2)
+ assert type_ in service_types
+
+ finally:
+ zeroconf_registrar.close()
+
+ @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
+ @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+ def test_integration_with_listener_v6_records(self):
+
+ type_ = "_test-listenv6rec-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = "%s.%s" % (name, type_)
+ addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com
+
+ zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1'])
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_pton(socket.AF_INET6, addr)],
+ )
+ zeroconf_registrar.registry.async_add(info)
+ try:
+ with patch.object(
+ zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False
+ ), patch.object(
+ zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False
+ ):
+ service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=2)
+ assert type_ in service_types
+ _clear_cache(zeroconf_registrar)
+ service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2)
+ assert type_ in service_types
+
+ finally:
+ zeroconf_registrar.close()
+
+ @unittest.skipIf(not has_working_ipv6() or sys.platform == 'win32', 'Requires IPv6')
+ @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+ def test_integration_with_listener_ipv6(self):
+
+ type_ = "_test-listenv6ip-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = "%s.%s" % (name, type_)
+ addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com
+
+ zeroconf_registrar = Zeroconf(ip_version=r.IPVersion.V6Only)
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_pton(socket.AF_INET6, addr)],
+ )
+ zeroconf_registrar.registry.async_add(info)
+ try:
+ with patch.object(
+ zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False
+ ), patch.object(
+ zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False
+ ):
+ service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=2)
+ assert type_ in service_types
+ _clear_cache(zeroconf_registrar)
+ service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2)
+ assert type_ in service_types
+
+ finally:
+ zeroconf_registrar.close()
+
+ def test_integration_with_subtype_and_listener(self):
+ subtype_ = "_subtype._sub"
+ type_ = "_listen._tcp.local."
+ name = "xxxyyy"
+ # Note: discovery returns only DNS-SD type not subtype
+ discovery_type = "%s.%s" % (subtype_, type_)
+ registration_name = "%s.%s" % (name, type_)
+
+ zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1'])
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ discovery_type,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ zeroconf_registrar.registry.async_add(info)
+ try:
+ with patch.object(
+ zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False
+ ), patch.object(
+ zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False
+ ):
+ service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=2)
+ assert discovery_type in service_types
+ _clear_cache(zeroconf_registrar)
+ service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2)
+ assert discovery_type in service_types
+
+ finally:
+ zeroconf_registrar.close()
diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py
new file mode 100644
index 00000000..ea80d6f5
--- /dev/null
+++ b/tests/test_asyncio.py
@@ -0,0 +1,1124 @@
+#!/usr/bin/env python
+
+
+"""Unit tests for aio.py."""
+
+import asyncio
+import logging
+import os
+import socket
+import time
+import threading
+from unittest.mock import ANY, call, patch, MagicMock
+
+
+import pytest
+
+from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes
+from zeroconf import (
+ DNSIncoming,
+ DNSOutgoing,
+ DNSQuestion,
+ DNSPointer,
+ DNSService,
+ DNSAddress,
+ DNSText,
+ ServiceStateChange,
+ Zeroconf,
+ const,
+)
+from zeroconf.const import _LISTENER_TIME
+from zeroconf._core import AsyncListener
+from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered
+from zeroconf._services import ServiceListener
+import zeroconf._services.browser as _services_browser
+from zeroconf._services.info import ServiceInfo
+from zeroconf._utils.time import current_time_millis
+
+from . import _clear_cache, has_working_ipv6
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+@pytest.fixture(autouse=True)
+def verify_threads_ended():
+ """Verify that the threads are not running after the test."""
+ threads_before = frozenset(threading.enumerate())
+ yield
+ threads_after = frozenset(threading.enumerate())
+ non_executor_threads = frozenset(
+ thread
+ for thread in threads_after
+ if "asyncio" not in thread.name and "ThreadPoolExecutor" not in thread.name
+ )
+ threads = non_executor_threads - threads_before
+ assert not threads
+
+
+@pytest.mark.asyncio
+async def test_async_basic_usage() -> None:
+ """Test we can create and close the instance."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_async_close_twice() -> None:
+ """Test we can close twice."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ await aiozc.async_close()
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_async_with_sync_passed_in() -> None:
+ """Test we can create and close the instance when passing in a sync Zeroconf."""
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ aiozc = AsyncZeroconf(zc=zc)
+ assert aiozc.zeroconf is zc
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_async_with_sync_passed_in_closed_in_async() -> None:
+ """Test caller closes the sync version in async."""
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ aiozc = AsyncZeroconf(zc=zc)
+ assert aiozc.zeroconf is zc
+ zc.close()
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_sync_within_event_loop_executor() -> None:
+ """Test sync version still works from an executor within an event loop."""
+
+ def sync_code():
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ assert zc.get_service_info("_neverused._tcp.local.", "xneverused._neverused._tcp.local.", 10) is None
+ zc.close()
+
+ await asyncio.get_event_loop().run_in_executor(None, sync_code)
+
+
+@pytest.mark.asyncio
+async def test_async_service_registration() -> None:
+ """Test registering services broadcasts the registration by default."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ type_ = "_test1-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ calls = []
+
+ class MyListener(ServiceListener):
+ def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("add", type, name))
+
+ def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("remove", type, name))
+
+ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("update", type, name))
+
+ listener = MyListener()
+
+ aiozc.zeroconf.add_service_listener(type_, listener)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ task = await aiozc.async_register_service(info)
+ await task
+ new_info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.3")],
+ )
+ task = await aiozc.async_update_service(new_info)
+ await task
+ task = await aiozc.async_unregister_service(new_info)
+ await task
+ await aiozc.async_close()
+
+ assert calls == [
+ ('add', type_, registration_name),
+ ('update', type_, registration_name),
+ ('remove', type_, registration_name),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_async_service_registration_same_server_different_ports() -> None:
+ """Test registering services with the same server with different srv records."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ type_ = "_test1-srvc-type._tcp.local."
+ name = "xxxyyy"
+ name2 = "xxxyyy2"
+
+ registration_name = f"{name}.{type_}"
+ registration_name2 = f"{name2}.{type_}"
+
+ calls = []
+
+ class MyListener(ServiceListener):
+ def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("add", type, name))
+
+ def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("remove", type, name))
+
+ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("update", type, name))
+
+ listener = MyListener()
+
+ aiozc.zeroconf.add_service_listener(type_, listener)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ info2 = ServiceInfo(
+ type_,
+ registration_name2,
+ 81,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ tasks = []
+ tasks.append(await aiozc.async_register_service(info))
+ tasks.append(await aiozc.async_register_service(info2))
+ await asyncio.gather(*tasks)
+
+ task = await aiozc.async_unregister_service(info)
+ await task
+ entries = aiozc.zeroconf.cache.async_entries_with_server("ash-2.local.")
+ assert len(entries) == 1
+ assert info2.dns_service() in entries
+ await aiozc.async_close()
+ assert calls == [
+ ('add', type_, registration_name),
+ ('add', type_, registration_name2),
+ ('remove', type_, registration_name),
+ ('remove', type_, registration_name2),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_async_service_registration_same_server_same_ports() -> None:
+ """Test registering services with the same server with the exact same srv record."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ type_ = "_test1-srvc-type._tcp.local."
+ name = "xxxyyy"
+ name2 = "xxxyyy2"
+
+ registration_name = f"{name}.{type_}"
+ registration_name2 = f"{name2}.{type_}"
+
+ calls = []
+
+ class MyListener(ServiceListener):
+ def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("add", type, name))
+
+ def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("remove", type, name))
+
+ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("update", type, name))
+
+ listener = MyListener()
+
+ aiozc.zeroconf.add_service_listener(type_, listener)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ info2 = ServiceInfo(
+ type_,
+ registration_name2,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ tasks = []
+ tasks.append(await aiozc.async_register_service(info))
+ tasks.append(await aiozc.async_register_service(info2))
+ await asyncio.gather(*tasks)
+
+ task = await aiozc.async_unregister_service(info)
+ await task
+ entries = aiozc.zeroconf.cache.async_entries_with_server("ash-2.local.")
+ assert len(entries) == 1
+ assert info2.dns_service() in entries
+ await aiozc.async_close()
+ assert calls == [
+ ('add', type_, registration_name),
+ ('add', type_, registration_name2),
+ ('remove', type_, registration_name),
+ ('remove', type_, registration_name2),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_async_service_registration_name_conflict() -> None:
+ """Test registering services throws on name conflict."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ type_ = "_test-srvc2-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ task = await aiozc.async_register_service(info)
+ await task
+
+ with pytest.raises(NonUniqueNameException):
+ task = await aiozc.async_register_service(info)
+ await task
+
+ with pytest.raises(ServiceNameAlreadyRegistered):
+ task = await aiozc.async_register_service(info, cooperating_responders=True)
+ await task
+
+ conflicting_info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-3.local.",
+ addresses=[socket.inet_aton("10.0.1.3")],
+ )
+
+ with pytest.raises(NonUniqueNameException):
+ task = await aiozc.async_register_service(conflicting_info)
+ await task
+
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_async_service_registration_name_does_not_match_type() -> None:
+ """Test registering services throws when the name does not match the type."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ type_ = "_test-srvc3-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ info.type = "_wrong._tcp.local."
+ with pytest.raises(BadTypeInNameException):
+ task = await aiozc.async_register_service(info)
+ await task
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_async_tasks() -> None:
+ """Test awaiting broadcast tasks"""
+
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ type_ = "_test-srvc4-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ calls = []
+
+ class MyListener(ServiceListener):
+ def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("add", type, name))
+
+ def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("remove", type, name))
+
+ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None:
+ calls.append(("update", type, name))
+
+ listener = MyListener()
+ aiozc.zeroconf.add_service_listener(type_, listener)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ task = await aiozc.async_register_service(info)
+ assert isinstance(task, asyncio.Task)
+ await task
+
+ new_info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.3")],
+ )
+ task = await aiozc.async_update_service(new_info)
+ assert isinstance(task, asyncio.Task)
+ await task
+
+ task = await aiozc.async_unregister_service(new_info)
+ assert isinstance(task, asyncio.Task)
+ await task
+
+ await aiozc.async_close()
+
+ assert calls == [
+ ('add', type_, registration_name),
+ ('update', type_, registration_name),
+ ('remove', type_, registration_name),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_async_wait_unblocks_on_update() -> None:
+ """Test async_wait will unblock on update."""
+
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ type_ = "_test-srvc4-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ task = await aiozc.async_register_service(info)
+
+ # Should unblock due to update from the
+ # registration
+ now = current_time_millis()
+ await aiozc.zeroconf.async_wait(50000)
+ assert current_time_millis() - now < 3000
+ await task
+
+ now = current_time_millis()
+ await aiozc.zeroconf.async_wait(50)
+ assert current_time_millis() - now < 1000
+
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_service_info_async_request() -> None:
+ """Test registering services broadcasts and query with AsyncServceInfo.async_request."""
+ if not has_working_ipv6() or os.environ.get('SKIP_IPV6'):
+ pytest.skip('Requires IPv6')
+
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ type_ = "_test1-srvc-type._tcp.local."
+ name = "xxxyyy"
+ name2 = "abc"
+ registration_name = f"{name}.{type_}"
+ registration_name2 = f"{name2}.{type_}"
+
+ # Start a tasks BEFORE the registration that will keep trying
+ # and see the registration a bit later
+ get_service_info_task1 = asyncio.ensure_future(aiozc.async_get_service_info(type_, registration_name))
+ await asyncio.sleep(_LISTENER_TIME / 1000 / 2)
+ get_service_info_task2 = asyncio.ensure_future(aiozc.async_get_service_info(type_, registration_name))
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-1.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ info2 = ServiceInfo(
+ type_,
+ registration_name2,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-5.local.",
+ addresses=[socket.inet_aton("10.0.1.5")],
+ )
+ tasks = []
+ tasks.append(await aiozc.async_register_service(info))
+ tasks.append(await aiozc.async_register_service(info2))
+ await asyncio.gather(*tasks)
+
+ aiosinfo = await get_service_info_task1
+ assert aiosinfo is not None
+ assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")]
+
+ aiosinfo = await get_service_info_task2
+ assert aiosinfo is not None
+ assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")]
+
+ aiosinfo = await aiozc.async_get_service_info(type_, registration_name)
+ assert aiosinfo is not None
+ assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")]
+
+ new_info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.3"), socket.inet_pton(socket.AF_INET6, "6001:db8::1")],
+ )
+
+ task = await aiozc.async_update_service(new_info)
+ await task
+
+ aiosinfo = await aiozc.async_get_service_info(type_, registration_name)
+ assert aiosinfo is not None
+ assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")]
+
+ aiosinfos = await asyncio.gather(
+ aiozc.async_get_service_info(type_, registration_name),
+ aiozc.async_get_service_info(type_, registration_name2),
+ )
+ assert aiosinfos[0] is not None
+ assert aiosinfos[0].addresses == [socket.inet_aton("10.0.1.3")]
+ assert aiosinfos[1] is not None
+ assert aiosinfos[1].addresses == [socket.inet_aton("10.0.1.5")]
+
+ aiosinfo = AsyncServiceInfo(type_, registration_name)
+ _clear_cache(aiozc.zeroconf)
+ # Generating the race condition is almost impossible
+ # without patching since its a TOCTOU race
+ with patch("zeroconf.asyncio.AsyncServiceInfo._is_complete", False):
+ await aiosinfo.async_request(aiozc.zeroconf, 3000)
+ assert aiosinfo is not None
+ assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")]
+
+ task = await aiozc.async_unregister_service(new_info)
+ await task
+
+ aiosinfo = await aiozc.async_get_service_info(type_, registration_name)
+ assert aiosinfo is None
+
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_async_service_browser() -> None:
+ """Test AsyncServiceBrowser."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ type_ = "_test9-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ calls = []
+
+ class MyListener(ServiceListener):
+ def add_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None:
+ calls.append(("add", type, name))
+
+ def remove_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None:
+ calls.append(("remove", type, name))
+
+ def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None:
+ calls.append(("update", type, name))
+
+ listener = MyListener()
+ await aiozc.async_add_service_listener(type_, listener)
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ task = await aiozc.async_register_service(info)
+ await task
+ new_info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.3")],
+ )
+ task = await aiozc.async_update_service(new_info)
+ await task
+ task = await aiozc.async_unregister_service(new_info)
+ await task
+ await aiozc.zeroconf.async_wait(1)
+ await aiozc.async_close()
+
+ assert calls == [
+ ('add', type_, registration_name),
+ ('update', type_, registration_name),
+ ('remove', type_, registration_name),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_async_context_manager() -> None:
+ """Test using an async context manager."""
+ type_ = "_test10-sr-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ async with AsyncZeroconf(interfaces=['127.0.0.1']) as aiozc:
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ {'path': '/~paulsm/'},
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ task = await aiozc.async_register_service(info)
+ await task
+ aiosinfo = await aiozc.async_get_service_info(type_, registration_name)
+ assert aiosinfo is not None
+
+
+@pytest.mark.asyncio
+async def test_async_unregister_all_services() -> None:
+ """Test unregistering all services."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ type_ = "_test1-srvc-type._tcp.local."
+ name = "xxxyyy"
+ name2 = "abc"
+ registration_name = f"{name}.{type_}"
+ registration_name2 = f"{name2}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-1.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ info2 = ServiceInfo(
+ type_,
+ registration_name2,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-5.local.",
+ addresses=[socket.inet_aton("10.0.1.5")],
+ )
+ tasks = []
+ tasks.append(await aiozc.async_register_service(info))
+ tasks.append(await aiozc.async_register_service(info2))
+ await asyncio.gather(*tasks)
+
+ tasks = []
+ tasks.append(aiozc.async_get_service_info(type_, registration_name))
+ tasks.append(aiozc.async_get_service_info(type_, registration_name2))
+ results = await asyncio.gather(*tasks)
+ assert results[0] is not None
+ assert results[1] is not None
+
+ await aiozc.async_unregister_all_services()
+ _clear_cache(aiozc.zeroconf)
+
+ tasks = []
+ tasks.append(aiozc.async_get_service_info(type_, registration_name))
+ tasks.append(aiozc.async_get_service_info(type_, registration_name2))
+ results = await asyncio.gather(*tasks)
+ assert results[0] is None
+ assert results[1] is None
+
+ # Verify we can call again
+ await aiozc.async_unregister_all_services()
+
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_async_zeroconf_service_types():
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ zeroconf_registrar = AsyncZeroconf(interfaces=['127.0.0.1'])
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ task = await zeroconf_registrar.async_register_service(info)
+ await task
+ # Ensure we do not clear the cache until after the last broadcast is processed
+ await asyncio.sleep(0.2)
+ _clear_cache(zeroconf_registrar.zeroconf)
+ try:
+ service_types = await AsyncZeroconfServiceTypes.async_find(interfaces=['127.0.0.1'], timeout=2)
+ assert type_ in service_types
+ _clear_cache(zeroconf_registrar.zeroconf)
+ service_types = await AsyncZeroconfServiceTypes.async_find(aiozc=zeroconf_registrar, timeout=2)
+ assert type_ in service_types
+
+ finally:
+ await zeroconf_registrar.async_close()
+
+
+@pytest.mark.asyncio
+async def test_guard_against_running_serviceinfo_request_event_loop() -> None:
+ """Test that running ServiceInfo.request from the event loop throws."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+
+ service_info = AsyncServiceInfo("_hap._tcp.local.", "doesnotmatter._hap._tcp.local.")
+ with pytest.raises(RuntimeError):
+ service_info.request(aiozc.zeroconf, 3000)
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_service_browser_instantiation_generates_add_events_from_cache():
+ """Test that the ServiceBrowser will generate Add events with the existing cache when starting."""
+
+ # instantiate a zeroconf instance
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zc = aiozc.zeroconf
+ type_ = "_hap._tcp.local."
+ registration_name = "xxxyyy.%s" % type_
+ callbacks = []
+
+ class MyServiceListener(ServiceListener):
+ def add_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("add", type_, name))
+
+ def remove_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("remove", type_, name))
+
+ def update_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("update", type_, name))
+
+ listener = MyServiceListener()
+
+ desc = {'path': '/~paulsm/'}
+ address_parsed = "10.0.1.2"
+ address = socket.inet_aton(address_parsed)
+ info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address])
+ zc.cache.async_add_records(
+ [info.dns_pointer(), info.dns_service(), *info.dns_addresses(), info.dns_text()]
+ )
+
+ browser = AsyncServiceBrowser(zc, type_, None, listener)
+
+ await asyncio.sleep(0)
+
+ assert callbacks == [
+ ('add', type_, registration_name),
+ ]
+ await browser.async_cancel()
+
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_integration():
+ service_added = asyncio.Event()
+ service_removed = asyncio.Event()
+ unexpected_ttl = asyncio.Event()
+ got_query = asyncio.Event()
+
+ type_ = "_http._tcp.local."
+ registration_name = "xxxyyy.%s" % type_
+
+ def on_service_state_change(zeroconf, service_type, state_change, name):
+ if name == registration_name:
+ if state_change is ServiceStateChange.Added:
+ service_added.set()
+ elif state_change is ServiceStateChange.Removed:
+ service_removed.set()
+
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zeroconf_browser = aiozc.zeroconf
+ await zeroconf_browser.async_wait_for_start()
+
+ # we are going to patch the zeroconf send to check packet sizes
+ old_send = zeroconf_browser.async_send
+
+ time_offset = 0.0
+
+ def _new_current_time_millis():
+ """Current system time in milliseconds"""
+ return (time.time() * 1000) + (time_offset * 1000)
+
+ expected_ttl = const._DNS_HOST_TTL
+ nbr_answers = 0
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
+ """Sends an outgoing packet."""
+ pout = DNSIncoming(out.packets()[0])
+ nonlocal nbr_answers
+ for answer in pout.answers:
+ nbr_answers += 1
+ if not answer.ttl > expected_ttl / 2:
+ unexpected_ttl.set()
+
+ got_query.set()
+
+ old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope)
+
+ assert len(zeroconf_browser.engine.protocols) == 2
+
+ aio_zeroconf_registrar = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zeroconf_registrar = aio_zeroconf_registrar.zeroconf
+ await aio_zeroconf_registrar.zeroconf.async_wait_for_start()
+
+ assert len(zeroconf_registrar.engine.protocols) == 2
+ # patch the zeroconf send
+ # patch the zeroconf current_time_millis
+ # patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL
+ # Disable duplicate question suppression and duplicate packet suppression for this test as it works
+ # by asking the same question over and over
+ with patch.object(
+ zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False
+ ), patch.object(
+ zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False
+ ), patch.object(
+ zeroconf_browser.engine.protocols[0], "suppress_duplicate_packet", return_value=False
+ ), patch.object(
+ zeroconf_browser.engine.protocols[1], "suppress_duplicate_packet", return_value=False
+ ), patch.object(
+ zeroconf_browser.question_history, "suppresses", return_value=False
+ ), patch.object(
+ zeroconf_browser, "async_send", send
+ ), patch(
+ "zeroconf._services.browser.current_time_millis", _new_current_time_millis
+ ), patch.object(
+ _services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)
+ ):
+ service_added = asyncio.Event()
+ service_removed = asyncio.Event()
+
+ browser = AsyncServiceBrowser(zeroconf_browser, type_, [on_service_state_change])
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ task = await aio_zeroconf_registrar.async_register_service(info)
+ await task
+
+ try:
+ await asyncio.wait_for(service_added.wait(), 1)
+ assert service_added.is_set()
+
+ # Test that we receive queries containing answers only if the remaining TTL
+ # is greater than half the original TTL
+ sleep_count = 0
+ test_iterations = 50
+
+ while nbr_answers < test_iterations:
+ # Increase simulated time shift by 1/4 of the TTL in seconds
+ time_offset += expected_ttl / 4
+ now = _new_current_time_millis()
+ browser.reschedule_type(type_, now)
+ sleep_count += 1
+ await asyncio.wait_for(got_query.wait(), 1)
+ got_query.clear()
+ # Prevent the test running indefinitely in an error condition
+ assert sleep_count < test_iterations * 4
+ assert not unexpected_ttl.is_set()
+ # Don't remove service, allow close() to cleanup
+ finally:
+ await aio_zeroconf_registrar.async_close()
+ await asyncio.wait_for(service_removed.wait(), 1)
+ assert service_removed.is_set()
+ await browser.async_cancel()
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_info_asking_default_is_asking_qm_questions_after_the_first_qu():
+ """Verify the service info first question is QU and subsequent ones are QM questions."""
+ type_ = "_quservice._tcp.local."
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zeroconf_info = aiozc.zeroconf
+
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+
+ zeroconf_info.registry.async_add(info)
+
+ # we are going to patch the zeroconf send to check query transmission
+ old_send = zeroconf_info.async_send
+
+ first_outgoing = None
+ second_outgoing = None
+
+ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT):
+ """Sends an outgoing packet."""
+ nonlocal first_outgoing
+ nonlocal second_outgoing
+ if out.questions:
+ if first_outgoing is not None and second_outgoing is None:
+ second_outgoing = out
+ if first_outgoing is None:
+ first_outgoing = out
+ old_send(out, addr=addr, port=port)
+
+ # patch the zeroconf send
+ with patch.object(zeroconf_info, "async_send", send):
+ aiosinfo = AsyncServiceInfo(type_, registration_name)
+ # Patch _is_complete so we send multiple times
+ with patch("zeroconf.asyncio.AsyncServiceInfo._is_complete", False):
+ await aiosinfo.async_request(aiozc.zeroconf, 1200)
+ try:
+ assert first_outgoing.questions[0].unicast == True
+ assert second_outgoing.questions[0].unicast == False
+ finally:
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_service_browser_ignores_unrelated_updates():
+ """Test that the ServiceBrowser ignores unrelated updates."""
+
+ # instantiate a zeroconf instance
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zc = aiozc.zeroconf
+ type_ = "_veryuniqueone._tcp.local."
+ registration_name = "xxxyyy.%s" % type_
+ callbacks = []
+
+ class MyServiceListener(ServiceListener):
+ def add_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("add", type_, name))
+
+ def remove_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("remove", type_, name))
+
+ def update_service(self, zc, type_, name) -> None:
+ nonlocal callbacks
+ if name == registration_name:
+ callbacks.append(("update", type_, name))
+
+ listener = MyServiceListener()
+
+ desc = {'path': '/~paulsm/'}
+ address_parsed = "10.0.1.2"
+ address = socket.inet_aton(address_parsed)
+ info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address])
+ zc.cache.async_add_records(
+ [
+ info.dns_pointer(),
+ info.dns_service(),
+ *info.dns_addresses(),
+ info.dns_text(),
+ DNSService(
+ "zoom._unrelated._tcp.local.",
+ const._TYPE_SRV,
+ const._CLASS_IN,
+ const._DNS_HOST_TTL,
+ 0,
+ 0,
+ 81,
+ 'unrelated.local.',
+ ),
+ ]
+ )
+
+ browser = AsyncServiceBrowser(zc, type_, None, listener)
+
+ generated = DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ generated.add_answer_at_time(
+ DNSPointer(
+ "_unrelated._tcp.local.",
+ const._TYPE_PTR,
+ const._CLASS_IN,
+ const._DNS_OTHER_TTL,
+ "zoom._unrelated._tcp.local.",
+ ),
+ 0,
+ )
+ generated.add_answer_at_time(
+ DNSAddress("unrelated.local.", const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b"1234"),
+ 0,
+ )
+ generated.add_answer_at_time(
+ DNSText(
+ "zoom._unrelated._tcp.local.",
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ b"zoom",
+ ),
+ 0,
+ )
+
+ zc.handle_response(DNSIncoming(generated.packets()[0]))
+
+ await browser.async_cancel()
+ await asyncio.sleep(0)
+
+ assert callbacks == [
+ ('add', type_, registration_name),
+ ]
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_async_request_timeout():
+ """Test that the timeout does not throw an exception and finishes close to the actual timeout."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ await aiozc.zeroconf.async_wait_for_start()
+ start_time = current_time_millis()
+ assert await aiozc.async_get_service_info("_notfound.local.", "notthere._notfound.local.") is None
+ end_time = current_time_millis()
+ await aiozc.async_close()
+ # 3000ms for the default timeout
+ # 1000ms for loaded systems + schedule overhead
+ assert (end_time - start_time) < 3000 + 1000
+
+
+@pytest.mark.asyncio
+async def test_legacy_unicast_response(run_isolated):
+ """Verify legacy unicast responses include questions and correct id."""
+ type_ = "_mservice._tcp.local."
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ await aiozc.zeroconf.async_wait_for_start()
+
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+
+ aiozc.zeroconf.registry.async_add(info)
+ query = DNSOutgoing(const._FLAGS_QR_QUERY, multicast=False, id_=888)
+ question = DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ query.add_question(question)
+ protocol = aiozc.zeroconf.engine.protocols[0]
+
+ with patch.object(aiozc.zeroconf, "async_send") as send_mock:
+ protocol.datagram_received(query.packets()[0], ('127.0.0.1', 6503))
+
+ calls = send_mock.mock_calls
+ # Verify the response is sent back on the socket it was recieved from
+ assert calls == [call(ANY, '127.0.0.1', 6503, (), protocol.transport)]
+ outgoing = send_mock.call_args[0][0]
+ assert isinstance(outgoing, DNSOutgoing)
+ assert outgoing.questions == [question]
+ assert outgoing.id == query.id
+ await aiozc.async_close()
diff --git a/tests/test_cache.py b/tests/test_cache.py
new file mode 100644
index 00000000..559b4357
--- /dev/null
+++ b/tests/test_cache.py
@@ -0,0 +1,200 @@
+#!/usr/bin/env python
+
+
+""" Unit tests for zeroconf._cache. """
+
+import logging
+import unittest
+import unittest.mock
+
+import zeroconf as r
+from zeroconf import const
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+class TestDNSCache(unittest.TestCase):
+ def test_order(self):
+ record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
+ record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN)
+ cached_record = cache.get(entry)
+ assert cached_record == record2
+
+ def test_adding_same_record_to_cache_different_ttls(self):
+ """We should always get back the last entry we added if there are different TTLs.
+
+ This ensures we only have one source of truth for TTLs as a record cannot
+ be both expired and not expired.
+ """
+ record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
+ record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a')
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ entry = r.DNSEntry(record2)
+ cached_record = cache.get(entry)
+ assert cached_record == record2
+
+ def test_adding_same_record_to_cache_different_ttls(self):
+ """Verify we only get one record back.
+
+ The last record added should replace the previous since two
+ records with different ttls are __eq__. This ensures we
+ only have one source of truth for TTLs as a record cannot
+ be both expired and not expired.
+ """
+ record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
+ record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a')
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)
+ assert cached_records == [record2]
+
+ def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self):
+ record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
+ record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert 'a' in cache.cache
+ cache.async_remove_records([record1, record2])
+ assert 'a' not in cache.cache
+
+ def test_cache_empty_multiple_calls(self):
+ record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
+ record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert 'a' in cache.cache
+ cache.async_remove_records([record1, record2])
+ assert 'a' not in cache.cache
+
+
+class TestDNSAsyncCacheAPI(unittest.TestCase):
+ def test_async_get_unique(self):
+ record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
+ record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert cache.async_get_unique(record1) == record1
+ assert cache.async_get_unique(record2) == record2
+
+ def test_async_all_by_details(self):
+ record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
+ record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert set(cache.async_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == {record1, record2}
+
+ def test_async_entries_with_server(self):
+ record1 = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab'
+ )
+ record2 = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
+ )
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert set(cache.async_entries_with_server('ab')) == {record1, record2}
+ assert set(cache.async_entries_with_server('AB')) == {record1, record2}
+
+ def test_async_entries_with_name(self):
+ record1 = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab'
+ )
+ record2 = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
+ )
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert set(cache.async_entries_with_name('irrelevant')) == {record1, record2}
+ assert set(cache.async_entries_with_name('Irrelevant')) == {record1, record2}
+
+
+# These functions have been seen in other projects so
+# we try to maintain a stable API for all the threadsafe getters
+class TestDNSCacheAPI(unittest.TestCase):
+ def test_get(self):
+ record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
+ record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
+ record3 = r.DNSAddress('a', const._TYPE_AAAA, const._CLASS_IN, 1, b'ipv6')
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2, record3])
+ assert cache.get(record1) == record1
+ assert cache.get(record2) == record2
+ assert cache.get(r.DNSEntry('a', const._TYPE_A, const._CLASS_IN)) == record2
+ assert cache.get(r.DNSEntry('a', const._TYPE_AAAA, const._CLASS_IN)) == record3
+ assert cache.get(r.DNSEntry('notthere', const._TYPE_A, const._CLASS_IN)) is None
+
+ def test_get_by_details(self):
+ record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
+ record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert cache.get_by_details('a', const._TYPE_A, const._CLASS_IN) == record2
+
+ def test_get_all_by_details(self):
+ record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
+ record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert set(cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == {record1, record2}
+
+ def test_entries_with_server(self):
+ record1 = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab'
+ )
+ record2 = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
+ )
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert set(cache.entries_with_server('ab')) == {record1, record2}
+ assert set(cache.entries_with_server('AB')) == {record1, record2}
+
+ def test_entries_with_name(self):
+ record1 = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab'
+ )
+ record2 = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
+ )
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert set(cache.entries_with_name('irrelevant')) == {record1, record2}
+ assert set(cache.entries_with_name('Irrelevant')) == {record1, record2}
+
+ def test_current_entry_with_name_and_alias(self):
+ record1 = r.DNSPointer(
+ 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'x.irrelevant'
+ )
+ record2 = r.DNSPointer(
+ 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'y.irrelevant'
+ )
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert cache.current_entry_with_name_and_alias('irrelevant', 'x.irrelevant') == record1
+
+ def test_name(self):
+ record1 = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab'
+ )
+ record2 = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
+ )
+ cache = r.DNSCache()
+ cache.async_add_records([record1, record2])
+ assert cache.names() == ['irrelevant']
diff --git a/tests/test_core.py b/tests/test_core.py
new file mode 100644
index 00000000..eab769be
--- /dev/null
+++ b/tests/test_core.py
@@ -0,0 +1,800 @@
+#!/usr/bin/env python
+
+
+""" Unit tests for zeroconf._core """
+
+import asyncio
+import itertools
+import logging
+import os
+import pytest
+import socket
+import sys
+import time
+import threading
+import unittest
+import unittest.mock
+from typing import cast
+from unittest.mock import patch
+
+import zeroconf as r
+from zeroconf import _core, const, Zeroconf, current_time_millis
+from zeroconf.asyncio import AsyncZeroconf
+from zeroconf._protocol import outgoing
+
+from . import has_working_ipv6, _clear_cache, _inject_response, _wait_for_start
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+def threadsafe_query(zc, protocol, *args):
+ async def make_query():
+ protocol.handle_query_or_defer(*args)
+
+ asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result()
+
+
+# This test uses asyncio because it needs to access the cache directly
+# which is not threadsafe
+@pytest.mark.asyncio
+async def test_reaper():
+ with patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10):
+ assert _core._CACHE_CLEANUP_INTERVAL == 10
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zeroconf = aiozc.zeroconf
+ cache = zeroconf.cache
+ original_entries = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names())))
+ record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a')
+ record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
+ zeroconf.cache.async_add_records([record_with_10s_ttl, record_with_1s_ttl])
+ question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN)
+ now = r.current_time_millis()
+ other_known_answers = {
+ r.DNSPointer(
+ "_hap._tcp.local.",
+ const._TYPE_PTR,
+ const._CLASS_IN,
+ 10000,
+ 'known-to-other._hap._tcp.local.',
+ )
+ }
+ zeroconf.question_history.add_question_at_time(question, now, other_known_answers)
+ assert zeroconf.question_history.suppresses(question, now, other_known_answers)
+ entries_with_cache = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names())))
+ await asyncio.sleep(1.2)
+ entries = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names())))
+ assert zeroconf.cache.get(record_with_1s_ttl) is None
+ await aiozc.async_close()
+ assert not zeroconf.question_history.suppresses(question, now, other_known_answers)
+ assert entries != original_entries
+ assert entries_with_cache != original_entries
+ assert record_with_10s_ttl in entries
+ assert record_with_1s_ttl not in entries
+
+
+@pytest.mark.asyncio
+async def test_reaper_aborts_when_done():
+ """Ensure cache cleanup stops when zeroconf is done."""
+ with patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10):
+ assert _core._CACHE_CLEANUP_INTERVAL == 10
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zeroconf = aiozc.zeroconf
+ record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a')
+ record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
+ zeroconf.cache.async_add_records([record_with_10s_ttl, record_with_1s_ttl])
+ assert zeroconf.cache.get(record_with_10s_ttl) is not None
+ assert zeroconf.cache.get(record_with_1s_ttl) is not None
+ await aiozc.async_close()
+ await asyncio.sleep(1.2)
+ assert zeroconf.cache.get(record_with_10s_ttl) is not None
+ assert zeroconf.cache.get(record_with_1s_ttl) is not None
+
+
+class Framework(unittest.TestCase):
+ def test_launch_and_close(self):
+ rv = r.Zeroconf(interfaces=r.InterfaceChoice.All)
+ rv.close()
+ rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default)
+ rv.close()
+
+ def test_launch_and_close_context_manager(self):
+ with r.Zeroconf(interfaces=r.InterfaceChoice.All) as rv:
+ assert rv.done is False
+ assert rv.done is True
+
+ with r.Zeroconf(interfaces=r.InterfaceChoice.Default) as rv:
+ assert rv.done is False
+ assert rv.done is True
+
+ def test_launch_and_close_unicast(self):
+ rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, unicast=True)
+ rv.close()
+ rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, unicast=True)
+ rv.close()
+
+ def test_close_multiple_times(self):
+ rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default)
+ rv.close()
+ rv.close()
+
+ @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
+ @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+ def test_launch_and_close_v4_v6(self):
+ rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.All)
+ rv.close()
+ rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.All)
+ rv.close()
+
+ @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
+ @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+ def test_launch_and_close_v6_only(self):
+ rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.V6Only)
+ rv.close()
+ rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.V6Only)
+ rv.close()
+
+ @unittest.skipIf(sys.platform == 'darwin', reason="apple_p2p failure path not testable on mac")
+ def test_launch_and_close_apple_p2p_not_mac(self):
+ with pytest.raises(RuntimeError):
+ r.Zeroconf(apple_p2p=True)
+
+ @unittest.skipIf(sys.platform != 'darwin', reason="apple_p2p happy path only testable on mac")
+ def test_launch_and_close_apple_p2p_on_mac(self):
+ rv = r.Zeroconf(apple_p2p=True)
+ rv.close()
+
+ def test_handle_response(self):
+ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
+ ttl = 120
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+
+ if service_state_change == r.ServiceStateChange.Updated:
+ generated.add_answer_at_time(
+ r.DNSText(
+ service_name,
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ service_text,
+ ),
+ 0,
+ )
+ return r.DNSIncoming(generated.packets()[0])
+
+ if service_state_change == r.ServiceStateChange.Removed:
+ ttl = 0
+
+ generated.add_answer_at_time(
+ r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0
+ )
+ generated.add_answer_at_time(
+ r.DNSService(
+ service_name,
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ 0,
+ 0,
+ 80,
+ service_server,
+ ),
+ 0,
+ )
+ generated.add_answer_at_time(
+ r.DNSText(
+ service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text
+ ),
+ 0,
+ )
+ generated.add_answer_at_time(
+ r.DNSAddress(
+ service_server,
+ const._TYPE_A,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ socket.inet_aton(service_address),
+ ),
+ 0,
+ )
+
+ return r.DNSIncoming(generated.packets()[0])
+
+ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
+ """Mock an incoming message for the case where the packet is split."""
+ ttl = 120
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ generated.add_answer_at_time(
+ r.DNSAddress(
+ service_server,
+ const._TYPE_A,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ socket.inet_aton(service_address),
+ ),
+ 0,
+ )
+ generated.add_answer_at_time(
+ r.DNSService(
+ service_name,
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ ttl,
+ 0,
+ 0,
+ 80,
+ service_server,
+ ),
+ 0,
+ )
+ return r.DNSIncoming(generated.packets()[0])
+
+ service_name = 'name._type._tcp.local.'
+ service_type = '_type._tcp.local.'
+ service_server = 'ash-2.local.'
+ service_text = b'path=/~paulsm/'
+ service_address = '10.0.1.2'
+
+ zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
+
+ try:
+ # service added
+ _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added))
+ dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN)
+ assert dns_text is not None
+ assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~paulsm/'
+ all_dns_text = zeroconf.cache.get_all_by_details(service_name, const._TYPE_TXT, const._CLASS_IN)
+ assert [dns_text] == all_dns_text
+
+ # https://tools.ietf.org/html/rfc6762#section-10.2
+ # Instead of merging this new record additively into the cache in addition
+ # to any previous records with the same name, rrtype, and rrclass,
+ # all old records with that name, rrtype, and rrclass that were received
+ # more than one second ago are declared invalid,
+ # and marked to expire from the cache in one second.
+ time.sleep(1.1)
+
+ # service updated. currently only text record can be updated
+ service_text = b'path=/~humingchun/'
+ _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated))
+ dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN)
+ assert dns_text is not None
+ assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/'
+
+ time.sleep(1.1)
+
+ # The split message only has a SRV and A record.
+ # This should not evict TXT records from the cache
+ _inject_response(zeroconf, mock_split_incoming_msg(r.ServiceStateChange.Updated))
+ time.sleep(1.1)
+ dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN)
+ assert dns_text is not None
+ assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/'
+
+ # service removed
+ _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed))
+ dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN)
+ assert dns_text.is_expired(current_time_millis() + 1000)
+
+ finally:
+ zeroconf.close()
+
+
+def test_generate_service_query_set_qu_bit():
+ """Test generate_service_query sets the QU bit."""
+
+ zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1'])
+ desc = {'path': '/~paulsm/'}
+ type_ = "._hap._tcp.local."
+ registration_name = "this-host-is-not-used._hap._tcp.local."
+ info = r.ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ out = zeroconf_registrar.generate_service_query(info)
+ assert out.questions[0].unicast is True
+ zeroconf_registrar.close()
+
+
+def test_invalid_packets_ignored_and_does_not_cause_loop_exception():
+ """Ensure an invalid packet cannot cause the loop to collapse."""
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ generated = r.DNSOutgoing(0)
+ packet = generated.packets()[0]
+ packet = packet[:8] + b'deadbeef' + packet[8:]
+ parsed = r.DNSIncoming(packet)
+ assert parsed.valid is False
+
+ # Invalid Packet
+ mock_out = unittest.mock.Mock()
+ mock_out.packets = lambda: [packet]
+ zc.send(mock_out)
+
+ # Invalid oversized packet
+ mock_out = unittest.mock.Mock()
+ mock_out.packets = lambda: [packet * 1000]
+ zc.send(mock_out)
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ entry = r.DNSText(
+ "didnotcrashincoming._crash._tcp.local.",
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ 500,
+ b'path=/~paulsm/',
+ )
+ assert isinstance(entry, r.DNSText)
+ assert isinstance(entry, r.DNSRecord)
+ assert isinstance(entry, r.DNSEntry)
+
+ generated.add_answer_at_time(entry, 0)
+ zc.send(generated)
+ time.sleep(0.2)
+ zc.close()
+ assert zc.cache.get(entry) is not None
+
+
+def test_goodbye_all_services():
+ """Verify generating the goodbye query does not change with time."""
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ out = zc.generate_unregister_all_services()
+ assert out is None
+ type_ = "_http._tcp.local."
+ registration_name = "xxxyyy.%s" % type_
+ desc = {'path': '/~paulsm/'}
+ info = r.ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ zc.registry.async_add(info)
+ out = zc.generate_unregister_all_services()
+ assert out is not None
+ first_packet = out.packets()
+ zc.registry.async_add(info)
+ out2 = zc.generate_unregister_all_services()
+ assert out2 is not None
+ second_packet = out.packets()
+ assert second_packet == first_packet
+
+ # Verify the registery is empty
+ out3 = zc.generate_unregister_all_services()
+ assert out3 is None
+ assert zc.registry.async_get_service_infos() == []
+
+ zc.close()
+
+
+def test_register_service_with_custom_ttl():
+ """Test a registering a service with a custom ttl."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ # start a browser
+ type_ = "_homeassistant._tcp.local."
+ name = "MyTestHome"
+ info_service = r.ServiceInfo(
+ type_,
+ f'{name}.{type_}',
+ 80,
+ 0,
+ 0,
+ {'path': '/~paulsm/'},
+ "ash-90.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+
+ zc.register_service(info_service, ttl=3000)
+ assert zc.cache.get(info_service.dns_pointer()).ttl == 3000
+ zc.close()
+
+
+def test_logging_packets(caplog):
+ """Test packets are only logged with debug logging."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ # start a browser
+ type_ = "_logging._tcp.local."
+ name = "TLD"
+ info_service = r.ServiceInfo(
+ type_,
+ f'{name}.{type_}',
+ 80,
+ 0,
+ 0,
+ {'path': '/~paulsm/'},
+ "ash-90.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+
+ logging.getLogger('zeroconf').setLevel(logging.DEBUG)
+ caplog.clear()
+ zc.register_service(info_service, ttl=3000)
+ assert "Sending to" in caplog.text
+ assert zc.cache.get(info_service.dns_pointer()).ttl == 3000
+ logging.getLogger('zeroconf').setLevel(logging.INFO)
+ caplog.clear()
+ zc.unregister_service(info_service)
+ assert "Sending to" not in caplog.text
+ logging.getLogger('zeroconf').setLevel(logging.DEBUG)
+
+ zc.close()
+
+
+def test_get_service_info_failure_path():
+ """Verify get_service_info return None when the underlying call returns False."""
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ assert zc.get_service_info("_neverused._tcp.local.", "xneverused._neverused._tcp.local.", 10) is None
+ zc.close()
+
+
+def test_sending_unicast():
+ """Test sending unicast response."""
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ entry = r.DNSText(
+ "didnotcrashincoming._crash._tcp.local.",
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ 500,
+ b'path=/~paulsm/',
+ )
+ generated.add_answer_at_time(entry, 0)
+ zc.send(generated, "2001:db8::1", const._MDNS_PORT) # https://www.iana.org/go/rfc3849
+ time.sleep(0.2)
+ assert zc.cache.get(entry) is None
+
+ zc.send(generated, "198.51.100.0", const._MDNS_PORT) # Documentation (TEST-NET-2)
+ time.sleep(0.2)
+ assert zc.cache.get(entry) is None
+
+ zc.send(generated)
+ time.sleep(0.2)
+ assert zc.cache.get(entry) is not None
+
+ zc.close()
+
+
+def test_tc_bit_defers():
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ _wait_for_start(zc)
+ type_ = "_tcbitdefer._tcp.local."
+ name = "knownname"
+ name2 = "knownname2"
+ name3 = "knownname3"
+
+ registration_name = f"{name}.{type_}"
+ registration2_name = f"{name2}.{type_}"
+ registration3_name = f"{name3}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ server_name = "ash-2.local."
+ server_name2 = "ash-3.local."
+ server_name3 = "ash-4.local."
+
+ info = r.ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info2 = r.ServiceInfo(
+ type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info3 = r.ServiceInfo(
+ type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ zc.registry.async_add(info)
+ zc.registry.async_add(info2)
+ zc.registry.async_add(info3)
+
+ protocol = zc.engine.protocols[0]
+ now = r.current_time_millis()
+ _clear_cache(zc)
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
+ generated.add_question(question)
+ for _ in range(300):
+ # Add so many answers we end up with another packet
+ generated.add_answer_at_time(info.dns_pointer(), now)
+ generated.add_answer_at_time(info2.dns_pointer(), now)
+ generated.add_answer_at_time(info3.dns_pointer(), now)
+ packets = generated.packets()
+ assert len(packets) == 4
+ expected_deferred = []
+ source_ip = '203.0.113.13'
+
+ next_packet = r.DNSIncoming(packets.pop(0))
+ expected_deferred.append(next_packet)
+ threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
+ assert protocol._deferred[source_ip] == expected_deferred
+ assert source_ip in protocol._timers
+
+ next_packet = r.DNSIncoming(packets.pop(0))
+ expected_deferred.append(next_packet)
+ threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
+ assert protocol._deferred[source_ip] == expected_deferred
+ assert source_ip in protocol._timers
+ threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
+ assert protocol._deferred[source_ip] == expected_deferred
+ assert source_ip in protocol._timers
+
+ next_packet = r.DNSIncoming(packets.pop(0))
+ expected_deferred.append(next_packet)
+ threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
+ assert protocol._deferred[source_ip] == expected_deferred
+ assert source_ip in protocol._timers
+
+ next_packet = r.DNSIncoming(packets.pop(0))
+ expected_deferred.append(next_packet)
+ threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
+ assert source_ip not in protocol._deferred
+ assert source_ip not in protocol._timers
+
+ # unregister
+ zc.unregister_service(info)
+ zc.close()
+
+
+def test_tc_bit_defers_last_response_missing():
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ _wait_for_start(zc)
+ type_ = "_knowndefer._tcp.local."
+ name = "knownname"
+ name2 = "knownname2"
+ name3 = "knownname3"
+
+ registration_name = f"{name}.{type_}"
+ registration2_name = f"{name2}.{type_}"
+ registration3_name = f"{name3}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ server_name = "ash-2.local."
+ server_name2 = "ash-3.local."
+ server_name3 = "ash-4.local."
+
+ info = r.ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info2 = r.ServiceInfo(
+ type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info3 = r.ServiceInfo(
+ type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ zc.registry.async_add(info)
+ zc.registry.async_add(info2)
+ zc.registry.async_add(info3)
+
+ protocol = zc.engine.protocols[0]
+ now = r.current_time_millis()
+ _clear_cache(zc)
+ source_ip = '203.0.113.12'
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
+ generated.add_question(question)
+ for _ in range(300):
+ # Add so many answers we end up with another packet
+ generated.add_answer_at_time(info.dns_pointer(), now)
+ generated.add_answer_at_time(info2.dns_pointer(), now)
+ generated.add_answer_at_time(info3.dns_pointer(), now)
+ packets = generated.packets()
+ assert len(packets) == 4
+ expected_deferred = []
+
+ next_packet = r.DNSIncoming(packets.pop(0))
+ expected_deferred.append(next_packet)
+ threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
+ assert protocol._deferred[source_ip] == expected_deferred
+ timer1 = protocol._timers[source_ip]
+
+ next_packet = r.DNSIncoming(packets.pop(0))
+ expected_deferred.append(next_packet)
+ threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
+ assert protocol._deferred[source_ip] == expected_deferred
+ timer2 = protocol._timers[source_ip]
+ if sys.version_info >= (3, 7):
+ assert timer1.cancelled()
+ assert timer2 != timer1
+
+ # Send the same packet again to similar multi interfaces
+ threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
+ assert protocol._deferred[source_ip] == expected_deferred
+ assert source_ip in protocol._timers
+ timer3 = protocol._timers[source_ip]
+ if sys.version_info >= (3, 7):
+ assert not timer3.cancelled()
+ assert timer3 == timer2
+
+ next_packet = r.DNSIncoming(packets.pop(0))
+ expected_deferred.append(next_packet)
+ threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
+ assert protocol._deferred[source_ip] == expected_deferred
+ assert source_ip in protocol._timers
+ timer4 = protocol._timers[source_ip]
+ if sys.version_info >= (3, 7):
+ assert timer3.cancelled()
+ assert timer4 != timer3
+
+ for _ in range(8):
+ time.sleep(0.1)
+ if source_ip not in protocol._timers and source_ip not in protocol._deferred:
+ break
+
+ assert source_ip not in protocol._deferred
+ assert source_ip not in protocol._timers
+
+ # unregister
+ zc.registry.async_remove(info)
+ zc.close()
+
+
+@pytest.mark.asyncio
+async def test_open_close_twice_from_async() -> None:
+ """Test we can close twice from a coroutine when using Zeroconf.
+
+ Ideally callers switch to using AsyncZeroconf, however there will
+ be a peroid where they still call the sync wrapper that we want
+ to ensure will not deadlock on shutdown.
+
+ This test is expected to throw warnings about tasks being destroyed
+ since we force shutdown right away since we don't want to block
+ callers event loops and since they aren't using the AsyncZeroconf
+ version they won't yield with an await like async_close we don't
+ have much choice but to force things down.
+ """
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ zc.close()
+ zc.close()
+ await asyncio.sleep(0)
+
+
+@pytest.mark.asyncio
+async def test_multiple_sync_instances_stared_from_async_close():
+ """Test we can shutdown multiple sync instances from async."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ zc2 = Zeroconf(interfaces=['127.0.0.1'])
+
+ assert zc.loop == zc2.loop
+
+ zc.close()
+ assert zc.loop.is_running()
+ zc2.close()
+ assert zc2.loop.is_running()
+
+ zc3 = Zeroconf(interfaces=['127.0.0.1'])
+ assert zc3.loop == zc2.loop
+
+ zc3.close()
+ assert zc3.loop.is_running()
+
+ await asyncio.sleep(0)
+
+
+def test_guard_against_oversized_packets():
+ """Ensure we do not process oversized packets.
+
+ These packets can quickly overwhelm the system.
+ """
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+
+ for i in range(5000):
+ generated.add_answer_at_time(
+ r.DNSText(
+ "packet{i}.local.",
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ 500,
+ b'path=/~paulsm/',
+ ),
+ 0,
+ )
+
+ # We are patching to generate an oversized packet
+ with patch.object(outgoing, "_MAX_MSG_ABSOLUTE", 100000), patch.object(
+ outgoing, "_MAX_MSG_TYPICAL", 100000
+ ):
+ over_sized_packet = generated.packets()[0]
+ assert len(over_sized_packet) > const._MAX_MSG_ABSOLUTE
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ okpacket_record = r.DNSText(
+ "okpacket.local.",
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ 500,
+ b'path=/~paulsm/',
+ )
+
+ generated.add_answer_at_time(
+ okpacket_record,
+ 0,
+ )
+ ok_packet = generated.packets()[0]
+
+ # We cannot test though the network interface as some operating systems
+ # will guard against the oversized packet and we won't see it.
+ listener = _core.AsyncListener(zc)
+ listener.transport = unittest.mock.MagicMock()
+
+ listener.datagram_received(ok_packet, ('127.0.0.1', const._MDNS_PORT))
+ assert zc.cache.async_get_unique(okpacket_record) is not None
+
+ listener.datagram_received(over_sized_packet, ('127.0.0.1', const._MDNS_PORT))
+ assert (
+ zc.cache.async_get_unique(
+ r.DNSText(
+ "packet0.local.",
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ 500,
+ b'path=/~paulsm/',
+ )
+ )
+ is None
+ )
+
+ zc.close()
+
+
+def test_guard_against_duplicate_packets():
+ """Ensure we do not process duplicate packets.
+ These packets can quickly overwhelm the system.
+ """
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ listener = _core.AsyncListener(zc)
+ assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is False
+ assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is True
+ assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is True
+ assert listener.suppress_duplicate_packet(b"first packet", current_time_millis() + 1000) is False
+ assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is True
+ assert listener.suppress_duplicate_packet(b"other packet", current_time_millis()) is False
+ assert listener.suppress_duplicate_packet(b"other packet", current_time_millis()) is True
+ assert listener.suppress_duplicate_packet(b"other packet", current_time_millis() + 1000) is False
+ assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is False
+ zc.close()
+
+
+def test_shutdown_while_register_in_process():
+ """Test we can shutdown while registering a service in another thread."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ # start a browser
+ type_ = "_homeassistant._tcp.local."
+ name = "MyTestHome"
+ info_service = r.ServiceInfo(
+ type_,
+ f'{name}.{type_}',
+ 80,
+ 0,
+ 0,
+ {'path': '/~paulsm/'},
+ "ash-90.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+
+ def _background_register():
+ zc.register_service(info_service)
+
+ bgthread = threading.Thread(target=_background_register, daemon=True)
+ bgthread.start()
+ time.sleep(0.3)
+
+ zc.close()
+ bgthread.join()
diff --git a/tests/test_dns.py b/tests/test_dns.py
new file mode 100644
index 00000000..c2669205
--- /dev/null
+++ b/tests/test_dns.py
@@ -0,0 +1,386 @@
+#!/usr/bin/env python
+
+
+""" Unit tests for zeroconf._dns. """
+
+import logging
+import os
+import socket
+import time
+import unittest
+import unittest.mock
+
+import zeroconf as r
+from zeroconf import const, current_time_millis
+from zeroconf._dns import DNSRRSet
+from zeroconf import (
+ DNSHinfo,
+ DNSText,
+ ServiceInfo,
+)
+
+from . import has_working_ipv6
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+class TestDunder(unittest.TestCase):
+ def test_dns_text_repr(self):
+ # There was an issue on Python 3 that prevented DNSText's repr
+ # from working when the text was longer than 10 bytes
+ text = DNSText('irrelevant', 0, 0, 0, b'12345678901')
+ repr(text)
+
+ text = DNSText('irrelevant', 0, 0, 0, b'123')
+ repr(text)
+
+ def test_dns_hinfo_repr_eq(self):
+ hinfo = DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os')
+ assert hinfo == hinfo
+ repr(hinfo)
+
+ def test_dns_pointer_repr(self):
+ pointer = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123')
+ repr(pointer)
+
+ @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
+ @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+ def test_dns_address_repr(self):
+ address = r.DNSAddress('irrelevant', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
+ assert repr(address).endswith("b'a'")
+
+ address_ipv4 = r.DNSAddress(
+ 'irrelevant', const._TYPE_SOA, const._CLASS_IN, 1, socket.inet_pton(socket.AF_INET, '127.0.0.1')
+ )
+ assert repr(address_ipv4).endswith('127.0.0.1')
+
+ address_ipv6 = r.DNSAddress(
+ 'irrelevant', const._TYPE_SOA, const._CLASS_IN, 1, socket.inet_pton(socket.AF_INET6, '::1')
+ )
+ assert repr(address_ipv6).endswith('::1')
+
+ def test_dns_question_repr(self):
+ question = r.DNSQuestion('irrelevant', const._TYPE_SRV, const._CLASS_IN | const._CLASS_UNIQUE)
+ repr(question)
+ assert not question != question
+
+ def test_dns_service_repr(self):
+ service = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a'
+ )
+ repr(service)
+
+ def test_dns_record_abc(self):
+ record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL)
+ self.assertRaises(r.AbstractMethodException, record.__eq__, record)
+ self.assertRaises(r.AbstractMethodException, record.write, None)
+
+ def test_dns_record_reset_ttl(self):
+ record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL)
+ time.sleep(1)
+ record2 = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL)
+ now = r.current_time_millis()
+
+ assert record.created != record2.created
+ assert record.get_remaining_ttl(now) != record2.get_remaining_ttl(now)
+
+ record.reset_ttl(record2)
+
+ assert record.ttl == record2.ttl
+ assert record.created == record2.created
+ assert record.get_remaining_ttl(now) == record2.get_remaining_ttl(now)
+
+ def test_service_info_dunder(self):
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ b'',
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+
+ assert not info != info
+ repr(info)
+
+ def test_service_info_text_properties_not_given(self):
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+ info = ServiceInfo(
+ type_=type_,
+ name=registration_name,
+ addresses=[socket.inet_aton("10.0.1.2")],
+ port=80,
+ server="ash-2.local.",
+ )
+
+ assert isinstance(info.text, bytes)
+ repr(info)
+
+ def test_dns_outgoing_repr(self):
+ dns_outgoing = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ repr(dns_outgoing)
+
+ def test_dns_record_is_expired(self):
+ record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8)
+ now = current_time_millis()
+ assert record.is_expired(now) is False
+ assert record.is_expired(now + (8 / 2 * 1000)) is False
+ assert record.is_expired(now + (8 * 1000)) is True
+
+ def test_dns_record_is_stale(self):
+ record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8)
+ now = current_time_millis()
+ assert record.is_stale(now) is False
+ assert record.is_stale(now + (8 / 4.1 * 1000)) is False
+ assert record.is_stale(now + (8 / 2 * 1000)) is True
+ assert record.is_stale(now + (8 * 1000)) is True
+
+ def test_dns_record_is_recent(self):
+ now = current_time_millis()
+ record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8)
+ assert record.is_recent(now + (8 / 4.1 * 1000)) is True
+ assert record.is_recent(now + (8 / 3 * 1000)) is False
+ assert record.is_recent(now + (8 / 2 * 1000)) is False
+ assert record.is_recent(now + (8 * 1000)) is False
+
+
+def test_dns_question_hashablity():
+ """Test DNSQuestions are hashable."""
+
+ record1 = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN)
+ record2 = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN)
+
+ record_set = {record1, record2}
+ assert len(record_set) == 1
+
+ record_set.add(record1)
+ assert len(record_set) == 1
+
+ record3_dupe = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN)
+ assert record2 == record3_dupe
+ assert record2.__hash__() == record3_dupe.__hash__()
+
+ record_set.add(record3_dupe)
+ assert len(record_set) == 1
+
+ record4_dupe = r.DNSQuestion('notsame', const._TYPE_A, const._CLASS_IN)
+ assert record2 != record4_dupe
+ assert record2.__hash__() != record4_dupe.__hash__()
+
+ record_set.add(record4_dupe)
+ assert len(record_set) == 2
+
+
+def test_dns_record_hashablity_does_not_consider_ttl():
+ """Test DNSRecord are hashable."""
+
+ # Verify the TTL is not considered in the hash
+ record1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_OTHER_TTL, b'same')
+ record2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same')
+
+ record_set = {record1, record2}
+ assert len(record_set) == 1
+
+ record_set.add(record1)
+ assert len(record_set) == 1
+
+ record3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same')
+ assert record2 == record3_dupe
+ assert record2.__hash__() == record3_dupe.__hash__()
+
+ record_set.add(record3_dupe)
+ assert len(record_set) == 1
+
+
+def test_dns_record_hashablity_does_not_consider_unique():
+ """Test DNSRecord are hashable and unique is ignored."""
+
+ # Verify the unique value is not considered in the hash
+ record1 = r.DNSAddress(
+ 'irrelevant', const._TYPE_A, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_OTHER_TTL, b'same'
+ )
+ record2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_OTHER_TTL, b'same')
+
+ assert record1.class_ == record2.class_
+ assert record1.__hash__() == record2.__hash__()
+ record_set = {record1, record2}
+ assert len(record_set) == 1
+
+
+def test_dns_address_record_hashablity():
+ """Test DNSAddress are hashable."""
+ address1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'a')
+ address2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'b')
+ address3 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c')
+ address4 = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 1, b'c')
+
+ record_set = {address1, address2, address3, address4}
+ assert len(record_set) == 4
+
+ record_set.add(address1)
+ assert len(record_set) == 4
+
+ address3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c')
+
+ record_set.add(address3_dupe)
+ assert len(record_set) == 4
+
+ # Verify we can remove records
+ additional_set = {address1, address2}
+ record_set -= additional_set
+ assert record_set == {address3, address4}
+
+
+def test_dns_hinfo_record_hashablity():
+ """Test DNSHinfo are hashable."""
+ hinfo1 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu1', 'os')
+ hinfo2 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os')
+
+ record_set = {hinfo1, hinfo2}
+ assert len(record_set) == 2
+
+ record_set.add(hinfo1)
+ assert len(record_set) == 2
+
+ hinfo2_dupe = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os')
+ assert hinfo2 == hinfo2_dupe
+ assert hinfo2.__hash__() == hinfo2_dupe.__hash__()
+
+ record_set.add(hinfo2_dupe)
+ assert len(record_set) == 2
+
+
+def test_dns_pointer_record_hashablity():
+ """Test DNSPointer are hashable."""
+ ptr1 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123')
+ ptr2 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456')
+
+ record_set = {ptr1, ptr2}
+ assert len(record_set) == 2
+
+ record_set.add(ptr1)
+ assert len(record_set) == 2
+
+ ptr2_dupe = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456')
+ assert ptr2 == ptr2
+ assert ptr2.__hash__() == ptr2_dupe.__hash__()
+
+ record_set.add(ptr2_dupe)
+ assert len(record_set) == 2
+
+
+def test_dns_text_record_hashablity():
+ """Test DNSText are hashable."""
+ text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
+ text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901')
+ text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901')
+ text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK')
+
+ record_set = {text1, text2, text3, text4}
+
+ assert len(record_set) == 4
+
+ record_set.add(text1)
+ assert len(record_set) == 4
+
+ text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
+ assert text1 == text1_dupe
+ assert text1.__hash__() == text1_dupe.__hash__()
+
+ record_set.add(text1_dupe)
+ assert len(record_set) == 4
+
+
+def test_dns_service_record_hashablity():
+ """Test DNSService are hashable."""
+ srv1 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a')
+ srv2 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 1, 80, 'a')
+ srv3 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 81, 'a')
+ srv4 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab')
+
+ record_set = {srv1, srv2, srv3, srv4}
+
+ assert len(record_set) == 4
+
+ record_set.add(srv1)
+ assert len(record_set) == 4
+
+ srv1_dupe = r.DNSService(
+ 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a'
+ )
+ assert srv1 == srv1_dupe
+ assert srv1.__hash__() == srv1_dupe.__hash__()
+
+ record_set.add(srv1_dupe)
+ assert len(record_set) == 4
+
+
+def test_dns_nsec_record_hashablity():
+ """Test DNSNsec are hashable."""
+ nsec1 = r.DNSNsec(
+ 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2, 3]
+ )
+ nsec2 = r.DNSNsec(
+ 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2]
+ )
+
+ record_set = {nsec1, nsec2}
+ assert len(record_set) == 2
+
+ record_set.add(nsec1)
+ assert len(record_set) == 2
+
+ nsec2_dupe = r.DNSNsec(
+ 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2]
+ )
+ assert nsec2 == nsec2_dupe
+ assert nsec2.__hash__() == nsec2_dupe.__hash__()
+
+ record_set.add(nsec2_dupe)
+ assert len(record_set) == 2
+
+
+def test_rrset_does_not_consider_ttl():
+ """Test DNSRRSet does not consider the ttl in the hash."""
+
+ longarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 100, b'same')
+ shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same')
+ longaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 100, b'same')
+ shortaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 10, b'same')
+
+ rrset = DNSRRSet([longarec, shortaaaarec])
+
+ assert rrset.suppresses(longarec)
+ assert rrset.suppresses(shortarec)
+ assert not rrset.suppresses(longaaaarec)
+ assert rrset.suppresses(shortaaaarec)
+
+ verylongarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1000, b'same')
+ longarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 100, b'same')
+ mediumarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 60, b'same')
+ shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same')
+
+ rrset2 = DNSRRSet([mediumarec])
+ assert not rrset2.suppresses(verylongarec)
+ assert rrset2.suppresses(longarec)
+ assert rrset2.suppresses(mediumarec)
+ assert rrset2.suppresses(shortarec)
diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py
new file mode 100644
index 00000000..47e68b75
--- /dev/null
+++ b/tests/test_exceptions.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python
+
+
+""" Unit tests for zeroconf._exceptions """
+
+import logging
+import unittest
+import unittest.mock
+
+import zeroconf as r
+from zeroconf import (
+ ServiceInfo,
+ Zeroconf,
+)
+
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+class Exceptions(unittest.TestCase):
+
+ browser = None # type: Zeroconf
+
+ @classmethod
+ def setUpClass(cls):
+ cls.browser = Zeroconf(interfaces=['127.0.0.1'])
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.browser.close()
+ del cls.browser
+
+ def test_bad_service_info_name(self):
+ self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, "type", "type_not")
+
+ def test_bad_service_names(self):
+ bad_names_to_try = (
+ '',
+ 'local',
+ '_tcp.local.',
+ '_udp.local.',
+ '._udp.local.',
+ '_@._tcp.local.',
+ '_A@._tcp.local.',
+ '_x--x._tcp.local.',
+ '_-x._udp.local.',
+ '_x-._tcp.local.',
+ '_22._udp.local.',
+ '_2-2._tcp.local.',
+ '\x00._x._udp.local.',
+ )
+ for name in bad_names_to_try:
+ self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, name, 'x.' + name)
+
+ def test_bad_local_names_for_get_service_info(self):
+ bad_names_to_try = (
+ 'homekitdev._nothttp._tcp.local.',
+ 'homekitdev._http._udp.local.',
+ )
+ for name in bad_names_to_try:
+ self.assertRaises(
+ r.BadTypeInNameException, self.browser.get_service_info, '_http._tcp.local.', name
+ )
+
+ def test_good_instance_names(self):
+ assert r.service_type_name('.._x._tcp.local.') == '_x._tcp.local.'
+ assert r.service_type_name('x.sub._http._tcp.local.') == '_http._tcp.local.'
+ assert (
+ r.service_type_name('6d86f882b90facee9170ad3439d72a4d6ee9f511._zget._http._tcp.local.')
+ == '_http._tcp.local.'
+ )
+
+ def test_good_instance_names_without_protocol(self):
+ good_names_to_try = (
+ "Rachio-C73233.local.",
+ 'YeelightColorBulb-3AFD.local.',
+ 'YeelightTunableBulb-7220.local.',
+ "AlexanderHomeAssistant 74651D.local.",
+ 'iSmartGate-152.local.',
+ 'MyQ-FGA.local.',
+ 'lutron-02c4392a.local.',
+ 'WICED-hap-3E2734.local.',
+ 'MyHost.local.',
+ 'MyHost.sub.local.',
+ )
+ for name in good_names_to_try:
+ assert r.service_type_name(name, strict=False) == 'local.'
+
+ for name in good_names_to_try:
+ # Raises without strict=False
+ self.assertRaises(r.BadTypeInNameException, r.service_type_name, name)
+
+ def test_bad_types(self):
+ bad_names_to_try = (
+ '._x._tcp.local.',
+ 'a' * 64 + '._sub._http._tcp.local.',
+ 'a' * 62 + 'â._sub._http._tcp.local.',
+ )
+ for name in bad_names_to_try:
+ self.assertRaises(r.BadTypeInNameException, r.service_type_name, name)
+
+ def test_bad_sub_types(self):
+ bad_names_to_try = (
+ '_sub._http._tcp.local.',
+ '._sub._http._tcp.local.',
+ '\x7f._sub._http._tcp.local.',
+ '\x1f._sub._http._tcp.local.',
+ )
+ for name in bad_names_to_try:
+ self.assertRaises(r.BadTypeInNameException, r.service_type_name, name)
+
+ def test_good_service_names(self):
+ good_names_to_try = (
+ ('_x._tcp.local.', '_x._tcp.local.'),
+ ('_x._udp.local.', '_x._udp.local.'),
+ ('_12345-67890-abc._udp.local.', '_12345-67890-abc._udp.local.'),
+ ('x._sub._http._tcp.local.', '_http._tcp.local.'),
+ ('a' * 63 + '._sub._http._tcp.local.', '_http._tcp.local.'),
+ ('a' * 61 + 'â._sub._http._tcp.local.', '_http._tcp.local.'),
+ )
+
+ for name, result in good_names_to_try:
+ assert r.service_type_name(name) == result
+
+ assert r.service_type_name('_one_two._tcp.local.', strict=False) == '_one_two._tcp.local.'
+
+ def test_invalid_addresses(self):
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ bad = ('127.0.0.1', '::1', 42)
+ for addr in bad:
+ self.assertRaisesRegex(
+ TypeError,
+ 'Addresses must be bytes',
+ ServiceInfo,
+ type_,
+ registration_name,
+ port=80,
+ addresses=[addr],
+ )
diff --git a/tests/test_handlers.py b/tests/test_handlers.py
new file mode 100644
index 00000000..44ee1d5a
--- /dev/null
+++ b/tests/test_handlers.py
@@ -0,0 +1,1540 @@
+#!/usr/bin/env python
+
+
+""" Unit tests for zeroconf._handlers """
+
+import asyncio
+import logging
+import os
+import pytest
+import socket
+import time
+import unittest
+import unittest.mock
+from typing import List
+
+import zeroconf as r
+from zeroconf import _handlers, ServiceInfo, Zeroconf, current_time_millis
+from zeroconf import const
+from zeroconf._handlers import construct_outgoing_multicast_answers, MulticastOutgoingQueue
+from zeroconf._utils.time import millis_to_seconds
+from zeroconf.asyncio import AsyncZeroconf
+
+
+from . import _clear_cache, _inject_response, has_working_ipv6
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+class TestRegistrar(unittest.TestCase):
+ def test_ttl(self):
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ # service definition
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_,
+ registration_name,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+
+ nbr_answers = nbr_additionals = nbr_authorities = 0
+
+ def get_ttl(record_type):
+ if expected_ttl is not None:
+ return expected_ttl
+ elif record_type in [const._TYPE_A, const._TYPE_SRV]:
+ return const._DNS_HOST_TTL
+ else:
+ return const._DNS_OTHER_TTL
+
+ def _process_outgoing_packet(out):
+ """Sends an outgoing packet."""
+ nonlocal nbr_answers, nbr_additionals, nbr_authorities
+
+ for answer, time_ in out.answers:
+ nbr_answers += 1
+ assert answer.ttl == get_ttl(answer.type)
+ for answer in out.additionals:
+ nbr_additionals += 1
+ assert answer.ttl == get_ttl(answer.type)
+ for answer in out.authorities:
+ nbr_authorities += 1
+ assert answer.ttl == get_ttl(answer.type)
+
+ # register service with default TTL
+ expected_ttl = None
+ for _ in range(3):
+ _process_outgoing_packet(zc.generate_service_query(info))
+ zc.registry.async_add(info)
+ for _ in range(3):
+ _process_outgoing_packet(zc.generate_service_broadcast(info, None))
+ assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3
+ nbr_answers = nbr_additionals = nbr_authorities = 0
+
+ # query
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
+ assert query.is_query() is True
+ query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
+ query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN))
+ query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN))
+ query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN))
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate))
+
+ # The additonals should all be suppresed since they are all in the answers section
+ # There will be one NSEC additional to indicate the lack of AAAA record
+ #
+ assert nbr_answers == 4 and nbr_additionals == 1 and nbr_authorities == 0
+ nbr_answers = nbr_additionals = nbr_authorities = 0
+
+ # unregister
+ expected_ttl = 0
+ zc.registry.async_remove(info)
+ for _ in range(3):
+ _process_outgoing_packet(zc.generate_service_broadcast(info, 0))
+ assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0
+ nbr_answers = nbr_additionals = nbr_authorities = 0
+
+ expected_ttl = None
+ for _ in range(3):
+ _process_outgoing_packet(zc.generate_service_query(info))
+ zc.registry.async_add(info)
+ # register service with custom TTL
+ expected_ttl = const._DNS_HOST_TTL * 2
+ assert expected_ttl != const._DNS_HOST_TTL
+ for _ in range(3):
+ _process_outgoing_packet(zc.generate_service_broadcast(info, expected_ttl))
+ assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3
+ nbr_answers = nbr_additionals = nbr_authorities = 0
+
+ # query
+ expected_ttl = None
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
+ query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
+ query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN))
+ query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN))
+ query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN))
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate))
+
+ # There will be one NSEC additional to indicate the lack of AAAA record
+ assert nbr_answers == 4 and nbr_additionals == 1 and nbr_authorities == 0
+ nbr_answers = nbr_additionals = nbr_authorities = 0
+
+ # unregister
+ expected_ttl = 0
+ zc.registry.async_remove(info)
+ for _ in range(3):
+ _process_outgoing_packet(zc.generate_service_broadcast(info, 0))
+ assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0
+ nbr_answers = nbr_additionals = nbr_authorities = 0
+ zc.close()
+
+ def test_name_conflicts(self):
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ type_ = "_homeassistant._tcp.local."
+ name = "Home"
+ registration_name = f"{name}.{type_}"
+
+ info = ServiceInfo(
+ type_,
+ name=registration_name,
+ server="random123.local.",
+ addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")],
+ port=80,
+ properties={"version": "1.0"},
+ )
+ zc.register_service(info)
+
+ conflicting_info = ServiceInfo(
+ type_,
+ name=registration_name,
+ server="random456.local.",
+ addresses=[socket.inet_pton(socket.AF_INET, "4.5.6.7")],
+ port=80,
+ properties={"version": "1.0"},
+ )
+ with pytest.raises(r.NonUniqueNameException):
+ zc.register_service(conflicting_info)
+ zc.close()
+
+ def test_register_and_lookup_type_by_uppercase_name(self):
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ type_ = "_mylowertype._tcp.local."
+ name = "Home"
+ registration_name = f"{name}.{type_}"
+
+ info = ServiceInfo(
+ type_,
+ name=registration_name,
+ server="random123.local.",
+ addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")],
+ port=80,
+ properties={"version": "1.0"},
+ )
+ zc.register_service(info)
+ _clear_cache(zc)
+ info = ServiceInfo(type_, registration_name)
+ info.load_from_cache(zc)
+ assert info.addresses == []
+
+ out = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ out.add_question(r.DNSQuestion(type_.upper(), const._TYPE_PTR, const._CLASS_IN))
+ zc.send(out)
+ time.sleep(1)
+ info = ServiceInfo(type_, registration_name)
+ info.load_from_cache(zc)
+ assert info.addresses == [socket.inet_pton(socket.AF_INET, "1.2.3.4")]
+ assert info.properties == {b"version": b"1.0"}
+ zc.close()
+
+
+def test_ptr_optimization():
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ # service definition
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+
+ # register
+ zc.register_service(info)
+
+ # Verify we won't respond for 1s with the same multicast
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ # Since we sent the PTR in the last second, they
+ # should end up in the delayed at least one second bucket
+ assert question_answers.mcast_aggregate_last_second
+
+ # Clear the cache to allow responding again
+ _clear_cache(zc)
+
+ # Verify we will now respond
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate_last_second
+ has_srv = has_txt = has_a = False
+ nbr_additionals = 0
+ nbr_answers = len(question_answers.mcast_aggregate)
+ additionals = set().union(*question_answers.mcast_aggregate.values())
+ for answer in additionals:
+ nbr_additionals += 1
+ if answer.type == const._TYPE_SRV:
+ has_srv = True
+ elif answer.type == const._TYPE_TXT:
+ has_txt = True
+ elif answer.type == const._TYPE_A:
+ has_a = True
+ assert nbr_answers == 1 and nbr_additionals == 4
+ # There will be one NSEC additional to indicate the lack of AAAA record
+
+ assert has_srv and has_txt and has_a
+
+ # unregister
+ zc.unregister_service(info)
+ zc.close()
+
+
+@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
+@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+def test_any_query_for_ptr():
+ """Test that queries for ANY will return PTR records and the response is aggregated."""
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ type_ = "_anyptr._tcp.local."
+ name = "knownname"
+ registration_name = f"{name}.{type_}"
+ desc = {'path': '/~paulsm/'}
+ server_name = "ash-2.local."
+ ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1")
+ info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address])
+ zc.registry.async_add(info)
+
+ _clear_cache(zc)
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(type_, const._TYPE_ANY, const._CLASS_IN)
+ generated.add_question(question)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ mcast_answers = list(question_answers.mcast_aggregate)
+ assert mcast_answers[0].name == type_
+ assert mcast_answers[0].alias == registration_name
+ # unregister
+ zc.registry.async_remove(info)
+ zc.close()
+
+
+@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
+@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+def test_aaaa_query():
+ """Test that queries for AAAA records work and should respond right away."""
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ type_ = "_knownaaaservice._tcp.local."
+ name = "knownname"
+ registration_name = f"{name}.{type_}"
+ desc = {'path': '/~paulsm/'}
+ server_name = "ash-2.local."
+ ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1")
+ info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address])
+ zc.registry.async_add(info)
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN)
+ generated.add_question(question)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ mcast_answers = list(question_answers.mcast_now)
+ assert mcast_answers[0].address == ipv6_address
+ # unregister
+ zc.registry.async_remove(info)
+ zc.close()
+
+
+@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
+@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+def test_a_and_aaaa_record_fate_sharing():
+ """Test that queries for AAAA always return A records in the additionals and should respond right away."""
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ type_ = "_a-and-aaaa-service._tcp.local."
+ name = "knownname"
+ registration_name = f"{name}.{type_}"
+ desc = {'path': '/~paulsm/'}
+ server_name = "ash-2.local."
+ ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1")
+ ipv4_address = socket.inet_aton("10.0.1.2")
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address, ipv4_address]
+ )
+ aaaa_record = info.dns_addresses(version=r.IPVersion.V6Only)[0]
+ a_record = info.dns_addresses(version=r.IPVersion.V4Only)[0]
+
+ zc.registry.async_add(info)
+
+ # Test AAAA query
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN)
+ generated.add_question(question)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ additionals = set().union(*question_answers.mcast_now.values())
+ assert aaaa_record in question_answers.mcast_now
+ assert a_record in additionals
+ assert len(question_answers.mcast_now) == 1
+ assert len(additionals) == 1
+
+ # Test A query
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN)
+ generated.add_question(question)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ additionals = set().union(*question_answers.mcast_now.values())
+ assert a_record in question_answers.mcast_now
+ assert aaaa_record in additionals
+ assert len(question_answers.mcast_now) == 1
+ assert len(additionals) == 1
+
+ # unregister
+ zc.registry.async_remove(info)
+ zc.close()
+
+
+def test_unicast_response():
+ """Ensure we send a unicast response when the source port is not the MDNS port."""
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ # service definition
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ # register
+ zc.registry.async_add(info)
+ _clear_cache(zc)
+
+ # query
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], True
+ )
+ for answers in (question_answers.ucast, question_answers.mcast_aggregate):
+ has_srv = has_txt = has_a = has_aaaa = has_nsec = False
+ nbr_additionals = 0
+ nbr_answers = len(answers)
+ additionals = set().union(*answers.values())
+ for answer in additionals:
+ nbr_additionals += 1
+ if answer.type == const._TYPE_SRV:
+ has_srv = True
+ elif answer.type == const._TYPE_TXT:
+ has_txt = True
+ elif answer.type == const._TYPE_A:
+ has_a = True
+ elif answer.type == const._TYPE_AAAA:
+ has_aaaa = True
+ elif answer.type == const._TYPE_NSEC:
+ has_nsec = True
+ # There will be one NSEC additional to indicate the lack of AAAA record
+ assert nbr_answers == 1 and nbr_additionals == 4
+ assert has_srv and has_txt and has_a and has_nsec
+ assert not has_aaaa
+
+ # unregister
+ zc.registry.async_remove(info)
+ zc.close()
+
+
+@pytest.mark.asyncio
+async def test_probe_answered_immediately():
+ """Verify probes are responded to immediately."""
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ # service definition
+ type_ = "_test-srvc-type._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ zc.registry.async_add(info)
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ query.add_question(question)
+ query.add_authorative_answer(info.dns_pointer())
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ assert not question_answers.ucast
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+ assert question_answers.mcast_now
+
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True
+ query.add_question(question)
+ query.add_authorative_answer(info.dns_pointer())
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ assert question_answers.ucast
+ assert question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+ zc.close()
+
+
+def test_qu_response():
+ """Handle multicast incoming with the QU bit set."""
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ # service definition
+ type_ = "_test-srvc-type._tcp.local."
+ other_type_ = "_notthesame._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+ registration_name2 = f"{name}.{other_type_}"
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info2 = ServiceInfo(
+ other_type_,
+ registration_name2,
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-other.local.",
+ addresses=[socket.inet_aton("10.0.4.2")],
+ )
+ # register
+ zc.register_service(info)
+
+ def _validate_complete_response(answers):
+ has_srv = has_txt = has_a = has_aaaa = has_nsec = False
+ nbr_answers = len(answers.keys())
+ additionals = set().union(*answers.values())
+ nbr_additionals = len(additionals)
+
+ for answer in additionals:
+ if answer.type == const._TYPE_SRV:
+ has_srv = True
+ elif answer.type == const._TYPE_TXT:
+ has_txt = True
+ elif answer.type == const._TYPE_A:
+ has_a = True
+ elif answer.type == const._TYPE_AAAA:
+ has_aaaa = True
+ elif answer.type == const._TYPE_NSEC:
+ has_nsec = True
+ assert nbr_answers == 1 and nbr_additionals == 4
+ assert has_srv and has_txt and has_a and has_nsec
+ assert not has_aaaa
+
+ # With QU should respond to only unicast when the answer has been recently multicast
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True # Set the QU bit
+ assert question.unicast is True
+ query.add_question(question)
+
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ _validate_complete_response(question_answers.ucast)
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ _clear_cache(zc)
+ # With QU should respond to only multicast since the response hasn't been seen since 75% of the ttl
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True # Set the QU bit
+ assert question.unicast is True
+ query.add_question(question)
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ assert not question_answers.ucast
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate
+ _validate_complete_response(question_answers.mcast_now)
+
+ # With QU set and an authorative answer (probe) should respond to both unitcast and multicast since the response hasn't been seen since 75% of the ttl
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True # Set the QU bit
+ assert question.unicast is True
+ query.add_question(question)
+ query.add_authorative_answer(info2.dns_pointer())
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ _validate_complete_response(question_answers.ucast)
+ _validate_complete_response(question_answers.mcast_now)
+
+ _inject_response(
+ zc, r.DNSIncoming(construct_outgoing_multicast_answers(question_answers.mcast_now).packets()[0])
+ )
+ # With the cache repopulated; should respond to only unicast when the answer has been recently multicast
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True # Set the QU bit
+ assert question.unicast is True
+ query.add_question(question)
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+ _validate_complete_response(question_answers.ucast)
+ # unregister
+ zc.unregister_service(info)
+ zc.close()
+
+
+def test_known_answer_supression():
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ type_ = "_knownanswersv8._tcp.local."
+ name = "knownname"
+ registration_name = f"{name}.{type_}"
+ desc = {'path': '/~paulsm/'}
+ server_name = "ash-2.local."
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ zc.registry.async_add(info)
+
+ now = current_time_millis()
+ _clear_cache(zc)
+ # Test PTR supression
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
+ generated.add_question(question)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
+ generated.add_question(question)
+ generated.add_answer_at_time(info.dns_pointer(), now)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ # Test A supression
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN)
+ generated.add_question(question)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN)
+ generated.add_question(question)
+ for dns_address in info.dns_addresses():
+ generated.add_answer_at_time(dns_address, now)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ # Test NSEC record returned when there is no AAAA record and we expectly ask
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN)
+ generated.add_question(question)
+ for dns_address in info.dns_addresses():
+ generated.add_answer_at_time(dns_address, now)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ expected_nsec_record: r.DNSNsec = list(question_answers.mcast_now)[0]
+ assert const._TYPE_A not in expected_nsec_record.rdtypes
+ assert const._TYPE_AAAA in expected_nsec_record.rdtypes
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ # Test SRV supression
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN)
+ generated.add_question(question)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN)
+ generated.add_question(question)
+ generated.add_answer_at_time(info.dns_service(), now)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ # Test TXT supression
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(registration_name, const._TYPE_TXT, const._CLASS_IN)
+ generated.add_question(question)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(registration_name, const._TYPE_TXT, const._CLASS_IN)
+ generated.add_question(question)
+ generated.add_answer_at_time(info.dns_text(), now)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ # unregister
+ zc.registry.async_remove(info)
+ zc.close()
+
+
+def test_multi_packet_known_answer_supression():
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ type_ = "_handlermultis._tcp.local."
+ name = "knownname"
+ name2 = "knownname2"
+ name3 = "knownname3"
+
+ registration_name = f"{name}.{type_}"
+ registration2_name = f"{name2}.{type_}"
+ registration3_name = f"{name3}.{type_}"
+
+ desc = {'path': '/~paulsm/'}
+ server_name = "ash-2.local."
+ server_name2 = "ash-3.local."
+ server_name3 = "ash-4.local."
+
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info2 = ServiceInfo(
+ type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info3 = ServiceInfo(
+ type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ zc.registry.async_add(info)
+ zc.registry.async_add(info2)
+ zc.registry.async_add(info3)
+
+ now = current_time_millis()
+ _clear_cache(zc)
+ # Test PTR supression
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
+ generated.add_question(question)
+ for _ in range(1000):
+ # Add so many answers we end up with another packet
+ generated.add_answer_at_time(info.dns_pointer(), now)
+ generated.add_answer_at_time(info2.dns_pointer(), now)
+ generated.add_answer_at_time(info3.dns_pointer(), now)
+ packets = generated.packets()
+ assert len(packets) > 1
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+ # unregister
+ zc.registry.async_remove(info)
+ zc.registry.async_remove(info2)
+ zc.registry.async_remove(info3)
+ zc.close()
+
+
+def test_known_answer_supression_service_type_enumeration_query():
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+ type_ = "_otherknown._tcp.local."
+ name = "knownname"
+ registration_name = f"{name}.{type_}"
+ desc = {'path': '/~paulsm/'}
+ server_name = "ash-2.local."
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ zc.registry.async_add(info)
+
+ type_2 = "_otherknown2._tcp.local."
+ name = "knownname"
+ registration_name2 = f"{name}.{type_2}"
+ desc = {'path': '/~paulsm/'}
+ server_name2 = "ash-3.local."
+ info2 = ServiceInfo(
+ type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ zc.registry.async_add(info2)
+ now = current_time_millis()
+ _clear_cache(zc)
+
+ # Test PTR supression
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME, const._TYPE_PTR, const._CLASS_IN)
+ generated.add_question(question)
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME, const._TYPE_PTR, const._CLASS_IN)
+ generated.add_question(question)
+ generated.add_answer_at_time(
+ r.DNSPointer(
+ const._SERVICE_TYPE_ENUMERATION_NAME,
+ const._TYPE_PTR,
+ const._CLASS_IN,
+ const._DNS_OTHER_TTL,
+ type_,
+ ),
+ now,
+ )
+ generated.add_answer_at_time(
+ r.DNSPointer(
+ const._SERVICE_TYPE_ENUMERATION_NAME,
+ const._TYPE_PTR,
+ const._CLASS_IN,
+ const._DNS_OTHER_TTL,
+ type_2,
+ ),
+ now,
+ )
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ # unregister
+ zc.registry.async_remove(info)
+ zc.registry.async_remove(info2)
+ zc.close()
+
+
+# This test uses asyncio because it needs to access the cache directly
+# which is not threadsafe
+@pytest.mark.asyncio
+async def test_qu_response_only_sends_additionals_if_sends_answer():
+ """Test that a QU response does not send additionals unless it sends the answer as well."""
+ # instantiate a zeroconf instance
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zc = aiozc.zeroconf
+
+ type_ = "_addtest1._tcp.local."
+ name = "knownname"
+ registration_name = f"{name}.{type_}"
+ desc = {'path': '/~paulsm/'}
+ server_name = "ash-2.local."
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ zc.registry.async_add(info)
+
+ type_2 = "_addtest2._tcp.local."
+ name = "knownname"
+ registration_name2 = f"{name}.{type_2}"
+ desc = {'path': '/~paulsm/'}
+ server_name2 = "ash-3.local."
+ info2 = ServiceInfo(
+ type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ zc.registry.async_add(info2)
+
+ ptr_record = info.dns_pointer()
+
+ # Add the PTR record to the cache
+ zc.cache.async_add_records([ptr_record])
+
+ # Add the A record to the cache with 50% ttl remaining
+ a_record = info.dns_addresses()[0]
+ a_record.set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl)
+ assert not a_record.is_recent(current_time_millis())
+ zc.cache.async_add_records([a_record])
+
+ # With QU should respond to only unicast when the answer has been recently multicast
+ # even if the additional has not been recently multicast
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True # Set the QU bit
+ assert question.unicast is True
+ query.add_question(question)
+
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ additionals = set().union(*question_answers.ucast.values())
+ assert a_record in additionals
+ assert ptr_record in question_answers.ucast
+
+ # Remove the 50% A record and add a 100% A record
+ zc.cache.async_remove_records([a_record])
+ a_record = info.dns_addresses()[0]
+ assert a_record.is_recent(current_time_millis())
+ zc.cache.async_add_records([a_record])
+ # With QU should respond to only unicast when the answer has been recently multicast
+ # even if the additional has not been recently multicast
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True # Set the QU bit
+ assert question.unicast is True
+ query.add_question(question)
+
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+ additionals = set().union(*question_answers.ucast.values())
+ assert a_record in additionals
+ assert ptr_record in question_answers.ucast
+
+ # Remove the 100% PTR record and add a 50% PTR record
+ zc.cache.async_remove_records([ptr_record])
+ ptr_record.set_created_ttl(current_time_millis() - (ptr_record.ttl * 1000 / 2), ptr_record.ttl)
+ assert not ptr_record.is_recent(current_time_millis())
+ zc.cache.async_add_records([ptr_record])
+ # With QU should respond to only multicast since the has less
+ # than 75% of its ttl remaining
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True # Set the QU bit
+ assert question.unicast is True
+ query.add_question(question)
+
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ assert not question_answers.ucast
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+ additionals = set().union(*question_answers.mcast_now.values())
+ assert a_record in additionals
+ assert info.dns_text() in additionals
+ assert info.dns_service() in additionals
+ assert ptr_record in question_answers.mcast_now
+
+ # Ask 2 QU questions, with info the PTR is at 50%, with info2 the PTR is at 100%
+ # We should get back a unicast reply for info2, but info should be multicasted since its within 75% of its TTL
+ # With QU should respond to only multicast since the has less
+ # than 75% of its ttl remaining
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True # Set the QU bit
+ assert question.unicast is True
+ query.add_question(question)
+
+ question = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True # Set the QU bit
+ assert question.unicast is True
+ query.add_question(question)
+ zc.cache.async_add_records([info2.dns_pointer()]) # Add 100% TTL for info2 to the cache
+
+ question_answers = zc.query_handler.async_response(
+ [r.DNSIncoming(packet) for packet in query.packets()], False
+ )
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+
+ mcast_now_additionals = set().union(*question_answers.mcast_now.values())
+ assert a_record in mcast_now_additionals
+ assert info.dns_text() in mcast_now_additionals
+ assert info.dns_addresses()[0] in mcast_now_additionals
+ assert info.dns_pointer() in question_answers.mcast_now
+
+ ucast_additionals = set().union(*question_answers.ucast.values())
+ assert info2.dns_pointer() in question_answers.ucast
+ assert info2.dns_text() in ucast_additionals
+ assert info2.dns_service() in ucast_additionals
+ assert info2.dns_addresses()[0] in ucast_additionals
+
+ # unregister
+ zc.registry.async_remove(info)
+ await aiozc.async_close()
+
+
+# This test uses asyncio because it needs to access the cache directly
+# which is not threadsafe
+@pytest.mark.asyncio
+async def test_cache_flush_bit():
+ """Test that the cache flush bit sets the TTL to one for matching records."""
+ # instantiate a zeroconf instance
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zc = aiozc.zeroconf
+
+ type_ = "_cacheflush._tcp.local."
+ name = "knownname"
+ registration_name = f"{name}.{type_}"
+ desc = {'path': '/~paulsm/'}
+ server_name = "server-uu1.local."
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ a_record = info.dns_addresses()[0]
+ zc.cache.async_add_records([info.dns_pointer(), a_record, info.dns_text(), info.dns_service()])
+
+ info.addresses = [socket.inet_aton("10.0.1.5"), socket.inet_aton("10.0.1.6")]
+ new_records = info.dns_addresses()
+ for new_record in new_records:
+ assert new_record.unique is True
+
+ original_a_record = zc.cache.async_get_unique(a_record)
+ # Do the run within 1s to verify the original record is not going to be expired
+ out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True)
+ for answer in new_records:
+ out.add_answer_at_time(answer, 0)
+ for packet in out.packets():
+ zc.record_manager.async_updates_from_response(r.DNSIncoming(packet))
+ assert zc.cache.async_get_unique(a_record) is original_a_record
+ assert original_a_record.ttl != 1
+ for record in new_records:
+ assert zc.cache.async_get_unique(record) is not None
+
+ original_a_record.created = current_time_millis() - 1001
+
+ # Do the run within 1s to verify the original record is not going to be expired
+ out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True)
+ for answer in new_records:
+ out.add_answer_at_time(answer, 0)
+ for packet in out.packets():
+ zc.record_manager.async_updates_from_response(r.DNSIncoming(packet))
+ assert original_a_record.ttl == 1
+ for record in new_records:
+ assert zc.cache.async_get_unique(record) is not None
+
+ cached_records = [zc.cache.async_get_unique(record) for record in new_records]
+ for record in cached_records:
+ record.created = current_time_millis() - 1001
+
+ fresh_address = socket.inet_aton("4.4.4.4")
+ info.addresses = [fresh_address]
+ # Do the run within 1s to verify the two new records get marked as expired
+ out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True)
+ for answer in info.dns_addresses():
+ out.add_answer_at_time(answer, 0)
+ for packet in out.packets():
+ zc.record_manager.async_updates_from_response(r.DNSIncoming(packet))
+ for record in cached_records:
+ assert record.ttl == 1
+
+ for entry in zc.cache.async_all_by_details(server_name, const._TYPE_A, const._CLASS_IN):
+ if entry.address == fresh_address:
+ assert entry.ttl > 1
+ else:
+ assert entry.ttl == 1
+
+ # Wait for the ttl 1 records to expire
+ await asyncio.sleep(1.1)
+
+ loaded_info = r.ServiceInfo(type_, registration_name)
+ loaded_info.load_from_cache(zc)
+ assert loaded_info.addresses == info.addresses
+
+ await aiozc.async_close()
+
+
+# This test uses asyncio because it needs to access the cache directly
+# which is not threadsafe
+@pytest.mark.asyncio
+async def test_record_update_manager_add_listener_callsback_existing_records():
+ """Test that the RecordUpdateManager will callback existing records."""
+
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zc: Zeroconf = aiozc.zeroconf
+ updated = []
+
+ class MyListener(r.RecordUpdateListener):
+ """A RecordUpdateListener that does not implement update_records."""
+
+ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.RecordUpdate]) -> None:
+ """Update multiple records in one shot."""
+ updated.extend(records)
+
+ type_ = "_cacheflush._tcp.local."
+ name = "knownname"
+ registration_name = f"{name}.{type_}"
+ desc = {'path': '/~paulsm/'}
+ server_name = "server-uu1.local."
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ a_record = info.dns_addresses()[0]
+ ptr_record = info.dns_pointer()
+ zc.cache.async_add_records([ptr_record, a_record, info.dns_text(), info.dns_service()])
+
+ listener = MyListener()
+
+ zc.add_listener(
+ listener,
+ [
+ r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN),
+ r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN),
+ ],
+ )
+ await asyncio.sleep(0) # flush out the call_soon_threadsafe
+
+ assert {record.new for record in updated} == {ptr_record, a_record}
+
+ # The old records should be None so we trigger Add events
+ # in service browsers instead of Update events
+ assert {record.old for record in updated} == {None}
+
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_questions_query_handler_populates_the_question_history_from_qm_questions():
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zc = aiozc.zeroconf
+ now = current_time_millis()
+ _clear_cache(zc)
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = False
+ known_answer = r.DNSPointer(
+ "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.'
+ )
+ generated.add_question(question)
+ generated.add_answer_at_time(known_answer, 0)
+ now = r.current_time_millis()
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+ assert zc.question_history.suppresses(question, now, {known_answer})
+
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_questions_query_handler_does_not_put_qu_questions_in_history():
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zc = aiozc.zeroconf
+ now = current_time_millis()
+ _clear_cache(zc)
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN)
+ question.unicast = True
+ known_answer = r.DNSPointer(
+ "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.'
+ )
+ generated.add_question(question)
+ generated.add_answer_at_time(known_answer, 0)
+ now = r.current_time_millis()
+ packets = generated.packets()
+ question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
+ assert not question_answers.ucast
+ assert not question_answers.mcast_now
+ assert not question_answers.mcast_aggregate
+ assert not question_answers.mcast_aggregate_last_second
+ assert not zc.question_history.suppresses(question, now, {known_answer})
+
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_guard_against_low_ptr_ttl():
+ """Ensure we enforce a minimum for PTR record ttls to avoid excessive refresh queries from ServiceBrowsers.
+
+ Some poorly designed IoT devices can set excessively low PTR
+ TTLs would will cause ServiceBrowsers to flood the network
+ with excessive refresh queries.
+ """
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zc = aiozc.zeroconf
+ # Apple uses a 15s minimum TTL, however we do not have the same
+ # level of rate limit and safe guards so we use 1/4 of the recommended value
+ answer_with_low_ttl = r.DNSPointer(
+ "myservicelow_tcp._tcp.local.",
+ const._TYPE_PTR,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ 2,
+ 'low.local.',
+ )
+ answer_with_normal_ttl = r.DNSPointer(
+ "myservicelow_tcp._tcp.local.",
+ const._TYPE_PTR,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ 'normal.local.',
+ )
+ good_bye_answer = r.DNSPointer(
+ "myservicelow_tcp._tcp.local.",
+ const._TYPE_PTR,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ 0,
+ 'goodbye.local.',
+ )
+ # TTL should be adjusted to a safe value
+ response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ response.add_answer_at_time(answer_with_low_ttl, 0)
+ response.add_answer_at_time(answer_with_normal_ttl, 0)
+ response.add_answer_at_time(good_bye_answer, 0)
+ incoming = r.DNSIncoming(response.packets()[0])
+ zc.record_manager.async_updates_from_response(incoming)
+
+ incoming_answer_low = zc.cache.async_get_unique(answer_with_low_ttl)
+ assert incoming_answer_low.ttl == const._DNS_PTR_MIN_TTL
+ incoming_answer_normal = zc.cache.async_get_unique(answer_with_normal_ttl)
+ assert incoming_answer_normal.ttl == const._DNS_OTHER_TTL
+ assert zc.cache.async_get_unique(good_bye_answer) is None
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_duplicate_goodbye_answers_in_packet():
+ """Ensure we do not throw an exception when there are duplicate goodbye records in a packet."""
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ zc = aiozc.zeroconf
+ answer_with_normal_ttl = r.DNSPointer(
+ "myservicelow_tcp._tcp.local.",
+ const._TYPE_PTR,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ 'host.local.',
+ )
+ good_bye_answer = r.DNSPointer(
+ "myservicelow_tcp._tcp.local.",
+ const._TYPE_PTR,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ 0,
+ 'host.local.',
+ )
+ response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ response.add_answer_at_time(answer_with_normal_ttl, 0)
+ incoming = r.DNSIncoming(response.packets()[0])
+ zc.record_manager.async_updates_from_response(incoming)
+
+ response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ response.add_answer_at_time(good_bye_answer, 0)
+ response.add_answer_at_time(good_bye_answer, 0)
+ incoming = r.DNSIncoming(response.packets()[0])
+ zc.record_manager.async_updates_from_response(incoming)
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_response_aggregation_timings(run_isolated):
+ """Verify multicast respones are aggregated."""
+ type_ = "_mservice._tcp.local."
+ type_2 = "_mservice2._tcp.local."
+ type_3 = "_mservice3._tcp.local."
+
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ await aiozc.zeroconf.async_wait_for_start()
+
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+ registration_name2 = f"{name}.{type_2}"
+ registration_name3 = f"{name}.{type_3}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info2 = ServiceInfo(
+ type_2, registration_name2, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.3")]
+ )
+ info3 = ServiceInfo(
+ type_3, registration_name3, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.3")]
+ )
+ aiozc.zeroconf.registry.async_add(info)
+ aiozc.zeroconf.registry.async_add(info2)
+ aiozc.zeroconf.registry.async_add(info3)
+
+ query = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True)
+ question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
+ query.add_question(question)
+
+ query2 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True)
+ question2 = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN)
+ query2.add_question(question2)
+
+ query3 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True)
+ question3 = r.DNSQuestion(info3.type, const._TYPE_PTR, const._CLASS_IN)
+ query3.add_question(question3)
+
+ query4 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True)
+ query4.add_question(question)
+ query4.add_question(question2)
+
+ zc = aiozc.zeroconf
+ protocol = zc.engine.protocols[0]
+
+ with unittest.mock.patch.object(aiozc.zeroconf, "async_send") as send_mock:
+ protocol.datagram_received(query.packets()[0], ('127.0.0.1', const._MDNS_PORT))
+ protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
+ protocol.datagram_received(query.packets()[0], ('127.0.0.1', const._MDNS_PORT))
+ await asyncio.sleep(0.7)
+
+ # Should aggregate into a single answer with up to a 500ms + 120ms delay
+ calls = send_mock.mock_calls
+ assert len(calls) == 1
+ outgoing = send_mock.call_args[0][0]
+ incoming = r.DNSIncoming(outgoing.packets()[0])
+ zc.handle_response(incoming)
+ assert info.dns_pointer() in incoming.answers
+ assert info2.dns_pointer() in incoming.answers
+ send_mock.reset_mock()
+
+ protocol.datagram_received(query3.packets()[0], ('127.0.0.1', const._MDNS_PORT))
+ await asyncio.sleep(0.3)
+
+ # Should send within 120ms since there are no other
+ # answers to aggregate with
+ calls = send_mock.mock_calls
+ assert len(calls) == 1
+ outgoing = send_mock.call_args[0][0]
+ incoming = r.DNSIncoming(outgoing.packets()[0])
+ zc.handle_response(incoming)
+ assert info3.dns_pointer() in incoming.answers
+ send_mock.reset_mock()
+
+ # Because the response was sent in the last second we need to make
+ # sure the next answer is delayed at least a second
+ aiozc.zeroconf.engine.protocols[0].datagram_received(
+ query4.packets()[0], ('127.0.0.1', const._MDNS_PORT)
+ )
+ await asyncio.sleep(0.5)
+
+ # After 0.5 seconds it should not have been sent
+ # Protect the network against excessive packet flooding
+ # https://datatracker.ietf.org/doc/html/rfc6762#section-14
+ calls = send_mock.mock_calls
+ assert len(calls) == 0
+ send_mock.reset_mock()
+
+ await asyncio.sleep(1.2)
+ calls = send_mock.mock_calls
+ assert len(calls) == 1
+ outgoing = send_mock.call_args[0][0]
+ incoming = r.DNSIncoming(outgoing.packets()[0])
+ assert info.dns_pointer() in incoming.answers
+
+ await aiozc.async_close()
+
+
+@pytest.mark.asyncio
+async def test_response_aggregation_timings_multiple(run_isolated):
+ """Verify multicast responses that are aggregated do not take longer than 620ms to send.
+
+ 620ms is the maximum random delay of 120ms and 500ms additional for aggregation."""
+ type_2 = "_mservice2._tcp.local."
+
+ aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
+ await aiozc.zeroconf.async_wait_for_start()
+
+ name = "xxxyyy"
+ registration_name2 = f"{name}.{type_2}"
+
+ desc = {'path': '/~paulsm/'}
+ info2 = ServiceInfo(
+ type_2, registration_name2, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.3")]
+ )
+ aiozc.zeroconf.registry.async_add(info2)
+
+ query2 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True)
+ question2 = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN)
+ query2.add_question(question2)
+
+ zc = aiozc.zeroconf
+ protocol = zc.engine.protocols[0]
+
+ with unittest.mock.patch.object(aiozc.zeroconf, "async_send") as send_mock, unittest.mock.patch.object(
+ protocol, "suppress_duplicate_packet", return_value=False
+ ):
+ send_mock.reset_mock()
+ protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
+ await asyncio.sleep(0.2)
+ calls = send_mock.mock_calls
+ assert len(calls) == 1
+ outgoing = send_mock.call_args[0][0]
+ incoming = r.DNSIncoming(outgoing.packets()[0])
+ zc.handle_response(incoming)
+ assert info2.dns_pointer() in incoming.answers
+
+ send_mock.reset_mock()
+ protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
+ await asyncio.sleep(1.2)
+ calls = send_mock.mock_calls
+ assert len(calls) == 1
+ outgoing = send_mock.call_args[0][0]
+ incoming = r.DNSIncoming(outgoing.packets()[0])
+ zc.handle_response(incoming)
+ assert info2.dns_pointer() in incoming.answers
+
+ send_mock.reset_mock()
+ protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
+ protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
+ # The delay should increase with two packets and
+ # 900ms is beyond the maximum aggregation delay
+ # when there is no network protection delay
+ await asyncio.sleep(0.9)
+ calls = send_mock.mock_calls
+ assert len(calls) == 0
+
+ # 1000ms (1s network protection delays)
+ # - 900ms (already slept)
+ # + 120ms (maximum random delay)
+ # + 200ms (maximum protected aggregation delay)
+ # + 20ms (execution time)
+ await asyncio.sleep(millis_to_seconds(1000 - 900 + 120 + 200 + 20))
+ calls = send_mock.mock_calls
+ assert len(calls) == 1
+ outgoing = send_mock.call_args[0][0]
+ incoming = r.DNSIncoming(outgoing.packets()[0])
+ zc.handle_response(incoming)
+ assert info2.dns_pointer() in incoming.answers
+
+
+@pytest.mark.asyncio
+async def test_response_aggregation_random_delay():
+ """Verify the random delay for outgoing multicast will coalesce into a single group
+
+ When the random delay is shorter than the last outgoing group,
+ the groups should be combined.
+ """
+ type_ = "_mservice._tcp.local."
+ type_2 = "_mservice2._tcp.local."
+ type_3 = "_mservice3._tcp.local."
+ type_4 = "_mservice4._tcp.local."
+ type_5 = "_mservice5._tcp.local."
+
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+ registration_name2 = f"{name}.{type_2}"
+ registration_name3 = f"{name}.{type_3}"
+ registration_name4 = f"{name}.{type_4}"
+ registration_name5 = f"{name}.{type_5}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-1.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info2 = ServiceInfo(
+ type_2, registration_name2, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.3")]
+ )
+ info3 = ServiceInfo(
+ type_3, registration_name3, 80, 0, 0, desc, "ash-3.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info4 = ServiceInfo(
+ type_4, registration_name4, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info5 = ServiceInfo(
+ type_5, registration_name5, 80, 0, 0, desc, "ash-5.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ mocked_zc = unittest.mock.MagicMock()
+ outgoing_queue = MulticastOutgoingQueue(mocked_zc, 0, 500)
+
+ now = current_time_millis()
+ with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (500, 600)):
+ outgoing_queue.async_add(now, {info.dns_pointer(): set()})
+
+ # The second group should always be coalesced into first group since it will always come before
+ with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (300, 400)):
+ outgoing_queue.async_add(now, {info2.dns_pointer(): set()})
+
+ # The third group should always be coalesced into first group since it will always come before
+ with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (100, 200)):
+ outgoing_queue.async_add(now, {info3.dns_pointer(): set(), info4.dns_pointer(): set()})
+
+ assert len(outgoing_queue.queue) == 1
+ assert info.dns_pointer() in outgoing_queue.queue[0].answers
+ assert info2.dns_pointer() in outgoing_queue.queue[0].answers
+ assert info3.dns_pointer() in outgoing_queue.queue[0].answers
+ assert info4.dns_pointer() in outgoing_queue.queue[0].answers
+
+ # The forth group should not be coalesced because its scheduled after the last group in the queue
+ with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (700, 800)):
+ outgoing_queue.async_add(now, {info5.dns_pointer(): set()})
+
+ assert len(outgoing_queue.queue) == 2
+ assert info.dns_pointer() not in outgoing_queue.queue[1].answers
+ assert info2.dns_pointer() not in outgoing_queue.queue[1].answers
+ assert info3.dns_pointer() not in outgoing_queue.queue[1].answers
+ assert info4.dns_pointer() not in outgoing_queue.queue[1].answers
+ assert info5.dns_pointer() in outgoing_queue.queue[1].answers
+
+
+@pytest.mark.asyncio
+async def test_future_answers_are_removed_on_send():
+ """Verify any future answers scheduled to be sent are removed when we send."""
+ type_ = "_mservice._tcp.local."
+ type_2 = "_mservice2._tcp.local."
+ name = "xxxyyy"
+ registration_name = f"{name}.{type_}"
+ registration_name2 = f"{name}.{type_2}"
+
+ desc = {'path': '/~paulsm/'}
+ info = ServiceInfo(
+ type_, registration_name, 80, 0, 0, desc, "ash-1.local.", addresses=[socket.inet_aton("10.0.1.2")]
+ )
+ info2 = ServiceInfo(
+ type_2, registration_name2, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.3")]
+ )
+ mocked_zc = unittest.mock.MagicMock()
+ outgoing_queue = MulticastOutgoingQueue(mocked_zc, 0, 0)
+
+ now = current_time_millis()
+ with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (1, 1)):
+ outgoing_queue.async_add(now, {info.dns_pointer(): set()})
+
+ assert len(outgoing_queue.queue) == 1
+
+ with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (2, 2)):
+ outgoing_queue.async_add(now, {info.dns_pointer(): set()})
+
+ assert len(outgoing_queue.queue) == 2
+
+ with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (1000, 1000)):
+ outgoing_queue.async_add(now, {info2.dns_pointer(): set()})
+ outgoing_queue.async_add(now, {info.dns_pointer(): set()})
+
+ assert len(outgoing_queue.queue) == 3
+
+ await asyncio.sleep(0.1)
+ outgoing_queue.async_ready()
+
+ assert len(outgoing_queue.queue) == 1
+ # The answer should get removed because we just sent it
+ assert info.dns_pointer() not in outgoing_queue.queue[0].answers
+
+ # But the one we have not sent yet shoudl still go out later
+ assert info2.dns_pointer() in outgoing_queue.queue[0].answers
diff --git a/tests/test_history.py b/tests/test_history.py
new file mode 100644
index 00000000..9da6b567
--- /dev/null
+++ b/tests/test_history.py
@@ -0,0 +1,68 @@
+#!/usr/bin/env python
+
+
+"""Unit tests for _history.py."""
+
+from zeroconf._history import QuestionHistory
+import zeroconf as r
+import zeroconf.const as const
+
+
+def test_question_suppression():
+ history = QuestionHistory()
+
+ question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN)
+ now = r.current_time_millis()
+ other_known_answers = {
+ r.DNSPointer(
+ "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.'
+ )
+ }
+ our_known_answers = {
+ r.DNSPointer(
+ "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-us._hap._tcp.local.'
+ )
+ }
+
+ history.add_question_at_time(question, now, other_known_answers)
+
+ # Verify the question is suppressed if the known answers are the same
+ assert history.suppresses(question, now, other_known_answers)
+
+ # Verify the question is suppressed if we know the answer to all the known answers
+ assert history.suppresses(question, now, other_known_answers | our_known_answers)
+
+ # Verify the question is not suppressed if our known answers do no include the ones in the last question
+ assert not history.suppresses(question, now, set())
+
+ # Verify the question is not suppressed if our known answers do no include the ones in the last question
+ assert not history.suppresses(question, now, our_known_answers)
+
+ # Verify the question is no longer suppressed after 1s
+ assert not history.suppresses(question, now + 1000, other_known_answers)
+
+
+def test_question_expire():
+ history = QuestionHistory()
+
+ question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN)
+ now = r.current_time_millis()
+ other_known_answers = {
+ r.DNSPointer(
+ "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.'
+ )
+ }
+ history.add_question_at_time(question, now, other_known_answers)
+
+ # Verify the question is suppressed if the known answers are the same
+ assert history.suppresses(question, now, other_known_answers)
+
+ history.async_expire(now)
+
+ # Verify the question is suppressed if the known answers are the same since the cache hasn't expired
+ assert history.suppresses(question, now, other_known_answers)
+
+ history.async_expire(now + 1000)
+
+ # Verify the question not longer suppressed since the cache has expired
+ assert not history.suppresses(question, now, other_known_answers)
diff --git a/tests/test_init.py b/tests/test_init.py
new file mode 100644
index 00000000..1d1f7086
--- /dev/null
+++ b/tests/test_init.py
@@ -0,0 +1,188 @@
+#!/usr/bin/env python
+
+
+""" Unit tests for zeroconf.py """
+
+import logging
+import socket
+import time
+import unittest
+import unittest.mock
+from unittest.mock import patch
+
+import zeroconf as r
+from zeroconf import ServiceInfo, Zeroconf, const
+
+from . import _inject_responses
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+class Names(unittest.TestCase):
+ def test_long_name(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ question = r.DNSQuestion(
+ "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", const._TYPE_SRV, const._CLASS_IN
+ )
+ generated.add_question(question)
+ r.DNSIncoming(generated.packets()[0])
+
+ def test_exceedingly_long_name(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ name = "%slocal." % ("part." * 1000)
+ question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN)
+ generated.add_question(question)
+ r.DNSIncoming(generated.packets()[0])
+
+ def test_extra_exceedingly_long_name(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ name = "%slocal." % ("part." * 4000)
+ question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN)
+ generated.add_question(question)
+ r.DNSIncoming(generated.packets()[0])
+
+ def test_exceedingly_long_name_part(self):
+ name = "%s.local." % ("a" * 1000)
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN)
+ generated.add_question(question)
+ self.assertRaises(r.NamePartTooLongException, generated.packets)
+
+ def test_same_name(self):
+ name = "paired.local."
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN)
+ generated.add_question(question)
+ generated.add_question(question)
+ r.DNSIncoming(generated.packets()[0])
+
+ def test_verify_name_change_with_lots_of_names(self):
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ # create a bunch of servers
+ type_ = "_my-service._tcp.local."
+ name = 'a wonderful service'
+ server_count = 300
+ self.generate_many_hosts(zc, type_, name, server_count)
+
+ # verify that name changing works
+ self.verify_name_change(zc, type_, name, server_count)
+
+ zc.close()
+
+ def test_large_packet_exception_log_handling(self):
+ """Verify we downgrade debug after warning."""
+
+ # instantiate a zeroconf instance
+ zc = Zeroconf(interfaces=['127.0.0.1'])
+
+ with patch('zeroconf._logger.log.warning') as mocked_log_warn, patch(
+ 'zeroconf._logger.log.debug'
+ ) as mocked_log_debug:
+ # now that we have a long packet in our possession, let's verify the
+ # exception handling.
+ out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA)
+ out.data.append(b'\0' * 10000)
+
+ # mock the zeroconf logger and check for the correct logging backoff
+ call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count
+ # try to send an oversized packet
+ zc.send(out)
+ assert mocked_log_warn.call_count == call_counts[0]
+ zc.send(out)
+ assert mocked_log_warn.call_count == call_counts[0]
+
+ # mock the zeroconf logger and check for the correct logging backoff
+ call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count
+ # force receive on oversized packet
+ zc.send(out, const._MDNS_ADDR, const._MDNS_PORT)
+ zc.send(out, const._MDNS_ADDR, const._MDNS_PORT)
+ time.sleep(0.3)
+ r.log.debug(
+ 'warn %d debug %d was %s',
+ mocked_log_warn.call_count,
+ mocked_log_debug.call_count,
+ call_counts,
+ )
+ assert mocked_log_debug.call_count > call_counts[0]
+
+ # close our zeroconf which will close the sockets
+ zc.close()
+
+ def verify_name_change(self, zc, type_, name, number_hosts):
+ desc = {'path': '/~paulsm/'}
+ info_service = ServiceInfo(
+ type_,
+ f'{name}.{type_}',
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+
+ # verify name conflict
+ self.assertRaises(r.NonUniqueNameException, zc.register_service, info_service)
+
+ # verify no name conflict https://tools.ietf.org/html/rfc6762#section-6.6
+ zc.register_service(info_service, cooperating_responders=True)
+
+ # Create a new object since allow_name_change will mutate the
+ # original object and then we will have the wrong service
+ # in the registry
+ info_service2 = ServiceInfo(
+ type_,
+ f'{name}.{type_}',
+ 80,
+ 0,
+ 0,
+ desc,
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+ zc.register_service(info_service2, allow_name_change=True)
+ assert info_service2.name.split('.')[0] == '%s-%d' % (name, number_hosts + 1)
+
+ def generate_many_hosts(self, zc, type_, name, number_hosts):
+ block_size = 25
+ number_hosts = int((number_hosts - 1) / block_size + 1) * block_size
+ out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA)
+ for i in range(1, number_hosts + 1):
+ next_name = name if i == 1 else '%s-%d' % (name, i)
+ self.generate_host(out, next_name, type_)
+
+ _inject_responses(zc, [r.DNSIncoming(packet) for packet in out.packets()])
+
+ @staticmethod
+ def generate_host(out, host_name, type_):
+ name = '.'.join((host_name, type_))
+ out.add_answer_at_time(
+ r.DNSPointer(type_, const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, name), 0
+ )
+ out.add_answer_at_time(
+ r.DNSService(
+ type_,
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL,
+ 0,
+ 0,
+ 80,
+ name,
+ ),
+ 0,
+ )
diff --git a/tests/test_logger.py b/tests/test_logger.py
new file mode 100644
index 00000000..2d8bbb08
--- /dev/null
+++ b/tests/test_logger.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+
+
+"""Unit tests for logger.py."""
+
+import logging
+from unittest.mock import patch
+from zeroconf._logger import QuietLogger, set_logger_level_if_unset
+
+
+def test_loading_logger():
+ """Test loading logger does not change level unless it is unset."""
+ log = logging.getLogger('zeroconf')
+ log.setLevel(logging.CRITICAL)
+ set_logger_level_if_unset()
+ log = logging.getLogger('zeroconf')
+ assert log.level == logging.CRITICAL
+
+ log = logging.getLogger('zeroconf')
+ log.setLevel(logging.NOTSET)
+ set_logger_level_if_unset()
+ log = logging.getLogger('zeroconf')
+ assert log.level == logging.WARNING
+
+
+def test_log_warning_once():
+ """Test we only log with warning level once."""
+ quiet_logger = QuietLogger()
+ with patch("zeroconf._logger.log.warning") as mock_log_warning, patch(
+ "zeroconf._logger.log.debug"
+ ) as mock_log_debug:
+ quiet_logger.log_warning_once("the warning")
+
+ assert mock_log_warning.mock_calls
+ assert not mock_log_debug.mock_calls
+
+ with patch("zeroconf._logger.log.warning") as mock_log_warning, patch(
+ "zeroconf._logger.log.debug"
+ ) as mock_log_debug:
+ quiet_logger.log_warning_once("the warning")
+
+ assert not mock_log_warning.mock_calls
+ assert mock_log_debug.mock_calls
+
+
+def test_log_exception_warning():
+ """Test we only log with warning level once."""
+ quiet_logger = QuietLogger()
+ with patch("zeroconf._logger.log.warning") as mock_log_warning, patch(
+ "zeroconf._logger.log.debug"
+ ) as mock_log_debug:
+ quiet_logger.log_exception_warning("the exception warning")
+
+ assert mock_log_warning.mock_calls
+ assert not mock_log_debug.mock_calls
+
+ with patch("zeroconf._logger.log.warning") as mock_log_warning, patch(
+ "zeroconf._logger.log.debug"
+ ) as mock_log_debug:
+ quiet_logger.log_exception_warning("the exception warning")
+
+ assert not mock_log_warning.mock_calls
+ assert mock_log_debug.mock_calls
+
+
+def test_log_exception_once():
+ """Test we only log with warning level once."""
+ quiet_logger = QuietLogger()
+ exc = Exception()
+ with patch("zeroconf._logger.log.warning") as mock_log_warning, patch(
+ "zeroconf._logger.log.debug"
+ ) as mock_log_debug:
+ quiet_logger.log_exception_once(exc, "the exceptional exception warning")
+
+ assert mock_log_warning.mock_calls
+ assert not mock_log_debug.mock_calls
+
+ with patch("zeroconf._logger.log.warning") as mock_log_warning, patch(
+ "zeroconf._logger.log.debug"
+ ) as mock_log_debug:
+ quiet_logger.log_exception_once(exc, "the exceptional exception warning")
+
+ assert not mock_log_warning.mock_calls
+ assert mock_log_debug.mock_calls
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
new file mode 100644
index 00000000..55dbbe4d
--- /dev/null
+++ b/tests/test_protocol.py
@@ -0,0 +1,1026 @@
+#!/usr/bin/env python
+
+
+""" Unit tests for zeroconf._protocol """
+
+import copy
+import logging
+import os
+import socket
+import struct
+import unittest
+import unittest.mock
+from typing import cast
+
+import zeroconf as r
+from zeroconf import DNSIncoming, const, current_time_millis
+from zeroconf import (
+ DNSHinfo,
+ DNSText,
+)
+
+from . import has_working_ipv6
+
+log = logging.getLogger('zeroconf')
+original_logging_level = logging.NOTSET
+
+
+def setup_module():
+ global original_logging_level
+ original_logging_level = log.level
+ log.setLevel(logging.DEBUG)
+
+
+def teardown_module():
+ if original_logging_level != logging.NOTSET:
+ log.setLevel(original_logging_level)
+
+
+class PacketGeneration(unittest.TestCase):
+ def test_parse_own_packet_simple(self):
+ generated = r.DNSOutgoing(0)
+ r.DNSIncoming(generated.packets()[0])
+
+ def test_parse_own_packet_simple_unicast(self):
+ generated = r.DNSOutgoing(0, False)
+ r.DNSIncoming(generated.packets()[0])
+
+ def test_parse_own_packet_flags(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ r.DNSIncoming(generated.packets()[0])
+
+ def test_parse_own_packet_question(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ generated.add_question(r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN))
+ r.DNSIncoming(generated.packets()[0])
+
+ def test_parse_own_packet_nsec(self):
+ answer = r.DNSNsec(
+ 'eufy HomeBase2-2464._hap._tcp.local.',
+ const._TYPE_NSEC,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ 'eufy HomeBase2-2464._hap._tcp.local.',
+ [const._TYPE_TXT, const._TYPE_SRV],
+ )
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ generated.add_answer_at_time(answer, 0)
+ parsed = r.DNSIncoming(generated.packets()[0])
+ assert answer in parsed.answers
+
+ # Types > 255 should be ignored
+ answer_invalid_types = r.DNSNsec(
+ 'eufy HomeBase2-2464._hap._tcp.local.',
+ const._TYPE_NSEC,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ 'eufy HomeBase2-2464._hap._tcp.local.',
+ [const._TYPE_TXT, const._TYPE_SRV, 1000],
+ )
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ generated.add_answer_at_time(answer_invalid_types, 0)
+ parsed = r.DNSIncoming(generated.packets()[0])
+ assert answer in parsed.answers
+
+ def test_parse_own_packet_response(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ generated.add_answer_at_time(
+ r.DNSService(
+ "æøå.local.",
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL,
+ 0,
+ 0,
+ 80,
+ "foo.local.",
+ ),
+ 0,
+ )
+ parsed = r.DNSIncoming(generated.packets()[0])
+ assert len(generated.answers) == 1
+ assert len(generated.answers) == len(parsed.answers)
+
+ def test_adding_empty_answer(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ generated.add_answer_at_time(
+ None,
+ 0,
+ )
+ generated.add_answer_at_time(
+ r.DNSService(
+ "æøå.local.",
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL,
+ 0,
+ 0,
+ 80,
+ "foo.local.",
+ ),
+ 0,
+ )
+ parsed = r.DNSIncoming(generated.packets()[0])
+ assert len(generated.answers) == 1
+ assert len(generated.answers) == len(parsed.answers)
+
+ def test_adding_expired_answer(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ generated.add_answer_at_time(
+ r.DNSService(
+ "æøå.local.",
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL,
+ 0,
+ 0,
+ 80,
+ "foo.local.",
+ ),
+ current_time_millis() + 1000000,
+ )
+ parsed = r.DNSIncoming(generated.packets()[0])
+ assert len(generated.answers) == 0
+ assert len(generated.answers) == len(parsed.answers)
+
+ def test_match_question(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN)
+ generated.add_question(question)
+ parsed = r.DNSIncoming(generated.packets()[0])
+ assert len(generated.questions) == 1
+ assert len(generated.questions) == len(parsed.questions)
+ assert question == parsed.questions[0]
+
+ def test_suppress_answer(self):
+ query_generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN)
+ query_generated.add_question(question)
+ answer1 = r.DNSService(
+ "testname1.local.",
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL,
+ 0,
+ 0,
+ 80,
+ "foo.local.",
+ )
+ staleanswer2 = r.DNSService(
+ "testname2.local.",
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL / 2,
+ 0,
+ 0,
+ 80,
+ "foo.local.",
+ )
+ answer2 = r.DNSService(
+ "testname2.local.",
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL,
+ 0,
+ 0,
+ 80,
+ "foo.local.",
+ )
+ query_generated.add_answer_at_time(answer1, 0)
+ query_generated.add_answer_at_time(staleanswer2, 0)
+ query = r.DNSIncoming(query_generated.packets()[0])
+
+ # Should be suppressed
+ response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ response.add_answer(query, answer1)
+ assert len(response.answers) == 0
+
+ # Should not be suppressed, TTL in query is too short
+ response.add_answer(query, answer2)
+ assert len(response.answers) == 1
+
+ # Should not be suppressed, name is different
+ tmp = copy.copy(answer1)
+ tmp.key = "testname3.local."
+ tmp.name = "testname3.local."
+ response.add_answer(query, tmp)
+ assert len(response.answers) == 2
+
+ # Should not be suppressed, type is different
+ tmp = copy.copy(answer1)
+ tmp.type = const._TYPE_A
+ response.add_answer(query, tmp)
+ assert len(response.answers) == 3
+
+ # Should not be suppressed, class is different
+ tmp = copy.copy(answer1)
+ tmp.class_ = const._CLASS_NONE
+ response.add_answer(query, tmp)
+ assert len(response.answers) == 4
+
+ # ::TODO:: could add additional tests for DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService
+
+ def test_dns_hinfo(self):
+ generated = r.DNSOutgoing(0)
+ generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os'))
+ parsed = r.DNSIncoming(generated.packets()[0])
+ answer = cast(r.DNSHinfo, parsed.answers[0])
+ assert answer.cpu == 'cpu'
+ assert answer.os == 'os'
+
+ generated = r.DNSOutgoing(0)
+ generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257))
+ self.assertRaises(r.NamePartTooLongException, generated.packets)
+
+ def test_many_questions(self):
+ """Test many questions get seperated into multiple packets."""
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ questions = []
+ for i in range(100):
+ question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN)
+ generated.add_question(question)
+ questions.append(question)
+ assert len(generated.questions) == 100
+
+ packets = generated.packets()
+ assert len(packets) == 2
+ assert len(packets[0]) < const._MAX_MSG_TYPICAL
+ assert len(packets[1]) < const._MAX_MSG_TYPICAL
+
+ parsed1 = r.DNSIncoming(packets[0])
+ assert len(parsed1.questions) == 85
+ parsed2 = r.DNSIncoming(packets[1])
+ assert len(parsed2.questions) == 15
+
+ def test_many_questions_with_many_known_answers(self):
+ """Test many questions and known answers get seperated into multiple packets."""
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ questions = []
+ for _ in range(30):
+ question = r.DNSQuestion(f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN)
+ generated.add_question(question)
+ questions.append(question)
+ assert len(generated.questions) == 30
+ now = current_time_millis()
+ for _ in range(200):
+ known_answer = r.DNSPointer(
+ "myservice{i}_tcp._tcp.local.",
+ const._TYPE_PTR,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ '123.local.',
+ )
+ generated.add_answer_at_time(known_answer, now)
+ packets = generated.packets()
+ assert len(packets) == 3
+ assert len(packets[0]) <= const._MAX_MSG_TYPICAL
+ assert len(packets[1]) <= const._MAX_MSG_TYPICAL
+ assert len(packets[2]) <= const._MAX_MSG_TYPICAL
+
+ parsed1 = r.DNSIncoming(packets[0])
+ assert len(parsed1.questions) == 30
+ assert len(parsed1.answers) == 88
+ assert parsed1.truncated
+ parsed2 = r.DNSIncoming(packets[1])
+ assert len(parsed2.questions) == 0
+ assert len(parsed2.answers) == 101
+ assert parsed2.truncated
+ parsed3 = r.DNSIncoming(packets[2])
+ assert len(parsed3.questions) == 0
+ assert len(parsed3.answers) == 11
+ assert not parsed3.truncated
+
+ def test_massive_probe_packet_split(self):
+ """Test probe with many authorative answers."""
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
+ questions = []
+ for _ in range(30):
+ question = r.DNSQuestion(
+ f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN | const._CLASS_UNIQUE
+ )
+ generated.add_question(question)
+ questions.append(question)
+ assert len(generated.questions) == 30
+ now = current_time_millis()
+ for _ in range(200):
+ authorative_answer = r.DNSPointer(
+ "myservice{i}_tcp._tcp.local.",
+ const._TYPE_PTR,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ '123.local.',
+ )
+ generated.add_authorative_answer(authorative_answer)
+ packets = generated.packets()
+ assert len(packets) == 3
+ assert len(packets[0]) <= const._MAX_MSG_TYPICAL
+ assert len(packets[1]) <= const._MAX_MSG_TYPICAL
+ assert len(packets[2]) <= const._MAX_MSG_TYPICAL
+
+ parsed1 = r.DNSIncoming(packets[0])
+ assert parsed1.questions[0].unicast is True
+ assert len(parsed1.questions) == 30
+ assert parsed1.num_authorities == 88
+ assert parsed1.truncated
+ parsed2 = r.DNSIncoming(packets[1])
+ assert len(parsed2.questions) == 0
+ assert parsed2.num_authorities == 101
+ assert parsed2.truncated
+ parsed3 = r.DNSIncoming(packets[2])
+ assert len(parsed3.questions) == 0
+ assert parsed3.num_authorities == 11
+ assert not parsed3.truncated
+
+ def test_only_one_answer_can_by_large(self):
+ """Test that only the first answer in each packet can be large.
+
+ https://datatracker.ietf.org/doc/html/rfc6762#section-17
+ """
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ query = r.DNSIncoming(r.DNSOutgoing(const._FLAGS_QR_QUERY).packets()[0])
+ for i in range(3):
+ generated.add_answer(
+ query,
+ r.DNSText(
+ "zoom._hap._tcp.local.",
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ 1200,
+ b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==' * 100,
+ ),
+ )
+ generated.add_answer(
+ query,
+ r.DNSService(
+ "testname1.local.",
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL,
+ 0,
+ 0,
+ 80,
+ "foo.local.",
+ ),
+ )
+ assert len(generated.answers) == 4
+
+ packets = generated.packets()
+ assert len(packets) == 4
+ assert len(packets[0]) <= const._MAX_MSG_ABSOLUTE
+ assert len(packets[0]) > const._MAX_MSG_TYPICAL
+
+ assert len(packets[1]) <= const._MAX_MSG_ABSOLUTE
+ assert len(packets[1]) > const._MAX_MSG_TYPICAL
+
+ assert len(packets[2]) <= const._MAX_MSG_ABSOLUTE
+ assert len(packets[2]) > const._MAX_MSG_TYPICAL
+
+ assert len(packets[3]) <= const._MAX_MSG_TYPICAL
+
+ for packet in packets:
+ parsed = r.DNSIncoming(packet)
+ assert len(parsed.answers) == 1
+
+ def test_questions_do_not_end_up_every_packet(self):
+ """Test that questions are not sent again when multiple packets are needed.
+
+ https://datatracker.ietf.org/doc/html/rfc6762#section-7.2
+ Sometimes a Multicast DNS querier will already have too many answers
+ to fit in the Known-Answer Section of its query packets.... It MUST
+ immediately follow the packet with another query packet containing no
+ questions and as many more Known-Answer records as will fit.
+ """
+
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ for i in range(35):
+ question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN)
+ generated.add_question(question)
+ answer = r.DNSService(
+ f"testname{i}.local.",
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL,
+ 0,
+ 0,
+ 80,
+ f"foo{i}.local.",
+ )
+ generated.add_answer_at_time(answer, 0)
+
+ assert len(generated.questions) == 35
+ assert len(generated.answers) == 35
+
+ packets = generated.packets()
+ assert len(packets) == 2
+ assert len(packets[0]) <= const._MAX_MSG_TYPICAL
+ assert len(packets[1]) <= const._MAX_MSG_TYPICAL
+
+ parsed1 = r.DNSIncoming(packets[0])
+ assert len(parsed1.questions) == 35
+ assert len(parsed1.answers) == 33
+
+ parsed2 = r.DNSIncoming(packets[1])
+ assert len(parsed2.questions) == 0
+ assert len(parsed2.answers) == 2
+
+
+class PacketForm(unittest.TestCase):
+ def test_transaction_id(self):
+ """ID must be zero in a DNS-SD packet"""
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ bytes = generated.packets()[0]
+ id = bytes[0] << 8 | bytes[1]
+ assert id == 0
+
+ def test_setting_id(self):
+ """Test setting id in the constructor"""
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY, id_=4444)
+ assert generated.id == 4444
+
+ def test_query_header_bits(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
+ bytes = generated.packets()[0]
+ flags = bytes[2] << 8 | bytes[3]
+ assert flags == 0x0
+
+ def test_response_header_bits(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ bytes = generated.packets()[0]
+ flags = bytes[2] << 8 | bytes[3]
+ assert flags == 0x8000
+
+ def test_numbers(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ bytes = generated.packets()[0]
+ (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12])
+ assert num_questions == 0
+ assert num_answers == 0
+ assert num_authorities == 0
+ assert num_additionals == 0
+
+ def test_numbers_questions(self):
+ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
+ question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN)
+ for i in range(10):
+ generated.add_question(question)
+ bytes = generated.packets()[0]
+ (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12])
+ assert num_questions == 10
+ assert num_answers == 0
+ assert num_authorities == 0
+ assert num_additionals == 0
+
+
+class TestDnsIncoming(unittest.TestCase):
+ def test_incoming_exception_handling(self):
+ generated = r.DNSOutgoing(0)
+ packet = generated.packets()[0]
+ packet = packet[:8] + b'deadbeef' + packet[8:]
+ parsed = r.DNSIncoming(packet)
+ parsed = r.DNSIncoming(packet)
+ assert parsed.valid is False
+
+ def test_incoming_unknown_type(self):
+ generated = r.DNSOutgoing(0)
+ answer = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
+ generated.add_additional_answer(answer)
+ packet = generated.packets()[0]
+ parsed = r.DNSIncoming(packet)
+ assert len(parsed.answers) == 0
+ assert parsed.is_query() != parsed.is_response()
+
+ def test_incoming_circular_reference(self):
+ assert not r.DNSIncoming(
+ bytes.fromhex(
+ '01005e0000fb542a1bf0577608004500006897934000ff11d81bc0a86a31e00000fb'
+ '14e914e90054f9b2000084000000000100000000095f7365727669636573075f646e'
+ '732d7364045f756470056c6f63616c00000c0001000011940018105f73706f746966'
+ '792d636f6e6e656374045f746370c023'
+ )
+ ).valid
+
+ @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
+ @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
+ def test_incoming_ipv6(self):
+ addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com
+ packed = socket.inet_pton(socket.AF_INET6, addr)
+ generated = r.DNSOutgoing(0)
+ answer = r.DNSAddress('domain', const._TYPE_AAAA, const._CLASS_IN | const._CLASS_UNIQUE, 1, packed)
+ generated.add_additional_answer(answer)
+ packet = generated.packets()[0]
+ parsed = r.DNSIncoming(packet)
+ record = parsed.answers[0]
+ assert isinstance(record, r.DNSAddress)
+ assert record.address == packed
+
+
+def test_dns_compression_rollback_for_corruption():
+ """Verify rolling back does not lead to dns compression corruption."""
+ out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA)
+ address = socket.inet_pton(socket.AF_INET, "192.168.208.5")
+
+ additionals = [
+ {
+ "name": "HASS Bridge ZJWH FF5137._hap._tcp.local.",
+ "address": address,
+ "port": 51832,
+ "text": b"\x13md=HASS Bridge"
+ b" ZJWH\x06pv=1.0\x14id=01:6B:30:FF:51:37\x05c#=12\x04s#=1\x04ff=0\x04"
+ b"ci=2\x04sf=0\x0bsh=L0m/aQ==",
+ },
+ {
+ "name": "HASS Bridge 3K9A C2582A._hap._tcp.local.",
+ "address": address,
+ "port": 51834,
+ "text": b"\x13md=HASS Bridge"
+ b" 3K9A\x06pv=1.0\x14id=E2:AA:5B:C2:58:2A\x05c#=12\x04s#=1\x04ff=0\x04"
+ b"ci=2\x04sf=0\x0bsh=b2CnzQ==",
+ },
+ {
+ "name": "Master Bed TV CEDB27._hap._tcp.local.",
+ "address": address,
+ "port": 51830,
+ "text": b"\x10md=Master Bed"
+ b" TV\x06pv=1.0\x14id=9E:B7:44:CE:DB:27\x05c#=18\x04s#=1\x04ff=0\x05"
+ b"ci=31\x04sf=0\x0bsh=CVj1kw==",
+ },
+ {
+ "name": "Living Room TV 921B77._hap._tcp.local.",
+ "address": address,
+ "port": 51833,
+ "text": b"\x11md=Living Room"
+ b" TV\x06pv=1.0\x14id=11:61:E7:92:1B:77\x05c#=17\x04s#=1\x04ff=0\x05"
+ b"ci=31\x04sf=0\x0bsh=qU77SQ==",
+ },
+ {
+ "name": "HASS Bridge ZC8X FF413D._hap._tcp.local.",
+ "address": address,
+ "port": 51829,
+ "text": b"\x13md=HASS Bridge"
+ b" ZC8X\x06pv=1.0\x14id=96:14:45:FF:41:3D\x05c#=12\x04s#=1\x04ff=0\x04"
+ b"ci=2\x04sf=0\x0bsh=b0QZlg==",
+ },
+ {
+ "name": "HASS Bridge WLTF 4BE61F._hap._tcp.local.",
+ "address": address,
+ "port": 51837,
+ "text": b"\x13md=HASS Bridge"
+ b" WLTF\x06pv=1.0\x14id=E0:E7:98:4B:E6:1F\x04c#=2\x04s#=1\x04ff=0\x04"
+ b"ci=2\x04sf=0\x0bsh=ahAISA==",
+ },
+ {
+ "name": "FrontdoorCamera 8941D1._hap._tcp.local.",
+ "address": address,
+ "port": 54898,
+ "text": b"\x12md=FrontdoorCamera\x06pv=1.0\x14id=9F:B7:DC:89:41:D1\x04c#=2\x04"
+ b"s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=0+MXmA==",
+ },
+ {
+ "name": "HASS Bridge W9DN 5B5CC5._hap._tcp.local.",
+ "address": address,
+ "port": 51836,
+ "text": b"\x13md=HASS Bridge"
+ b" W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1\x04ff=0\x04"
+ b"ci=2\x04sf=0\x0bsh=6fLM5A==",
+ },
+ {
+ "name": "HASS Bridge Y9OO EFF0A7._hap._tcp.local.",
+ "address": address,
+ "port": 51838,
+ "text": b"\x13md=HASS Bridge"
+ b" Y9OO\x06pv=1.0\x14id=D3:FE:98:EF:F0:A7\x04c#=2\x04s#=1\x04ff=0\x04"
+ b"ci=2\x04sf=0\x0bsh=u3bdfw==",
+ },
+ {
+ "name": "Snooze Room TV 6B89B0._hap._tcp.local.",
+ "address": address,
+ "port": 51835,
+ "text": b"\x11md=Snooze Room"
+ b" TV\x06pv=1.0\x14id=5F:D5:70:6B:89:B0\x05c#=17\x04s#=1\x04ff=0\x05"
+ b"ci=31\x04sf=0\x0bsh=xNTqsg==",
+ },
+ {
+ "name": "AlexanderHomeAssistant 74651D._hap._tcp.local.",
+ "address": address,
+ "port": 54811,
+ "text": b"\x19md=AlexanderHomeAssistant\x06pv=1.0\x14id=59:8A:0B:74:65:1D\x05"
+ b"c#=14\x04s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=ccZLPA==",
+ },
+ {
+ "name": "HASS Bridge OS95 39C053._hap._tcp.local.",
+ "address": address,
+ "port": 51831,
+ "text": b"\x13md=HASS Bridge"
+ b" OS95\x06pv=1.0\x14id=7E:8C:E6:39:C0:53\x05c#=12\x04s#=1\x04ff=0\x04ci=2"
+ b"\x04sf=0\x0bsh=Xfe5LQ==",
+ },
+ ]
+
+ out.add_answer_at_time(
+ DNSText(
+ "HASS Bridge W9DN 5B5CC5._hap._tcp.local.",
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1'
+ b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
+ ),
+ 0,
+ )
+
+ for record in additionals:
+ out.add_additional_answer(
+ r.DNSService(
+ record["name"], # type: ignore
+ const._TYPE_SRV,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL,
+ 0,
+ 0,
+ record["port"], # type: ignore
+ record["name"], # type: ignore
+ )
+ )
+ out.add_additional_answer(
+ r.DNSText(
+ record["name"], # type: ignore
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ record["text"], # type: ignore
+ )
+ )
+ out.add_additional_answer(
+ r.DNSAddress(
+ record["name"], # type: ignore
+ const._TYPE_A,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_HOST_TTL,
+ record["address"], # type: ignore
+ )
+ )
+
+ for packet in out.packets():
+ # Verify we can process the packets we created to
+ # ensure there is no corruption with the dns compression
+ incoming = r.DNSIncoming(packet)
+ assert incoming.valid is True
+ assert (
+ len(incoming.answers)
+ == incoming.num_answers + incoming.num_authorities + incoming.num_additionals
+ )
+
+
+def test_tc_bit_in_query_packet():
+ """Verify the TC bit is set when known answers exceed the packet size."""
+ out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
+ type_ = "_hap._tcp.local."
+ out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN))
+
+ for i in range(30):
+ out.add_answer_at_time(
+ DNSText(
+ ("HASS Bridge W9DN %s._hap._tcp.local." % i),
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1'
+ b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
+ ),
+ 0,
+ )
+
+ packets = out.packets()
+ assert len(packets) == 3
+
+ first_packet = r.DNSIncoming(packets[0])
+ assert first_packet.truncated
+ assert first_packet.valid is True
+
+ second_packet = r.DNSIncoming(packets[1])
+ assert second_packet.truncated
+ assert second_packet.valid is True
+
+ third_packet = r.DNSIncoming(packets[2])
+ assert not third_packet.truncated
+ assert third_packet.valid is True
+
+
+def test_tc_bit_not_set_in_answer_packet():
+ """Verify the TC bit is not set when there are no questions and answers exceed the packet size."""
+ out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA)
+ for i in range(30):
+ out.add_answer_at_time(
+ DNSText(
+ ("HASS Bridge W9DN %s._hap._tcp.local." % i),
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1'
+ b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
+ ),
+ 0,
+ )
+
+ packets = out.packets()
+ assert len(packets) == 3
+
+ first_packet = r.DNSIncoming(packets[0])
+ assert not first_packet.truncated
+ assert first_packet.valid is True
+
+ second_packet = r.DNSIncoming(packets[1])
+ assert not second_packet.truncated
+ assert second_packet.valid is True
+
+ third_packet = r.DNSIncoming(packets[2])
+ assert not third_packet.truncated
+ assert third_packet.valid is True
+
+
+# 4003 15.973052 192.168.107.68 224.0.0.251 MDNS 76 Standard query 0xffc4 PTR _raop._tcp.local, "QM" question
+def test_qm_packet_parser():
+ """Test we can parse a query packet with the QM bit."""
+ qm_packet = (
+ b'\xff\xc4\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x05_raop\x04_tcp\x05local\x00\x00\x0c\x00\x01'
+ )
+ parsed = DNSIncoming(qm_packet)
+ assert parsed.questions[0].unicast is False
+ assert ",QM," in str(parsed.questions[0])
+
+
+# 389951 1450.577370 192.168.107.111 224.0.0.251 MDNS 115 Standard query 0x0000 PTR _companion-link._tcp.local, "QU" question OPT
+def test_qu_packet_parser():
+ """Test we can parse a query packet with the QU bit."""
+ qu_packet = b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x01\x0f_companion-link\x04_tcp\x05local\x00\x00\x0c\x80\x01\x00\x00)\x05\xa0\x00\x00\x11\x94\x00\x12\x00\x04\x00\x0e\x00dz{\x8a6\x9czF\x84,\xcaQ\xff'
+ parsed = DNSIncoming(qu_packet)
+ assert parsed.questions[0].unicast is True
+ assert ",QU," in str(parsed.questions[0])
+
+
+def test_parse_packet_with_nsec_record():
+ """Test we can parse a packet with an NSEC record."""
+ nsec_packet = (
+ b"\x00\x00\x84\x00\x00\x00\x00\x01\x00\x00\x00\x03\x08_meshcop\x04_udp\x05local\x00\x00\x0c\x00"
+ b"\x01\x00\x00\x11\x94\x00\x0f\x0cMyHome54 (2)\xc0\x0c\xc0+\x00\x10\x80\x01\x00\x00\x11\x94\x00"
+ b")\x0bnn=MyHome54\x13xp=695034D148CC4784\x08tv=0.0.0\xc0+\x00!\x80\x01\x00\x00\x00x\x00\x15\x00"
+ b"\x00\x00\x00\xc0'\x0cMaster-Bed-2\xc0\x1a\xc0+\x00/\x80\x01\x00\x00\x11\x94\x00\t\xc0+\x00\x05"
+ b"\x00\x00\x80\x00@"
+ )
+ parsed = DNSIncoming(nsec_packet)
+ nsec_record = parsed.answers[3]
+ assert "nsec," in str(nsec_record)
+ assert nsec_record.rdtypes == [16, 33]
+ assert nsec_record.next_name == "MyHome54 (2)._meshcop._udp.local."
+
+
+def test_records_same_packet_share_fate():
+ """Test records in the same packet all have the same created time."""
+ out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
+ type_ = "_hap._tcp.local."
+ out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN))
+
+ for i in range(30):
+ out.add_answer_at_time(
+ DNSText(
+ ("HASS Bridge W9DN %s._hap._tcp.local." % i),
+ const._TYPE_TXT,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1'
+ b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==',
+ ),
+ 0,
+ )
+
+ for packet in out.packets():
+ dnsin = DNSIncoming(packet)
+ first_time = dnsin.answers[0].created
+ for answer in dnsin.answers:
+ assert answer.created == first_time
+
+
+def test_dns_compression_invalid_skips_bad_name_compress_in_question():
+ """Test our wire parser can skip bad compression in questions."""
+ packet = (
+ b'\x00\x00\x00\x00\x00\x04\x00\x00\x00\x07\x00\x00\x11homeassistant1128\x05l'
+ b'ocal\x00\x00\xff\x00\x014homeassistant1128 [534a4794e5ed41879ecf012252d3e02'
+ b'a]\x0c_workstation\x04_tcp\xc0\x1e\x00\xff\x00\x014homeassistant1127 [534a47'
+ b'94e5ed41879ecf012252d3e02a]\xc0^\x00\xff\x00\x014homeassistant1123 [534a479'
+ b'4e5ed41879ecf012252d3e02a]\xc0^\x00\xff\x00\x014homeassistant1118 [534a4794'
+ b'e5ed41879ecf012252d3e02a]\xc0^\x00\xff\x00\x01\xc0\x0c\x00\x01\x80'
+ b'\x01\x00\x00\x00x\x00\x04\xc0\xa8<\xc3\xc0v\x00\x10\x80\x01\x00\x00\x00'
+ b'x\x00\x01\x00\xc0v\x00!\x80\x01\x00\x00\x00x\x00\x1f\x00\x00\x00\x00'
+ b'\x00\x00\x11homeassistant1127\x05local\x00\xc0\xb1\x00\x10\x80'
+ b'\x01\x00\x00\x00x\x00\x01\x00\xc0\xb1\x00!\x80\x01\x00\x00\x00x\x00\x1f'
+ b'\x00\x00\x00\x00\x00\x00\x11homeassistant1123\x05local\x00\xc0)\x00\x10\x80'
+ b'\x01\x00\x00\x00x\x00\x01\x00\xc0)\x00!\x80\x01\x00\x00\x00x\x00\x1f'
+ b'\x00\x00\x00\x00\x00\x00\x11homeassistant1128\x05local\x00'
+ )
+ parsed = r.DNSIncoming(packet)
+ assert len(parsed.questions) == 4
+
+
+def test_dns_compression_all_invalid(caplog):
+ """Test our wire parser can skip all invalid data."""
+ packet = (
+ b'\x00\x00\x84\x00\x00\x00\x00\x01\x00\x00\x00\x00!roborock-vacuum-s5e_miio416'
+ b'112328\x00\x00/\x80\x01\x00\x00\x00x\x00\t\xc0P\x00\x05@\x00\x00\x00\x00'
+ )
+ parsed = r.DNSIncoming(packet, ("2.4.5.4", 5353))
+ assert len(parsed.questions) == 0
+ assert len(parsed.answers) == 0
+
+ assert " Unable to parse; skipping record" in caplog.text
+
+
+def test_invalid_next_name_ignored():
+ """Test our wire parser does not throw an an invalid next name.
+
+ The RFC states it should be ignored when used with mDNS.
+ """
+ packet = (
+ b'\x00\x00\x00\x00\x00\x01\x00\x02\x00\x00\x00\x00\x07Android\x05local\x00\x00'
+ b'\xff\x00\x01\xc0\x0c\x00/\x00\x01\x00\x00\x00x\x00\x08\xc02\x00\x04@'
+ b'\x00\x00\x08\xc0\x0c\x00\x01\x00\x01\x00\x00\x00x\x00\x04\xc0\xa8X<'
+ )
+ parsed = r.DNSIncoming(packet)
+ assert len(parsed.questions) == 1
+ assert len(parsed.answers) == 2
+
+
+def test_dns_compression_invalid_skips_record():
+ """Test our wire parser can skip records we do not know how to parse."""
+ packet = (
+ b"\x00\x00\x84\x00\x00\x00\x00\x06\x00\x00\x00\x00\x04_hap\x04_tcp\x05local\x00\x00\x0c"
+ b"\x00\x01\x00\x00\x11\x94\x00\x16\x13eufy HomeBase2-2464\xc0\x0c\x04Eufy\xc0\x16\x00/"
+ b"\x80\x01\x00\x00\x00x\x00\x08\xc0\xa6\x00\x04@\x00\x00\x08\xc0'\x00/\x80\x01\x00\x00"
+ b"\x11\x94\x00\t\xc0'\x00\x05\x00\x00\x80\x00@\xc0=\x00\x01\x80\x01\x00\x00\x00x\x00\x04"
+ b"\xc0\xa8Dp\xc0'\x00!\x80\x01\x00\x00\x00x\x00\x08\x00\x00\x00\x00\xd1_\xc0=\xc0'\x00"
+ b"\x10\x80\x01\x00\x00\x11\x94\x00K\x04c#=1\x04ff=2\x14id=38:71:4F:6B:76:00\x08md=T8010"
+ b"\x06pv=1.1\x05s#=75\x04sf=1\x04ci=2\x0bsh=xaQk4g=="
+ )
+ parsed = r.DNSIncoming(packet)
+ answer = r.DNSNsec(
+ 'eufy HomeBase2-2464._hap._tcp.local.',
+ const._TYPE_NSEC,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ 'eufy HomeBase2-2464._hap._tcp.local.',
+ [const._TYPE_TXT, const._TYPE_SRV],
+ )
+ assert answer in parsed.answers
+
+
+def test_dns_compression_points_forward():
+ """Test our wire parser can unpack nsec records with compression."""
+ packet = (
+ b"\x00\x00\x84\x00\x00\x00\x00\x07\x00\x00\x00\x00\x0eTV Beneden (2)"
+ b"\x10_androidtvremote\x04_tcp\x05local\x00\x00\x10\x80\x01\x00\x00\x11"
+ b"\x94\x00\x15\x14bt=D8:13:99:AC:98:F1\xc0\x0c\x00/\x80\x01\x00\x00\x11"
+ b"\x94\x00\t\xc0\x0c\x00\x05\x00\x00\x80\x00@\tAndroid-3\xc01\x00/\x80"
+ b"\x01\x00\x00\x00x\x00\x08\xc0\x9c\x00\x04@\x00\x00\x08\xc0l\x00\x01\x80"
+ b"\x01\x00\x00\x00x\x00\x04\xc0\xa8X\x0f\xc0\x0c\x00!\x80\x01\x00\x00\x00"
+ b"x\x00\x08\x00\x00\x00\x00\x19B\xc0l\xc0\x1b\x00\x0c\x00\x01\x00\x00\x11"
+ b"\x94\x00\x02\xc0\x0c\t_services\x07_dns-sd\x04_udp\xc01\x00\x0c\x00\x01"
+ b"\x00\x00\x11\x94\x00\x02\xc0\x1b"
+ )
+ parsed = r.DNSIncoming(packet)
+ answer = r.DNSNsec(
+ 'TV Beneden (2)._androidtvremote._tcp.local.',
+ const._TYPE_NSEC,
+ const._CLASS_IN | const._CLASS_UNIQUE,
+ const._DNS_OTHER_TTL,
+ 'TV Beneden (2)._androidtvremote._tcp.local.',
+ [const._TYPE_TXT, const._TYPE_SRV],
+ )
+ assert answer in parsed.answers
+
+
+def test_dns_compression_points_to_itself():
+ """Test our wire parser does not loop forever when a compression pointer points to itself."""
+ packet = (
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01"
+ b"\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\xc0(\x00\x01\x80\x01\x00\x00\x00"
+ b"\x01\x00\x04\xc0\xa8\xd0\x06"
+ )
+ parsed = r.DNSIncoming(packet)
+ assert len(parsed.answers) == 1
+
+
+def test_dns_compression_points_beyond_packet():
+ """Test our wire parser does not fail when the compression pointer points beyond the packet."""
+ packet = (
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01'
+ b'\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\xe7\x0f\x00\x01\x80\x01\x00\x00'
+ b'\x00\x01\x00\x04\xc0\xa8\xd0\x06'
+ )
+ parsed = r.DNSIncoming(packet)
+ assert len(parsed.answers) == 1
+
+
+def test_dns_compression_generic_failure(caplog):
+ """Test our wire parser does not loop forever when dns compression is corrupt."""
+ packet = (
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01'
+ b'\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05-\x0c\x00\x01\x80\x01\x00\x00'
+ b'\x00\x01\x00\x04\xc0\xa8\xd0\x06'
+ )
+ parsed = r.DNSIncoming(packet, ("1.2.3.4", 5353))
+ assert len(parsed.answers) == 1
+ assert "Received invalid packet from ('1.2.3.4', 5353)" in caplog.text
+
+
+def test_label_length_attack():
+ """Test our wire parser does not loop forever when the name exceeds 253 chars."""
+ packet = (
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x01d\x01d\x01d\x01d\x01d\x01d'
+ b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d'
+ b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d'
+ b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d'
+ b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d'
+ b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d'
+ b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d'
+ b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d'
+ b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x00\x00\x01\x80'
+ b'\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\xc0\x0c\x00\x01\x80\x01\x00\x00\x00'
+ b'\x01\x00\x04\xc0\xa8\xd0\x06'
+ )
+ parsed = r.DNSIncoming(packet)
+ assert len(parsed.answers) == 0
+
+
+def test_label_compression_attack():
+ """Test our wire parser does not loop forever when exceeding the maximum number of labels."""
+ packet = (
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x03atk\x00\x00\x01\x80'
+ b'\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03'
+ b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\xc0'
+ b'\x0c\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x06'
+ )
+ parsed = r.DNSIncoming(packet)
+ assert len(parsed.answers) == 1
+
+
+def test_dns_compression_loop_attack():
+ """Test our wire parser does not loop forever when dns compression is in a loop."""
+ packet = (
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07\x03atk\x03dns\x05loc'
+ b'al\xc0\x10\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\x04a'
+ b'tk2\x04dns2\xc0\x14\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05'
+ b'\x04atk3\xc0\x10\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0'
+ b'\x05\x04atk4\x04dns5\xc0\x14\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0'
+ b'\xa8\xd0\x05\x04atk5\x04dns2\xc0^\x00\x01\x80\x01\x00\x00\x00\x01\x00'
+ b'\x04\xc0\xa8\xd0\x05\xc0s\x00\x01\x80\x01\x00\x00\x00\x01\x00'
+ b'\x04\xc0\xa8\xd0\x05\xc0s\x00\x01\x80\x01\x00\x00\x00\x01\x00'
+ b'\x04\xc0\xa8\xd0\x05'
+ )
+ parsed = r.DNSIncoming(packet)
+ assert len(parsed.answers) == 0
+
+
+def test_txt_after_invalid_nsec_name_still_usable():
+ """Test that we can see the txt record after the invalid nsec record."""
+ packet = (
+ b'\x00\x00\x84\x00\x00\x00\x00\x06\x00\x00\x00\x00\x06_sonos\x04_tcp\x05loc'
+ b'al\x00\x00\x0c\x00\x01\x00\x00\x11\x94\x00\x15\x12Sonos-542A1BC9220E'
+ b'\xc0\x0c\x12Sonos-542A1BC9220E\xc0\x18\x00/\x80\x01\x00\x00\x00x\x00'
+ b'\x08\xc1t\x00\x04@\x00\x00\x08\xc0)\x00/\x80\x01\x00\x00\x11\x94\x00'
+ b'\t\xc0)\x00\x05\x00\x00\x80\x00@\xc0)\x00!\x80\x01\x00\x00\x00x'
+ b'\x00\x08\x00\x00\x00\x00\x05\xa3\xc0>\xc0>\x00\x01\x80\x01\x00\x00\x00x'
+ b'\x00\x04\xc0\xa8\x02:\xc0)\x00\x10\x80\x01\x00\x00\x11\x94\x01*2info=/api'
+ b'/v1/players/RINCON_542A1BC9220E01400/info\x06vers=3\x10protovers=1.24.1\nbo'
+ b'otseq=11%hhid=Sonos_rYn9K9DLXJe0f3LP9747lbvFvh;mhhid=Sonos_rYn9K9DLXJe0f3LP9'
+ b'747lbvFvh.Q45RuMaeC07rfXh7OJGm None:
+ nonlocal updates
+ updates.append(record)
+
+ listener = LegacyRecordUpdateListener()
+
+ zc.add_listener(listener, None)
+
+ # dummy service callback
+ def on_service_state_change(zeroconf, service_type, state_change, name):
+ pass
+
+ # start a browser
+ type_ = "_homeassistant._tcp.local."
+ name = "MyTestHome"
+ browser = ServiceBrowser(zc, type_, [on_service_state_change])
+
+ info_service = ServiceInfo(
+ type_,
+ f'{name}.{type_}',
+ 80,
+ 0,
+ 0,
+ {'path': '/~paulsm/'},
+ "ash-2.local.",
+ addresses=[socket.inet_aton("10.0.1.2")],
+ )
+
+ zc.register_service(info_service)
+
+ time.sleep(0.001)
+
+ browser.cancel()
+
+ assert len(updates)
+ assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1
+
+ zc.remove_listener(listener)
+ # Removing a second time should not throw
+ zc.remove_listener(listener)
+
+ zc.close()
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
new file mode 100644
index 00000000..2ef4b15b
--- /dev/null
+++ b/tests/utils/__init__.py
@@ -0,0 +1,21 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
diff --git a/tests/utils/test_asyncio.py b/tests/utils/test_asyncio.py
new file mode 100644
index 00000000..2939b5ab
--- /dev/null
+++ b/tests/utils/test_asyncio.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+
+"""Unit tests for zeroconf._utils.asyncio."""
+
+import asyncio
+import concurrent.futures
+import contextlib
+import threading
+import time
+from unittest.mock import patch
+
+import pytest
+
+from zeroconf._core import _CLOSE_TIMEOUT
+from zeroconf._utils import asyncio as aioutils
+from zeroconf.const import _LOADED_SYSTEM_TIMEOUT
+
+
+@pytest.mark.asyncio
+async def test_async_get_all_tasks() -> None:
+ """Test we can get all tasks in the event loop.
+
+ We make sure we handle RuntimeError here as
+ this is not thread safe under PyPy
+ """
+ await aioutils._async_get_all_tasks(aioutils.get_running_loop())
+ if not hasattr(asyncio, 'all_tasks'):
+ return
+ with patch("zeroconf._utils.asyncio.asyncio.all_tasks", side_effect=RuntimeError):
+ await aioutils._async_get_all_tasks(aioutils.get_running_loop())
+
+
+@pytest.mark.asyncio
+async def test_get_running_loop_from_async() -> None:
+ """Test we can get the event loop."""
+ assert isinstance(aioutils.get_running_loop(), asyncio.AbstractEventLoop)
+
+
+def test_get_running_loop_no_loop() -> None:
+ """Test we get None when there is no loop running."""
+ assert aioutils.get_running_loop() is None
+
+
+@pytest.mark.asyncio
+async def test_wait_event_or_timeout_times_out() -> None:
+ """Test wait_event_or_timeout will timeout."""
+ test_event = asyncio.Event()
+ await aioutils.wait_event_or_timeout(test_event, 0.1)
+
+ task = asyncio.ensure_future(test_event.wait())
+ await asyncio.sleep(0.1)
+
+ async def _async_wait_or_timeout():
+ await aioutils.wait_event_or_timeout(test_event, 0.1)
+
+ # Test high lock contention
+ await asyncio.gather(*[_async_wait_or_timeout() for _ in range(100)])
+
+ task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await task
+
+
+def test_shutdown_loop() -> None:
+ """Test shutting down an event loop."""
+ loop = None
+ loop_thread_ready = threading.Event()
+ runcoro_thread_ready = threading.Event()
+
+ def _run_loop() -> None:
+ nonlocal loop
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ loop_thread_ready.set()
+ loop.run_forever()
+
+ loop_thread = threading.Thread(target=_run_loop, daemon=True)
+ loop_thread.start()
+ loop_thread_ready.wait()
+
+ async def _still_running():
+ await asyncio.sleep(5)
+
+ def _run_coro() -> None:
+ runcoro_thread_ready.set()
+ with contextlib.suppress(concurrent.futures.TimeoutError):
+ asyncio.run_coroutine_threadsafe(_still_running(), loop).result(1)
+
+ runcoro_thread = threading.Thread(target=_run_coro, daemon=True)
+ runcoro_thread.start()
+ runcoro_thread_ready.wait()
+
+ time.sleep(0.1)
+ aioutils.shutdown_loop(loop)
+ for _ in range(5):
+ if not loop.is_running():
+ break
+ time.sleep(0.05)
+
+ assert loop.is_running() is False
+ runcoro_thread.join()
+
+
+def test_cumulative_timeouts_less_than_close_plus_buffer():
+ """Test that the combined async timeouts are shorter than the close timeout with the buffer.
+
+ We want to make sure that the close timeout is the one that gets
+ raised if something goes wrong.
+ """
+ assert (
+ aioutils._TASK_AWAIT_TIMEOUT + aioutils._GET_ALL_TASKS_TIMEOUT + aioutils._WAIT_FOR_LOOP_TASKS_TIMEOUT
+ ) < 1 + _CLOSE_TIMEOUT + _LOADED_SYSTEM_TIMEOUT
diff --git a/tests/utils/test_name.py b/tests/utils/test_name.py
new file mode 100644
index 00000000..6f8b417d
--- /dev/null
+++ b/tests/utils/test_name.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+
+"""Unit tests for zeroconf._utils.name."""
+
+import pytest
+
+from zeroconf._utils import name as nameutils
+from zeroconf import BadTypeInNameException
+
+
+def test_service_type_name_overlong_type():
+ """Test overlong service_type_name type."""
+ with pytest.raises(BadTypeInNameException):
+ nameutils.service_type_name("Tivo1._tivo-videostream._tcp.local.")
+ nameutils.service_type_name("Tivo1._tivo-videostream._tcp.local.", strict=False)
+
+
+def test_service_type_name_overlong_full_name():
+ """Test overlong service_type_name full name."""
+ long_name = "Tivo1Tivo1Tivo1Tivo1Tivo1Tivo1Tivo1Tivo1" * 100
+ with pytest.raises(BadTypeInNameException):
+ nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.")
+ with pytest.raises(BadTypeInNameException):
+ nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.", strict=False)
diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py
new file mode 100644
index 00000000..41fdb7aa
--- /dev/null
+++ b/tests/utils/test_net.py
@@ -0,0 +1,199 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+
+"""Unit tests for zeroconf._utils.net."""
+from unittest.mock import Mock, patch
+
+import errno
+import ifaddr
+import pytest
+import socket
+import unittest
+
+from zeroconf._utils import net as netutils
+import zeroconf as r
+
+
+def _generate_mock_adapters():
+ mock_lo0 = Mock(spec=ifaddr.Adapter)
+ mock_lo0.nice_name = "lo0"
+ mock_lo0.ips = [ifaddr.IP("127.0.0.1", 8, "lo0")]
+ mock_lo0.index = 0
+ mock_eth0 = Mock(spec=ifaddr.Adapter)
+ mock_eth0.nice_name = "eth0"
+ mock_eth0.ips = [ifaddr.IP(("2001:db8::", 1, 1), 8, "eth0")]
+ mock_eth0.index = 1
+ mock_eth1 = Mock(spec=ifaddr.Adapter)
+ mock_eth1.nice_name = "eth1"
+ mock_eth1.ips = [ifaddr.IP("192.168.1.5", 23, "eth1")]
+ mock_eth1.index = 2
+ mock_vtun0 = Mock(spec=ifaddr.Adapter)
+ mock_vtun0.nice_name = "vtun0"
+ mock_vtun0.ips = [ifaddr.IP("169.254.3.2", 16, "vtun0")]
+ mock_vtun0.index = 3
+ return [mock_eth0, mock_lo0, mock_eth1, mock_vtun0]
+
+
+def test_ip6_to_address_and_index():
+ """Test we can extract from mocked adapters."""
+ adapters = _generate_mock_adapters()
+ assert netutils.ip6_to_address_and_index(adapters, "2001:db8::") == (('2001:db8::', 1, 1), 1)
+ with pytest.raises(RuntimeError):
+ assert netutils.ip6_to_address_and_index(adapters, "2005:db8::")
+
+
+def test_interface_index_to_ip6_address():
+ """Test we can extract from mocked adapters."""
+ adapters = _generate_mock_adapters()
+ assert netutils.interface_index_to_ip6_address(adapters, 1) == ('2001:db8::', 1, 1)
+
+ # call with invalid adapter
+ with pytest.raises(RuntimeError):
+ assert netutils.interface_index_to_ip6_address(adapters, 6)
+
+ # call with adapter that has ipv4 address only
+ with pytest.raises(RuntimeError):
+ assert netutils.interface_index_to_ip6_address(adapters, 2)
+
+
+def test_ip6_addresses_to_indexes():
+ """Test we can extract from mocked adapters."""
+ interfaces = [1]
+ with patch("zeroconf._utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()):
+ assert netutils.ip6_addresses_to_indexes(interfaces) == [(('2001:db8::', 1, 1), 1)]
+
+ interfaces = ['2001:db8::']
+ with patch("zeroconf._utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()):
+ assert netutils.ip6_addresses_to_indexes(interfaces) == [(('2001:db8::', 1, 1), 1)]
+
+
+def test_normalize_interface_choice_errors():
+ """Test we generate exception on invalid input."""
+ with patch("zeroconf._utils.net.get_all_addresses", return_value=[]), patch(
+ "zeroconf._utils.net.get_all_addresses_v6", return_value=[]
+ ), pytest.raises(RuntimeError):
+ netutils.normalize_interface_choice(r.InterfaceChoice.All)
+
+ with pytest.raises(TypeError):
+ netutils.normalize_interface_choice("1.2.3.4")
+
+
+@pytest.mark.parametrize(
+ "errno,expected_result",
+ [(errno.EADDRINUSE, False), (errno.EADDRNOTAVAIL, False), (errno.EINVAL, False), (0, True)],
+)
+def test_add_multicast_member_socket_errors(errno, expected_result):
+ """Test we handle socket errors when adding multicast members."""
+ if errno:
+ setsockopt_mock = unittest.mock.Mock(side_effect=OSError(errno, "Error: {}".format(errno)))
+ else:
+ setsockopt_mock = unittest.mock.Mock()
+ fileno_mock = unittest.mock.PropertyMock(return_value=10)
+ socket_mock = unittest.mock.Mock(setsockopt=setsockopt_mock, fileno=fileno_mock)
+ assert r.add_multicast_member(socket_mock, "0.0.0.0") == expected_result
+
+
+def test_autodetect_ip_version():
+ """Tests for auto detecting IPVersion based on interface ips."""
+ assert r.autodetect_ip_version(["1.3.4.5"]) is r.IPVersion.V4Only
+ assert r.autodetect_ip_version([]) is r.IPVersion.V4Only
+ assert r.autodetect_ip_version(["::1", "1.2.3.4"]) is r.IPVersion.All
+ assert r.autodetect_ip_version(["::1"]) is r.IPVersion.V6Only
+
+
+def test_disable_ipv6_only_or_raise():
+ """Test that IPV6_V6ONLY failing logs a nice error message and still raises."""
+ errors_logged = []
+
+ def _log_error(*args):
+ nonlocal errors_logged
+ errors_logged.append(args)
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ with pytest.raises(OSError), patch.object(netutils.log, "error", _log_error), patch(
+ "socket.socket.setsockopt", side_effect=OSError
+ ):
+ netutils.disable_ipv6_only_or_raise(sock)
+
+ assert (
+ errors_logged[0][0]
+ == 'Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6'
+ )
+
+
+@pytest.mark.skipif(not hasattr(socket, 'SO_REUSEPORT'), reason="System does not have SO_REUSEPORT")
+def test_set_so_reuseport_if_available_is_present():
+ """Test that setting socket.SO_REUSEPORT only OSError errno.ENOPROTOOPT is trapped."""
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError):
+ netutils.set_so_reuseport_if_available(sock)
+
+ with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENOPROTOOPT, None)):
+ netutils.set_so_reuseport_if_available(sock)
+
+
+@pytest.mark.skipif(hasattr(socket, 'SO_REUSEPORT'), reason="System has SO_REUSEPORT")
+def test_set_so_reuseport_if_available_not_present():
+ """Test that we do not try to set SO_REUSEPORT if it is not present."""
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ with patch("socket.socket.setsockopt", side_effect=OSError):
+ netutils.set_so_reuseport_if_available(sock)
+
+
+def test_set_mdns_port_socket_options_for_ip_version():
+ """Test OSError with errno with EINVAL and bind address '' from setsockopt IP_MULTICAST_TTL does not raise."""
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+
+ # Should raise on EPERM always
+ with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.EPERM, None)):
+ netutils.set_mdns_port_socket_options_for_ip_version(sock, ('',), r.IPVersion.V4Only)
+
+ # Should raise on EINVAL always when bind address is not ''
+ with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.EINVAL, None)):
+ netutils.set_mdns_port_socket_options_for_ip_version(sock, ('127.0.0.1',), r.IPVersion.V4Only)
+
+ # Should not raise on EINVAL when bind address is ''
+ with patch("socket.socket.setsockopt", side_effect=OSError(errno.EINVAL, None)):
+ netutils.set_mdns_port_socket_options_for_ip_version(sock, ('',), r.IPVersion.V4Only)
+
+
+def test_add_multicast_member():
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ interface = '127.0.0.1'
+
+ # EPERM should always raise
+ with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.EPERM, None)):
+ netutils.add_multicast_member(sock, interface)
+
+ # EADDRINUSE should return False
+ with patch("socket.socket.setsockopt", side_effect=OSError(errno.EADDRINUSE, None)):
+ assert netutils.add_multicast_member(sock, interface) is False
+
+ # EADDRNOTAVAIL should return False
+ with patch("socket.socket.setsockopt", side_effect=OSError(errno.EADDRNOTAVAIL, None)):
+ assert netutils.add_multicast_member(sock, interface) is False
+
+ # EINVAL should return False
+ with patch("socket.socket.setsockopt", side_effect=OSError(errno.EINVAL, None)):
+ assert netutils.add_multicast_member(sock, interface) is False
+
+ # ENOPROTOOPT should return False
+ with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENOPROTOOPT, None)):
+ assert netutils.add_multicast_member(sock, interface) is False
+
+ # ENODEV should raise for ipv4
+ with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.ENODEV, None)):
+ netutils.add_multicast_member(sock, interface) is False
+
+ # ENODEV should return False for ipv6
+ with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENODEV, None)):
+ assert netutils.add_multicast_member(sock, ('2001:db8::', 1, 1)) is False
+
+ # No IPv6 support should return False for IPv6
+ with patch("socket.inet_pton", side_effect=OSError()):
+ assert netutils.add_multicast_member(sock, ('2001:db8::', 1, 1)) is False
+
+ # No error should return True
+ with patch("socket.socket.setsockopt"):
+ assert netutils.add_multicast_member(sock, interface) is True
diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py
index 35fac80a..4821cbb8 100644
--- a/zeroconf/__init__.py
+++ b/zeroconf/__init__.py
@@ -20,2608 +20,88 @@
USA
"""
-import enum
-import errno
-import ipaddress
-import itertools
-import logging
-import os
-import platform
-import re
-import select
-import socket
-import struct
import sys
-import threading
-import time
-import warnings
-from typing import Dict, List, Optional, Sequence, Union, cast
-from typing import Any, Callable, Set, Tuple # noqa # used in type hints
-import ifaddr
+from ._cache import DNSCache # noqa # import needed for backwards compat
+from ._core import Zeroconf # noqa # import needed for backwards compat
+from ._dns import ( # noqa # import needed for backwards compat
+ DNSAddress,
+ DNSEntry,
+ DNSHinfo,
+ DNSNsec,
+ DNSPointer,
+ DNSQuestion,
+ DNSRecord,
+ DNSService,
+ DNSText,
+ DNSQuestionType,
+)
+from ._logger import QuietLogger, log # noqa # import needed for backwards compat
+from ._exceptions import ( # noqa # import needed for backwards compat
+ AbstractMethodException,
+ BadTypeInNameException,
+ Error,
+ IncomingDecodeError,
+ NamePartTooLongException,
+ NonUniqueNameException,
+ ServiceNameAlreadyRegistered,
+)
+from ._protocol.incoming import DNSIncoming # noqa # import needed for backwards compat
+from ._protocol.outgoing import DNSOutgoing # noqa # import needed for backwards compat
+from ._services import ( # noqa # import needed for backwards compat
+ Signal,
+ SignalRegistrationInterface,
+ ServiceListener,
+ ServiceStateChange,
+)
+from ._services.browser import ( # noqa # import needed for backwards compat
+ ServiceBrowser,
+)
+from ._services.info import ( # noqa # import needed for backwards compat
+ instance_name_from_service_info,
+ ServiceInfo,
+)
+from ._services.registry import ServiceRegistry # noqa # import needed for backwards compat
+from ._services.types import ZeroconfServiceTypes
+from ._updates import RecordUpdate, RecordUpdateListener # noqa # import needed for backwards compat
+from ._utils.name import service_type_name # noqa # import needed for backwards compat
+from ._utils.net import ( # noqa # import needed for backwards compat
+ add_multicast_member,
+ autodetect_ip_version,
+ create_sockets,
+ get_all_addresses_v6,
+ InterfaceChoice,
+ InterfacesType,
+ IPVersion,
+ get_all_addresses,
+)
+from ._utils.time import current_time_millis, millis_to_seconds # noqa # import needed for backwards compat
__author__ = 'Paul Scott-Murphy, William McBrine'
__maintainer__ = 'Jakub Stasiak '
-__version__ = '0.24.4'
+__version__ = '0.36.7'
__license__ = 'LGPL'
__all__ = [
"__version__",
+ "DNSQuestionType",
"Zeroconf",
"ServiceInfo",
"ServiceBrowser",
+ "ServiceListener",
"Error",
"InterfaceChoice",
"ServiceStateChange",
"IPVersion",
+ "ZeroconfServiceTypes",
]
-if sys.version_info <= (3, 3):
- raise ImportError(
+if sys.version_info <= (3, 6): # pragma: no cover
+ raise ImportError( # pragma: no cover
'''
-Python version > 3.3 required for python-zeroconf.
-If you need support for Python 2 or Python 3.3 please use version 19.1
+Python version > 3.6 required for python-zeroconf.
+If you need support for Python 2 or Python 3.3-3.4 please use version 19.1
+If you need support for Python 3.5 please use version 0.28.0
'''
)
-
-log = logging.getLogger(__name__)
-log.addHandler(logging.NullHandler())
-
-if log.level == logging.NOTSET:
- log.setLevel(logging.WARN)
-
-# Some timing constants
-
-_UNREGISTER_TIME = 125 # ms
-_CHECK_TIME = 175 # ms
-_REGISTER_TIME = 225 # ms
-_LISTENER_TIME = 200 # ms
-_BROWSER_TIME = 1000 # ms
-_BROWSER_BACKOFF_LIMIT = 3600 # s
-
-# Some DNS constants
-
-_MDNS_ADDR = '224.0.0.251'
-_MDNS_ADDR_BYTES = socket.inet_aton(_MDNS_ADDR)
-_MDNS_ADDR6 = 'ff02::fb'
-_MDNS_ADDR6_BYTES = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6)
-_MDNS_PORT = 5353
-_DNS_PORT = 53
-_DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762
-_DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762
-
-_MAX_MSG_TYPICAL = 1460 # unused
-_MAX_MSG_ABSOLUTE = 8966
-
-_FLAGS_QR_MASK = 0x8000 # query response mask
-_FLAGS_QR_QUERY = 0x0000 # query
-_FLAGS_QR_RESPONSE = 0x8000 # response
-
-_FLAGS_AA = 0x0400 # Authoritative answer
-_FLAGS_TC = 0x0200 # Truncated
-_FLAGS_RD = 0x0100 # Recursion desired
-_FLAGS_RA = 0x8000 # Recursion available
-
-_FLAGS_Z = 0x0040 # Zero
-_FLAGS_AD = 0x0020 # Authentic data
-_FLAGS_CD = 0x0010 # Checking disabled
-
-_CLASS_IN = 1
-_CLASS_CS = 2
-_CLASS_CH = 3
-_CLASS_HS = 4
-_CLASS_NONE = 254
-_CLASS_ANY = 255
-_CLASS_MASK = 0x7FFF
-_CLASS_UNIQUE = 0x8000
-
-_TYPE_A = 1
-_TYPE_NS = 2
-_TYPE_MD = 3
-_TYPE_MF = 4
-_TYPE_CNAME = 5
-_TYPE_SOA = 6
-_TYPE_MB = 7
-_TYPE_MG = 8
-_TYPE_MR = 9
-_TYPE_NULL = 10
-_TYPE_WKS = 11
-_TYPE_PTR = 12
-_TYPE_HINFO = 13
-_TYPE_MINFO = 14
-_TYPE_MX = 15
-_TYPE_TXT = 16
-_TYPE_AAAA = 28
-_TYPE_SRV = 33
-_TYPE_ANY = 255
-
-# Mapping constants to names
-
-_CLASSES = {
- _CLASS_IN: "in",
- _CLASS_CS: "cs",
- _CLASS_CH: "ch",
- _CLASS_HS: "hs",
- _CLASS_NONE: "none",
- _CLASS_ANY: "any",
-}
-
-_TYPES = {
- _TYPE_A: "a",
- _TYPE_NS: "ns",
- _TYPE_MD: "md",
- _TYPE_MF: "mf",
- _TYPE_CNAME: "cname",
- _TYPE_SOA: "soa",
- _TYPE_MB: "mb",
- _TYPE_MG: "mg",
- _TYPE_MR: "mr",
- _TYPE_NULL: "null",
- _TYPE_WKS: "wks",
- _TYPE_PTR: "ptr",
- _TYPE_HINFO: "hinfo",
- _TYPE_MINFO: "minfo",
- _TYPE_MX: "mx",
- _TYPE_TXT: "txt",
- _TYPE_AAAA: "quada",
- _TYPE_SRV: "srv",
- _TYPE_ANY: "any",
-}
-
-_HAS_A_TO_Z = re.compile(r'[A-Za-z]')
-_HAS_ONLY_A_TO_Z_NUM_HYPHEN = re.compile(r'^[A-Za-z0-9\-]+$')
-_HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$')
-_HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]')
-
-try:
- _IPPROTO_IPV6 = socket.IPPROTO_IPV6
-except AttributeError:
- # Sigh: https://bugs.python.org/issue29515
- _IPPROTO_IPV6 = 41
-
-int2byte = struct.Struct(">B").pack
-
-
-@enum.unique
-class InterfaceChoice(enum.Enum):
- Default = 1
- All = 2
-
-
-InterfacesType = Union[List[Union[str, int]], InterfaceChoice]
-
-
-@enum.unique
-class ServiceStateChange(enum.Enum):
- Added = 1
- Removed = 2
- Updated = 3
-
-
-@enum.unique
-class IPVersion(enum.Enum):
- V4Only = 1
- V6Only = 2
- All = 3
-
-
-# utility functions
-
-
-def current_time_millis() -> float:
- """Current system time in milliseconds"""
- return time.time() * 1000
-
-
-def _is_v6_address(addr: bytes) -> bool:
- return len(addr) == 16
-
-
-def service_type_name(type_: str, *, allow_underscores: bool = False) -> str:
- """
- Validate a fully qualified service name, instance or subtype. [rfc6763]
-
- Returns fully qualified service name.
-
- Domain names used by mDNS-SD take the following forms:
-
- . <_tcp|_udp> . local.
- . . <_tcp|_udp> . local.
- ._sub . . <_tcp|_udp> . local.
-
- 1) must end with 'local.'
-
- This is true because we are implementing mDNS and since the 'm' means
- multi-cast, the 'local.' domain is mandatory.
-
- 2) local is preceded with either '_udp.' or '_tcp.'
-
- 3) service name precedes <_tcp|_udp>
-
- The rules for Service Names [RFC6335] state that they may be no more
- than fifteen characters long (not counting the mandatory underscore),
- consisting of only letters, digits, and hyphens, must begin and end
- with a letter or digit, must not contain consecutive hyphens, and
- must contain at least one letter.
-
- The instance name and sub type may be up to 63 bytes.
-
- The portion of the Service Instance Name is a user-
- friendly name consisting of arbitrary Net-Unicode text [RFC5198]. It
- MUST NOT contain ASCII control characters (byte values 0x00-0x1F and
- 0x7F) [RFC20] but otherwise is allowed to contain any characters,
- without restriction, including spaces, uppercase, lowercase,
- punctuation -- including dots -- accented characters, non-Roman text,
- and anything else that may be represented using Net-Unicode.
-
- :param type_: Type, SubType or service name to validate
- :return: fully qualified service name (eg: _http._tcp.local.)
- """
- if not (type_.endswith('._tcp.local.') or type_.endswith('._udp.local.')):
- raise BadTypeInNameException("Type '%s' must end with '._tcp.local.' or '._udp.local.'" % type_)
-
- remaining = type_[: -len('._tcp.local.')].split('.')
- name = remaining.pop()
- if not name:
- raise BadTypeInNameException("No Service name found")
-
- if len(remaining) == 1 and len(remaining[0]) == 0:
- raise BadTypeInNameException("Type '%s' must not start with '.'" % type_)
-
- if name[0] != '_':
- raise BadTypeInNameException("Service name (%s) must start with '_'" % name)
-
- # remove leading underscore
- name = name[1:]
-
- if len(name) > 15:
- raise BadTypeInNameException("Service name (%s) must be <= 15 bytes" % name)
-
- if '--' in name:
- raise BadTypeInNameException("Service name (%s) must not contain '--'" % name)
-
- if '-' in (name[0], name[-1]):
- raise BadTypeInNameException("Service name (%s) may not start or end with '-'" % name)
-
- if not _HAS_A_TO_Z.search(name):
- raise BadTypeInNameException("Service name (%s) must contain at least one letter (eg: 'A-Z')" % name)
-
- allowed_characters_re = (
- _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE if allow_underscores else _HAS_ONLY_A_TO_Z_NUM_HYPHEN
- )
-
- if not allowed_characters_re.search(name):
- raise BadTypeInNameException(
- "Service name (%s) must contain only these characters: "
- "A-Z, a-z, 0-9, hyphen ('-')%s" % (name, ", underscore ('_')" if allow_underscores else "")
- )
-
- if remaining and remaining[-1] == '_sub':
- remaining.pop()
- if len(remaining) == 0 or len(remaining[0]) == 0:
- raise BadTypeInNameException("_sub requires a subtype name")
-
- if len(remaining) > 1:
- remaining = ['.'.join(remaining)]
-
- if remaining:
- length = len(remaining[0].encode('utf-8'))
- if length > 63:
- raise BadTypeInNameException("Too long: '%s'" % remaining[0])
-
- if _HAS_ASCII_CONTROL_CHARS.search(remaining[0]):
- raise BadTypeInNameException(
- "Ascii control character 0x00-0x1F and 0x7F illegal in '%s'" % remaining[0]
- )
-
- return '_' + name + type_[-len('._tcp.local.') :]
-
-
-# Exceptions
-
-
-class Error(Exception):
- pass
-
-
-class IncomingDecodeError(Error):
- pass
-
-
-class NonUniqueNameException(Error):
- pass
-
-
-class NamePartTooLongException(Error):
- pass
-
-
-class AbstractMethodException(Error):
- pass
-
-
-class BadTypeInNameException(Error):
- pass
-
-
-# implementation classes
-
-
-class QuietLogger:
- _seen_logs = {} # type: Dict[str, Union[int, tuple]]
-
- @classmethod
- def log_exception_warning(cls, logger_data: Optional[Tuple] = None) -> None:
- exc_info = sys.exc_info()
- exc_str = str(exc_info[1])
- if exc_str not in cls._seen_logs:
- # log at warning level the first time this is seen
- cls._seen_logs[exc_str] = exc_info
- logger = log.warning
- else:
- logger = log.debug
- if logger_data is not None:
- logger(*logger_data)
- logger('Exception occurred:', exc_info=True)
-
- @classmethod
- def log_warning_once(cls, *args: Any) -> None:
- msg_str = args[0]
- if msg_str not in cls._seen_logs:
- cls._seen_logs[msg_str] = 0
- logger = log.warning
- else:
- logger = log.debug
- cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1
- logger(*args)
-
-
-class DNSEntry:
-
- """A DNS entry"""
-
- def __init__(self, name: str, type_: int, class_: int) -> None:
- self.key = name.lower()
- self.name = name
- self.type = type_
- self.class_ = class_ & _CLASS_MASK
- self.unique = (class_ & _CLASS_UNIQUE) != 0
-
- def __eq__(self, other: Any) -> bool:
- """Equality test on name, type, and class"""
- return (
- self.name == other.name
- and self.type == other.type
- and self.class_ == other.class_
- and isinstance(other, DNSEntry)
- )
-
- def __ne__(self, other: Any) -> bool:
- """Non-equality test"""
- return not self.__eq__(other)
-
- @staticmethod
- def get_class_(class_: int) -> str:
- """Class accessor"""
- return _CLASSES.get(class_, "?(%s)" % class_)
-
- @staticmethod
- def get_type(t: int) -> str:
- """Type accessor"""
- return _TYPES.get(t, "?(%s)" % t)
-
- def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str:
- """String representation with additional information"""
- result = "%s[%s,%s" % (hdr, self.get_type(self.type), self.get_class_(self.class_))
- if self.unique:
- result += "-unique,"
- else:
- result += ","
- result += self.name
- if other is not None:
- result += "]=%s" % cast(Any, other)
- else:
- result += "]"
- return result
-
-
-class DNSQuestion(DNSEntry):
-
- """A DNS question entry"""
-
- def __init__(self, name: str, type_: int, class_: int) -> None:
- DNSEntry.__init__(self, name, type_, class_)
-
- def answered_by(self, rec: 'DNSRecord') -> bool:
- """Returns true if the question is answered by the record"""
- return (
- self.class_ == rec.class_
- and (self.type == rec.type or self.type == _TYPE_ANY)
- and self.name == rec.name
- )
-
- def __repr__(self) -> str:
- """String representation"""
- return DNSEntry.entry_to_string(self, "question", None)
-
-
-class DNSRecord(DNSEntry):
-
- """A DNS record - like a DNS entry, but has a TTL"""
-
- # TODO: Switch to just int ttl
- def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None:
- DNSEntry.__init__(self, name, type_, class_)
- self.ttl = ttl
- self.created = current_time_millis()
- self._expiration_time = self.get_expiration_time(100)
- self._stale_time = self.get_expiration_time(50)
-
- def __eq__(self, other: Any) -> bool:
- """Abstract method"""
- raise AbstractMethodException
-
- def __ne__(self, other: Any) -> bool:
- """Non-equality test"""
- return not self.__eq__(other)
-
- def suppressed_by(self, msg: 'DNSIncoming') -> bool:
- """Returns true if any answer in a message can suffice for the
- information held in this record."""
- for record in msg.answers:
- if self.suppressed_by_answer(record):
- return True
- return False
-
- def suppressed_by_answer(self, other: 'DNSRecord') -> bool:
- """Returns true if another record has same name, type and class,
- and if its TTL is at least half of this record's."""
- return self == other and other.ttl > (self.ttl / 2)
-
- def get_expiration_time(self, percent: int) -> float:
- """Returns the time at which this record will have expired
- by a certain percentage."""
- return self.created + (percent * self.ttl * 10)
-
- # TODO: Switch to just int here
- def get_remaining_ttl(self, now: float) -> Union[int, float]:
- """Returns the remaining TTL in seconds."""
- return max(0, (self._expiration_time - now) / 1000.0)
-
- def is_expired(self, now: float) -> bool:
- """Returns true if this record has expired."""
- return self._expiration_time <= now
-
- def is_stale(self, now: float) -> bool:
- """Returns true if this record is at least half way expired."""
- return self._stale_time <= now
-
- def reset_ttl(self, other: 'DNSRecord') -> None:
- """Sets this record's TTL and created time to that of
- another record."""
- self.created = other.created
- self.ttl = other.ttl
- self._expiration_time = self.get_expiration_time(100)
- self._stale_time = self.get_expiration_time(50)
-
- def write(self, out: 'DNSOutgoing') -> None:
- """Abstract method"""
- raise AbstractMethodException
-
- def to_string(self, other: Union[bytes, str]) -> str:
- """String representation with additional information"""
- arg = "%s/%s,%s" % (self.ttl, int(self.get_remaining_ttl(current_time_millis())), cast(Any, other))
- return DNSEntry.entry_to_string(self, "record", arg)
-
-
-class DNSAddress(DNSRecord):
-
- """A DNS address record"""
-
- def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None:
- DNSRecord.__init__(self, name, type_, class_, ttl)
- self.address = address
-
- def write(self, out: 'DNSOutgoing') -> None:
- """Used in constructing an outgoing packet"""
- out.write_string(self.address)
-
- def __eq__(self, other: Any) -> bool:
- """Tests equality on address"""
- return (
- isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and self.address == other.address
- )
-
- def __ne__(self, other: Any) -> bool:
- """Non-equality test"""
- return not self.__eq__(other)
-
- def __repr__(self) -> str:
- """String representation"""
- try:
- return self.to_string(str(socket.inet_ntoa(self.address)))
- except Exception: # TODO stop catching all Exceptions
- return self.to_string(str(self.address))
-
-
-class DNSHinfo(DNSRecord):
-
- """A DNS host information record"""
-
- def __init__(
- self, name: str, type_: int, class_: int, ttl: int, cpu: Union[bytes, str], os: Union[bytes, str]
- ) -> None:
- DNSRecord.__init__(self, name, type_, class_, ttl)
- try:
- self.cpu = cast(bytes, cpu).decode('utf-8')
- except AttributeError:
- self.cpu = cast(str, cpu)
- try:
- self.os = cast(bytes, os).decode('utf-8')
- except AttributeError:
- self.os = cast(str, os)
-
- def write(self, out: 'DNSOutgoing') -> None:
- """Used in constructing an outgoing packet"""
- out.write_character_string(self.cpu.encode('utf-8'))
- out.write_character_string(self.os.encode('utf-8'))
-
- def __eq__(self, other: Any) -> bool:
- """Tests equality on cpu and os"""
- return (
- isinstance(other, DNSHinfo)
- and DNSEntry.__eq__(self, other)
- and self.cpu == other.cpu
- and self.os == other.os
- )
-
- def __ne__(self, other: Any) -> bool:
- """Non-equality test"""
- return not self.__eq__(other)
-
- def __repr__(self) -> str:
- """String representation"""
- return self.to_string(self.cpu + " " + self.os)
-
-
-class DNSPointer(DNSRecord):
-
- """A DNS pointer record"""
-
- def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None:
- DNSRecord.__init__(self, name, type_, class_, ttl)
- self.alias = alias
-
- def write(self, out: 'DNSOutgoing') -> None:
- """Used in constructing an outgoing packet"""
- out.write_name(self.alias)
-
- def __eq__(self, other: Any) -> bool:
- """Tests equality on alias"""
- return isinstance(other, DNSPointer) and self.alias == other.alias and DNSEntry.__eq__(self, other)
-
- def __ne__(self, other: Any) -> bool:
- """Non-equality test"""
- return not self.__eq__(other)
-
- def __repr__(self) -> str:
- """String representation"""
- return self.to_string(self.alias)
-
-
-class DNSText(DNSRecord):
-
- """A DNS text record"""
-
- def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None:
- assert isinstance(text, (bytes, type(None)))
- DNSRecord.__init__(self, name, type_, class_, ttl)
- self.text = text
-
- def write(self, out: 'DNSOutgoing') -> None:
- """Used in constructing an outgoing packet"""
- out.write_string(self.text)
-
- def __eq__(self, other: Any) -> bool:
- """Tests equality on text"""
- return isinstance(other, DNSText) and self.text == other.text and DNSEntry.__eq__(self, other)
-
- def __ne__(self, other: Any) -> bool:
- """Non-equality test"""
- return not self.__eq__(other)
-
- def __repr__(self) -> str:
- """String representation"""
- if len(self.text) > 10:
- return self.to_string(self.text[:7]) + "..."
- else:
- return self.to_string(self.text)
-
-
-class DNSService(DNSRecord):
-
- """A DNS service record"""
-
- def __init__(
- self,
- name: str,
- type_: int,
- class_: int,
- ttl: Union[float, int],
- priority: int,
- weight: int,
- port: int,
- server: str,
- ) -> None:
- DNSRecord.__init__(self, name, type_, class_, ttl)
- self.priority = priority
- self.weight = weight
- self.port = port
- self.server = server
-
- def write(self, out: 'DNSOutgoing') -> None:
- """Used in constructing an outgoing packet"""
- out.write_short(self.priority)
- out.write_short(self.weight)
- out.write_short(self.port)
- out.write_name(self.server)
-
- def __eq__(self, other: Any) -> bool:
- """Tests equality on priority, weight, port and server"""
- return (
- isinstance(other, DNSService)
- and self.priority == other.priority
- and self.weight == other.weight
- and self.port == other.port
- and self.server == other.server
- and DNSEntry.__eq__(self, other)
- )
-
- def __ne__(self, other: Any) -> bool:
- """Non-equality test"""
- return not self.__eq__(other)
-
- def __repr__(self) -> str:
- """String representation"""
- return self.to_string("%s:%s" % (self.server, self.port))
-
-
-class DNSIncoming(QuietLogger):
-
- """Object representation of an incoming DNS packet"""
-
- def __init__(self, data: bytes) -> None:
- """Constructor from string holding bytes of packet"""
- self.offset = 0
- self.data = data
- self.questions = [] # type: List[DNSQuestion]
- self.answers = [] # type: List[DNSRecord]
- self.id = 0
- self.flags = 0 # type: int
- self.num_questions = 0
- self.num_answers = 0
- self.num_authorities = 0
- self.num_additionals = 0
- self.valid = False
-
- try:
- self.read_header()
- self.read_questions()
- self.read_others()
- self.valid = True
-
- except (IndexError, struct.error, IncomingDecodeError):
- self.log_exception_warning(('Choked at offset %d while unpacking %r', self.offset, data))
-
- def unpack(self, format_: bytes) -> tuple:
- length = struct.calcsize(format_)
- info = struct.unpack(format_, self.data[self.offset : self.offset + length])
- self.offset += length
- return info
-
- def read_header(self) -> None:
- """Reads header portion of packet"""
- (
- self.id,
- self.flags,
- self.num_questions,
- self.num_answers,
- self.num_authorities,
- self.num_additionals,
- ) = self.unpack(b'!6H')
-
- def read_questions(self) -> None:
- """Reads questions section of packet"""
- for i in range(self.num_questions):
- name = self.read_name()
- type_, class_ = self.unpack(b'!HH')
-
- question = DNSQuestion(name, type_, class_)
- self.questions.append(question)
-
- # def read_int(self):
- # """Reads an integer from the packet"""
- # return self.unpack(b'!I')[0]
-
- def read_character_string(self) -> bytes:
- """Reads a character string from the packet"""
- length = self.data[self.offset]
- self.offset += 1
- return self.read_string(length)
-
- def read_string(self, length: int) -> bytes:
- """Reads a string of a given length from the packet"""
- info = self.data[self.offset : self.offset + length]
- self.offset += length
- return info
-
- def read_unsigned_short(self) -> int:
- """Reads an unsigned short from the packet"""
- return cast(int, self.unpack(b'!H')[0])
-
- def read_others(self) -> None:
- """Reads the answers, authorities and additionals section of the
- packet"""
- n = self.num_answers + self.num_authorities + self.num_additionals
- for i in range(n):
- domain = self.read_name()
- type_, class_, ttl, length = self.unpack(b'!HHiH')
-
- rec = None # type: Optional[DNSRecord]
- if type_ == _TYPE_A:
- rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4))
- elif type_ == _TYPE_CNAME or type_ == _TYPE_PTR:
- rec = DNSPointer(domain, type_, class_, ttl, self.read_name())
- elif type_ == _TYPE_TXT:
- rec = DNSText(domain, type_, class_, ttl, self.read_string(length))
- elif type_ == _TYPE_SRV:
- rec = DNSService(
- domain,
- type_,
- class_,
- ttl,
- self.read_unsigned_short(),
- self.read_unsigned_short(),
- self.read_unsigned_short(),
- self.read_name(),
- )
- elif type_ == _TYPE_HINFO:
- rec = DNSHinfo(
- domain, type_, class_, ttl, self.read_character_string(), self.read_character_string()
- )
- elif type_ == _TYPE_AAAA:
- rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16))
- else:
- # Try to ignore types we don't know about
- # Skip the payload for the resource record so the next
- # records can be parsed correctly
- self.offset += length
-
- if rec is not None:
- self.answers.append(rec)
-
- def is_query(self) -> bool:
- """Returns true if this is a query"""
- return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY
-
- def is_response(self) -> bool:
- """Returns true if this is a response"""
- return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE
-
- def read_utf(self, offset: int, length: int) -> str:
- """Reads a UTF-8 string of a given length from the packet"""
- return str(self.data[offset : offset + length], 'utf-8', 'replace')
-
- def read_name(self) -> str:
- """Reads a domain name from the packet"""
- result = ''
- off = self.offset
- next_ = -1
- first = off
-
- while True:
- length = self.data[off]
- off += 1
- if length == 0:
- break
- t = length & 0xC0
- if t == 0x00:
- result = ''.join((result, self.read_utf(off, length) + '.'))
- off += length
- elif t == 0xC0:
- if next_ < 0:
- next_ = off + 1
- off = ((length & 0x3F) << 8) | self.data[off]
- if off >= first:
- raise IncomingDecodeError("Bad domain name (circular) at %s" % (off,))
- first = off
- else:
- raise IncomingDecodeError("Bad domain name at %s" % (off,))
-
- if next_ >= 0:
- self.offset = next_
- else:
- self.offset = off
-
- return result
-
-
-class DNSOutgoing:
-
- """Object representation of an outgoing packet"""
-
- def __init__(self, flags: int, multicast: bool = True) -> None:
- self.finished = False
- self.id = 0
- self.multicast = multicast
- self.flags = flags
- self.names = {} # type: Dict[str, int]
- self.data = [] # type: List[bytes]
- self.size = 12
- self.state = self.State.init
-
- self.questions = [] # type: List[DNSQuestion]
- self.answers = [] # type: List[Tuple[DNSRecord, float]]
- self.authorities = [] # type: List[DNSPointer]
- self.additionals = [] # type: List[DNSRecord]
-
- def __repr__(self) -> str:
- return '' % ', '.join(
- [
- 'multicast=%s' % self.multicast,
- 'flags=%s' % self.flags,
- 'questions=%s' % self.questions,
- 'answers=%s' % self.answers,
- 'authorities=%s' % self.authorities,
- 'additionals=%s' % self.additionals,
- ]
- )
-
- class State(enum.Enum):
- init = 0
- finished = 1
-
- def add_question(self, record: DNSQuestion) -> None:
- """Adds a question"""
- self.questions.append(record)
-
- def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None:
- """Adds an answer"""
- if not record.suppressed_by(inp):
- self.add_answer_at_time(record, 0)
-
- def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None:
- """Adds an answer if it does not expire by a certain time"""
- if record is not None:
- if now == 0 or not record.is_expired(now):
- self.answers.append((record, now))
-
- def add_authorative_answer(self, record: DNSPointer) -> None:
- """Adds an authoritative answer"""
- self.authorities.append(record)
-
- def add_additional_answer(self, record: DNSRecord) -> None:
- """ Adds an additional answer
-
- From: RFC 6763, DNS-Based Service Discovery, February 2013
-
- 12. DNS Additional Record Generation
-
- DNS has an efficiency feature whereby a DNS server may place
- additional records in the additional section of the DNS message.
- These additional records are records that the client did not
- explicitly request, but the server has reasonable grounds to expect
- that the client might request them shortly, so including them can
- save the client from having to issue additional queries.
-
- This section recommends which additional records SHOULD be generated
- to improve network efficiency, for both Unicast and Multicast DNS-SD
- responses.
-
- 12.1. PTR Records
-
- When including a DNS-SD Service Instance Enumeration or Selective
- Instance Enumeration (subtype) PTR record in a response packet, the
- server/responder SHOULD include the following additional records:
-
- o The SRV record(s) named in the PTR rdata.
- o The TXT record(s) named in the PTR rdata.
- o All address records (type "A" and "AAAA") named in the SRV rdata.
-
- 12.2. SRV Records
-
- When including an SRV record in a response packet, the
- server/responder SHOULD include the following additional records:
-
- o All address records (type "A" and "AAAA") named in the SRV rdata.
-
- """
- self.additionals.append(record)
-
- def pack(self, format_: Union[bytes, str], value: Any) -> None:
- self.data.append(struct.pack(format_, value))
- self.size += struct.calcsize(format_)
-
- def write_byte(self, value: int) -> None:
- """Writes a single byte to the packet"""
- self.pack(b'!c', int2byte(value))
-
- def insert_short(self, index: int, value: int) -> None:
- """Inserts an unsigned short in a certain position in the packet"""
- self.data.insert(index, struct.pack(b'!H', value))
- self.size += 2
-
- def write_short(self, value: int) -> None:
- """Writes an unsigned short to the packet"""
- self.pack(b'!H', value)
-
- def write_int(self, value: Union[float, int]) -> None:
- """Writes an unsigned integer to the packet"""
- self.pack(b'!I', int(value))
-
- def write_string(self, value: bytes) -> None:
- """Writes a string to the packet"""
- assert isinstance(value, bytes)
- self.data.append(value)
- self.size += len(value)
-
- def write_utf(self, s: str) -> None:
- """Writes a UTF-8 string of a given length to the packet"""
- utfstr = s.encode('utf-8')
- length = len(utfstr)
- if length > 64:
- raise NamePartTooLongException
- self.write_byte(length)
- self.write_string(utfstr)
-
- def write_character_string(self, value: bytes) -> None:
- assert isinstance(value, bytes)
- length = len(value)
- if length > 256:
- raise NamePartTooLongException
- self.write_byte(length)
- self.write_string(value)
-
- def write_name(self, name: str) -> None:
- """
- Write names to packet
-
- 18.14. Name Compression
-
- When generating Multicast DNS messages, implementations SHOULD use
- name compression wherever possible to compress the names of resource
- records, by replacing some or all of the resource record name with a
- compact two-byte reference to an appearance of that data somewhere
- earlier in the message [RFC1035].
- """
-
- # split name into each label
- parts = name.split('.')
- if not parts[-1]:
- parts.pop()
-
- # construct each suffix
- name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))]
-
- # look for an existing name or suffix
- for count, sub_name in enumerate(name_suffices):
- if sub_name in self.names:
- break
- else:
- count = len(name_suffices)
-
- # note the new names we are saving into the packet
- name_length = len(name.encode('utf-8'))
- for suffix in name_suffices[:count]:
- self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1
-
- # write the new names out.
- for part in parts[:count]:
- self.write_utf(part)
-
- # if we wrote part of the name, create a pointer to the rest
- if count != len(name_suffices):
- # Found substring in packet, create pointer
- index = self.names[name_suffices[count]]
- self.write_byte((index >> 8) | 0xC0)
- self.write_byte(index & 0xFF)
- else:
- # this is the end of a name
- self.write_byte(0)
-
- def write_question(self, question: DNSQuestion) -> None:
- """Writes a question to the packet"""
- self.write_name(question.name)
- self.write_short(question.type)
- self.write_short(question.class_)
-
- def write_record(self, record: DNSRecord, now: float) -> int:
- """Writes a record (answer, authoritative answer, additional) to
- the packet"""
- if self.state == self.State.finished:
- return 1
-
- start_data_length, start_size = len(self.data), self.size
- self.write_name(record.name)
- self.write_short(record.type)
- if record.unique and self.multicast:
- self.write_short(record.class_ | _CLASS_UNIQUE)
- else:
- self.write_short(record.class_)
- if now == 0:
- self.write_int(record.ttl)
- else:
- self.write_int(record.get_remaining_ttl(now))
- index = len(self.data)
-
- # Adjust size for the short we will write before this record
- self.size += 2
- record.write(self)
- self.size -= 2
-
- length = sum((len(d) for d in self.data[index:]))
- # Here is the short we adjusted for
- self.insert_short(index, length)
-
- # if we go over, then rollback and quit
- if self.size > _MAX_MSG_ABSOLUTE:
- while len(self.data) > start_data_length:
- self.data.pop()
- self.size = start_size
- self.state = self.State.finished
- return 1
- return 0
-
- def packet(self) -> bytes:
- """Returns a string containing the packet's bytes
-
- No further parts should be added to the packet once this
- is done."""
-
- overrun_answers, overrun_authorities, overrun_additionals = 0, 0, 0
-
- if self.state != self.State.finished:
- for question in self.questions:
- self.write_question(question)
- for answer, time_ in self.answers:
- overrun_answers += self.write_record(answer, time_)
- for authority in self.authorities:
- overrun_authorities += self.write_record(authority, 0)
- for additional in self.additionals:
- overrun_additionals += self.write_record(additional, 0)
- self.state = self.State.finished
-
- self.insert_short(0, len(self.additionals) - overrun_additionals)
- self.insert_short(0, len(self.authorities) - overrun_authorities)
- self.insert_short(0, len(self.answers) - overrun_answers)
- self.insert_short(0, len(self.questions))
- self.insert_short(0, self.flags)
- if self.multicast:
- self.insert_short(0, 0)
- else:
- self.insert_short(0, self.id)
- return b''.join(self.data)
-
-
-class DNSCache:
-
- """A cache of DNS entries"""
-
- def __init__(self) -> None:
- self.cache = {} # type: Dict[str, List[DNSRecord]]
-
- def add(self, entry: DNSRecord) -> None:
- """Adds an entry"""
- # Insert first in list so get returns newest entry
- self.cache.setdefault(entry.key, []).insert(0, entry)
-
- def remove(self, entry: DNSRecord) -> None:
- """Removes an entry"""
- try:
- list_ = self.cache[entry.key]
- list_.remove(entry)
- except (KeyError, ValueError):
- pass
-
- def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
- """Gets an entry by key. Will return None if there is no
- matching entry."""
- try:
- list_ = self.cache[entry.key]
- for cached_entry in list_:
- if entry.__eq__(cached_entry):
- return cached_entry
- return None
- except (KeyError, ValueError):
- return None
-
- def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]:
- """Gets an entry by details. Will return None if there is
- no matching entry."""
- entry = DNSEntry(name, type_, class_)
- return self.get(entry)
-
- def entries_with_name(self, name: str) -> List[DNSRecord]:
- """Returns a list of entries whose key matches the name."""
- try:
- return self.cache[name.lower()]
- except KeyError:
- return []
-
- def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]:
- now = current_time_millis()
- for record in self.entries_with_name(name):
- if (
- record.type == _TYPE_PTR
- and not record.is_expired(now)
- and cast(DNSPointer, record).alias == alias
- ):
- return record
- return None
-
- def entries(self) -> List[DNSRecord]:
- """Returns a list of all entries"""
- if not self.cache:
- return []
- else:
- # avoid size change during iteration by copying the cache
- values = list(self.cache.values())
- return list(itertools.chain.from_iterable(values))
-
-
-class Engine(threading.Thread):
-
- """An engine wraps read access to sockets, allowing objects that
- need to receive data from sockets to be called back when the
- sockets are ready.
-
- A reader needs a handle_read() method, which is called when the socket
- it is interested in is ready for reading.
-
- Writers are not implemented here, because we only send short
- packets.
- """
-
- def __init__(self, zc: 'Zeroconf') -> None:
- threading.Thread.__init__(self, name='zeroconf-Engine')
- self.daemon = True
- self.zc = zc
- self.readers = {} # type: Dict[socket.socket, Listener]
- self.timeout = 5
- self.condition = threading.Condition()
- self.start()
-
- def run(self) -> None:
- while not self.zc.done:
- with self.condition:
- rs = self.readers.keys()
- if len(rs) == 0:
- # No sockets to manage, but we wait for the timeout
- # or addition of a socket
- self.condition.wait(self.timeout)
-
- if len(rs) != 0:
- try:
- rr, wr, er = select.select(cast(Sequence[Any], rs), [], [], self.timeout)
- if not self.zc.done:
- for socket_ in rr:
- reader = self.readers.get(socket_)
- if reader:
- reader.handle_read(socket_)
-
- except (select.error, socket.error) as e:
- # If the socket was closed by another thread, during
- # shutdown, ignore it and exit
- if e.args[0] not in (errno.EBADF, errno.ENOTCONN) or not self.zc.done:
- raise
-
- def add_reader(self, reader: 'Listener', socket_: socket.socket) -> None:
- with self.condition:
- self.readers[socket_] = reader
- self.condition.notify()
-
- def del_reader(self, socket_: socket.socket) -> None:
- with self.condition:
- del self.readers[socket_]
- self.condition.notify()
-
-
-class Listener(QuietLogger):
-
- """A Listener is used by this module to listen on the multicast
- group to which DNS messages are sent, allowing the implementation
- to cache information as it arrives.
-
- It requires registration with an Engine object in order to have
- the read() method called when a socket is available for reading."""
-
- def __init__(self, zc: 'Zeroconf') -> None:
- self.zc = zc
- self.data = None # type: Optional[bytes]
-
- def handle_read(self, socket_: socket.socket) -> None:
- try:
- data, (addr, port, *_v6) = socket_.recvfrom(_MAX_MSG_ABSOLUTE)
- except Exception:
- self.log_exception_warning()
- return
-
- log.debug('Received from %r:%r: %r ', addr, port, data)
-
- self.data = data
- msg = DNSIncoming(data)
- if not msg.valid:
- pass
-
- elif msg.is_query():
- # Always multicast responses
- if port == _MDNS_PORT:
- self.zc.handle_query(msg, None, _MDNS_PORT)
-
- # If it's not a multicast query, reply via unicast
- # and multicast
- elif port == _DNS_PORT:
- self.zc.handle_query(msg, addr, port)
- self.zc.handle_query(msg, None, _MDNS_PORT)
-
- else:
- self.zc.handle_response(msg)
-
-
-class Reaper(threading.Thread):
-
- """A Reaper is used by this module to remove cache entries that
- have expired."""
-
- def __init__(self, zc: 'Zeroconf') -> None:
- threading.Thread.__init__(self, name='zeroconf-Reaper')
- self.daemon = True
- self.zc = zc
- self.start()
-
- def run(self) -> None:
- while True:
- self.zc.wait(10 * 1000)
- if self.zc.done:
- return
- now = current_time_millis()
- for record in self.zc.cache.entries():
- if record.is_expired(now):
- self.zc.update_record(now, record)
- self.zc.cache.remove(record)
-
-
-class Signal:
- def __init__(self) -> None:
- self._handlers = [] # type: List[Callable[..., None]]
-
- def fire(self, **kwargs: Any) -> None:
- for h in list(self._handlers):
- h(**kwargs)
-
- @property
- def registration_interface(self) -> 'SignalRegistrationInterface':
- return SignalRegistrationInterface(self._handlers)
-
-
-# NOTE: Callable quoting needed on Python 3.5.2, see
-# https://github.com/jstasiak/python-zeroconf/issues/208 for details.
-class SignalRegistrationInterface:
- def __init__(self, handlers: List['Callable[..., None]']) -> None:
- self._handlers = handlers
-
- def register_handler(self, handler: 'Callable[..., None]') -> 'SignalRegistrationInterface':
- self._handlers.append(handler)
- return self
-
- def unregister_handler(self, handler: 'Callable[..., None]') -> 'SignalRegistrationInterface':
- self._handlers.remove(handler)
- return self
-
-
-class RecordUpdateListener:
- def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
- raise NotImplementedError()
-
-
-class ServiceListener:
- def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None:
- raise NotImplementedError()
-
- def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None:
- raise NotImplementedError()
-
- def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None:
- raise NotImplementedError()
-
-
-class ServiceBrowser(RecordUpdateListener, threading.Thread):
-
- """Used to browse for a service of a specific type.
-
- The listener object will have its add_service() and
- remove_service() methods called when this browser
- discovers changes in the services availability."""
-
- def __init__(
- self,
- zc: 'Zeroconf',
- type_: str,
- # NOTE: Callable quoting needed on Python 3.5.2, see
- # https://github.com/jstasiak/python-zeroconf/issues/208 for details.
- handlers: Optional[Union[ServiceListener, List['Callable[..., None]']]] = None,
- listener: Optional[ServiceListener] = None,
- addr: Optional[str] = None,
- port: int = _MDNS_PORT,
- delay: int = _BROWSER_TIME,
- ) -> None:
- """Creates a browser for a specific type"""
- assert handlers or listener, 'You need to specify at least one handler'
- if not type_.endswith(service_type_name(type_, allow_underscores=True)):
- raise BadTypeInNameException
- threading.Thread.__init__(self, name='zeroconf-ServiceBrowser_' + type_)
- self.daemon = True
- self.zc = zc
- self.type = type_
- self.addr = addr
- self.port = port
- self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6)
- self.services = {} # type: Dict[str, DNSRecord]
- self.next_time = current_time_millis()
- self.delay = delay
- self._handlers_to_call = [] # type: List[Callable[[Zeroconf], None]]
-
- self._service_state_changed = Signal()
-
- self.done = False
-
- if hasattr(handlers, 'add_service'):
- listener = cast(ServiceListener, handlers)
- handlers = None
-
- # NOTE: Callable quoting needed on Python 3.5.2, see
- # https://github.com/jstasiak/python-zeroconf/issues/208 for details.
- handlers = cast(List['Callable[..., None]'], handlers or [])
-
- if listener:
-
- def on_change(
- zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange
- ) -> None:
- assert listener is not None
- args = (zeroconf, service_type, name)
- if state_change is ServiceStateChange.Added:
- listener.add_service(*args)
- elif state_change is ServiceStateChange.Removed:
- listener.remove_service(*args)
- elif state_change is ServiceStateChange.Updated:
- if hasattr(listener, 'update_service'):
- listener.update_service(*args)
- else:
- raise NotImplementedError(state_change)
-
- handlers.append(on_change)
-
- for h in handlers:
- self.service_state_changed.register_handler(h)
-
- self.start()
-
- @property
- def service_state_changed(self) -> SignalRegistrationInterface:
- return self._service_state_changed.registration_interface
-
- def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
- """Callback invoked by Zeroconf when new information arrives.
-
- Updates information required by browser in the Zeroconf cache."""
-
- def enqueue_callback(state_change: ServiceStateChange, name: str) -> None:
- self._handlers_to_call.append(
- lambda zeroconf: self._service_state_changed.fire(
- zeroconf=zeroconf, service_type=self.type, name=name, state_change=state_change
- )
- )
-
- if record.type == _TYPE_PTR and record.name == self.type:
- assert isinstance(record, DNSPointer)
- expired = record.is_expired(now)
- service_key = record.alias.lower()
- try:
- old_record = self.services[service_key]
- except KeyError:
- if not expired:
- self.services[service_key] = record
- enqueue_callback(ServiceStateChange.Added, record.alias)
- else:
- if not expired:
- old_record.reset_ttl(record)
- else:
- del self.services[service_key]
- enqueue_callback(ServiceStateChange.Removed, record.alias)
- return
-
- expires = record.get_expiration_time(75)
- if expires < self.next_time:
- self.next_time = expires
-
- elif record.type == _TYPE_TXT and record.name.endswith(self.type):
- assert isinstance(record, DNSText)
- expired = record.is_expired(now)
- if not expired:
- enqueue_callback(ServiceStateChange.Updated, record.name)
-
- def cancel(self) -> None:
- self.done = True
- self.zc.remove_listener(self)
- self.join()
-
- def run(self) -> None:
- self.zc.add_listener(self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
-
- while True:
- now = current_time_millis()
- if len(self._handlers_to_call) == 0 and self.next_time > now:
- self.zc.wait(self.next_time - now)
- if self.zc.done or self.done:
- return
- now = current_time_millis()
- if self.next_time <= now:
- out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast)
- out.add_question(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
- for record in self.services.values():
- if not record.is_stale(now):
- out.add_answer_at_time(record, now)
-
- self.zc.send(out, addr=self.addr, port=self.port)
- self.next_time = now + self.delay
- self.delay = min(_BROWSER_BACKOFF_LIMIT * 1000, self.delay * 2)
-
- if len(self._handlers_to_call) > 0 and not self.zc.done:
- handler = self._handlers_to_call.pop(0)
- handler(self.zc)
-
-
-class ServiceInfo(RecordUpdateListener):
- text = b''
-
- """Service information"""
-
- # FIXME(dtantsur): black 19.3b0 produces code that is not valid syntax on
- # Python 3.5: https://github.com/python/black/issues/759
- # fmt: off
- def __init__(
- self,
- type_: str,
- name: str,
- address: Optional[Union[bytes, List[bytes]]] = None,
- port: Optional[int] = None,
- weight: int = 0,
- priority: int = 0,
- properties: Union[bytes, Dict] = b'',
- server: Optional[str] = None,
- host_ttl: int = _DNS_HOST_TTL,
- other_ttl: int = _DNS_OTHER_TTL,
- *,
- addresses: Optional[List[bytes]] = None
- ) -> None:
- """Create a service description.
-
- type_: fully qualified service type name
- name: fully qualified service name
- address: IP address as unsigned short, network byte order (deprecated, use addresses)
- port: port that the service runs on
- weight: weight of the service
- priority: priority of the service
- properties: dictionary of properties (or a string holding the
- bytes for the text field)
- server: fully qualified name for service host (defaults to name)
- host_ttl: ttl used for A/SRV records
- other_ttl: ttl used for PTR/TXT records
- addresses: List of IP addresses as unsigned short (IPv4) or unsigned
- 128 bit number (IPv6), network byte order
- """
-
- # Accept both none, or one, but not both.
- if address is not None and addresses is not None:
- raise TypeError("address and addresses cannot be provided together")
-
- if not type_.endswith(service_type_name(name, allow_underscores=True)):
- raise BadTypeInNameException
- self.type = type_
- self.name = name
- if addresses is not None:
- self._addresses = addresses
- elif address is not None:
- warnings.warn("address is deprecated, use addresses instead", DeprecationWarning)
- if isinstance(address, list):
- self._addresses = address
- else:
- self._addresses = [address]
- else:
- self._addresses = []
- # This results in an ugly error when registering, better check now
- invalid = [a for a in self._addresses
- if not isinstance(a, bytes) or len(a) not in (4, 16)]
- if invalid:
- raise TypeError('Addresses must be bytes, got %s. Hint: convert string addresses '
- 'with socket.inet_pton' % invalid)
- self.port = port
- self.weight = weight
- self.priority = priority
- if server:
- self.server = server
- else:
- self.server = name
- self._properties = {} # type: Dict
- self._set_properties(properties)
- self.host_ttl = host_ttl
- self.other_ttl = other_ttl
- # fmt: on
-
- @property
- def address(self) -> Optional[bytes]:
- warnings.warn("ServiceInfo.address is deprecated, use addresses instead", DeprecationWarning)
- try:
- # Return the first V4 address for compatibility
- return self.addresses[0]
- except IndexError:
- return None
-
- @address.setter
- def address(self, value: bytes) -> None:
- warnings.warn("ServiceInfo.address is deprecated, use addresses instead", DeprecationWarning)
- if value is None:
- self._addresses = []
- else:
- self._addresses = [value]
-
- @property
- def addresses(self) -> List[bytes]:
- """IPv4 addresses of this service.
-
- Only IPv4 addresses are returned for backward compatibility.
- Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to
- include IPv6 addresses as well.
- """
- return self.addresses_by_version(IPVersion.V4Only)
-
- @addresses.setter
- def addresses(self, value: List[bytes]) -> None:
- """Replace the addresses list.
-
- This replaces all currently stored addresses, both IPv4 and IPv6.
- """
- self._addresses = value
-
- @property
- def properties(self) -> Dict:
- return self._properties
-
- def addresses_by_version(self, version: IPVersion) -> List[bytes]:
- """List addresses matching IP version."""
- if version == IPVersion.V4Only:
- return [addr for addr in self._addresses if not _is_v6_address(addr)]
- elif version == IPVersion.V6Only:
- return list(filter(_is_v6_address, self._addresses))
- else:
- return self._addresses
-
- def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
- """List addresses in their parsed string form."""
- result = self.addresses_by_version(version)
- return [
- socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr)
- for addr in result
- ]
-
- def _set_properties(self, properties: Union[bytes, Dict]) -> None:
- """Sets properties and text of this info from a dictionary"""
- if isinstance(properties, dict):
- self._properties = properties
- list_ = []
- result = b''
- for key, value in properties.items():
- if isinstance(key, str):
- key = key.encode('utf-8')
-
- if value is None:
- suffix = b''
- elif isinstance(value, str):
- suffix = value.encode('utf-8')
- elif isinstance(value, bytes):
- suffix = value
- elif isinstance(value, int):
- if value:
- suffix = b'true'
- else:
- suffix = b'false'
- else:
- suffix = b''
- list_.append(b'='.join((key, suffix)))
- for item in list_:
- result = b''.join((result, int2byte(len(item)), item))
- self.text = result
- else:
- self.text = properties
-
- def _set_text(self, text: bytes) -> None:
- """Sets properties and text given a text field"""
- self.text = text
- result = {} # type: Dict
- end = len(text)
- index = 0
- strs = []
- while index < end:
- length = text[index]
- index += 1
- strs.append(text[index : index + length])
- index += length
-
- for s in strs:
- parts = s.split(b'=', 1)
- try:
- key, value = parts # type: Tuple[bytes, Union[bool, bytes]]
- except ValueError:
- # No equals sign at all
- key = s
- value = False
- else:
- if value == b'true':
- value = True
- elif value == b'false' or not value:
- value = False
-
- # Only update non-existent properties
- if key and result.get(key) is None:
- result[key] = value
-
- self._properties = result
-
- def get_name(self) -> str:
- """Name accessor"""
- if self.type is not None and self.name.endswith("." + self.type):
- return self.name[: len(self.name) - len(self.type) - 1]
- return self.name
-
- def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:
- """Updates service information from a DNS record"""
- if record is not None and not record.is_expired(now):
- if record.type in [_TYPE_A, _TYPE_AAAA]:
- assert isinstance(record, DNSAddress)
- # if record.name == self.name:
- if record.name == self.server:
- if record.address not in self._addresses:
- self._addresses.append(record.address)
- elif record.type == _TYPE_SRV:
- assert isinstance(record, DNSService)
- if record.name == self.name:
- self.server = record.server
- self.port = record.port
- self.weight = record.weight
- self.priority = record.priority
- # self.address = None
- self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN))
- self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN))
- elif record.type == _TYPE_TXT:
- assert isinstance(record, DNSText)
- if record.name == self.name:
- self._set_text(record.text)
-
- def request(self, zc: 'Zeroconf', timeout: float) -> bool:
- """Returns true if the service could be discovered on the
- network, and updates this object with details discovered.
- """
- now = current_time_millis()
- delay = _LISTENER_TIME
- next_ = now + delay
- last = now + timeout
-
- record_types_for_check_cache = [(_TYPE_SRV, _CLASS_IN), (_TYPE_TXT, _CLASS_IN)]
- if self.server is not None:
- record_types_for_check_cache.append((_TYPE_A, _CLASS_IN))
- record_types_for_check_cache.append((_TYPE_AAAA, _CLASS_IN))
- for record_type in record_types_for_check_cache:
- cached = zc.cache.get_by_details(self.name, *record_type)
- if cached:
- self.update_record(zc, now, cached)
-
- if self.server is not None and self.text is not None and self._addresses:
- return True
-
- try:
- zc.add_listener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN))
- while self.server is None or self.text is None or not self._addresses:
- if last <= now:
- return False
- if next_ <= now:
- out = DNSOutgoing(_FLAGS_QR_QUERY)
- out.add_question(DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN))
- out.add_answer_at_time(zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN), now)
-
- out.add_question(DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN))
- out.add_answer_at_time(zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN), now)
-
- if self.server is not None:
- out.add_question(DNSQuestion(self.server, _TYPE_A, _CLASS_IN))
- out.add_answer_at_time(zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN), now)
- out.add_question(DNSQuestion(self.server, _TYPE_AAAA, _CLASS_IN))
- out.add_answer_at_time(
- zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN), now
- )
- zc.send(out)
- next_ = now + delay
- delay *= 2
-
- zc.wait(min(next_, last) - now)
- now = current_time_millis()
- finally:
- zc.remove_listener(self)
-
- return True
-
- def __eq__(self, other: object) -> bool:
- """Tests equality of service name"""
- return isinstance(other, ServiceInfo) and other.name == self.name
-
- def __ne__(self, other: object) -> bool:
- """Non-equality test"""
- return not self.__eq__(other)
-
- def __repr__(self) -> str:
- """String representation"""
- return '%s(%s)' % (
- type(self).__name__,
- ', '.join(
- '%s=%r' % (name, getattr(self, name))
- for name in (
- 'type',
- 'name',
- 'addresses',
- 'port',
- 'weight',
- 'priority',
- 'server',
- 'properties',
- )
- ),
- )
-
-
-class ZeroconfServiceTypes(ServiceListener):
- """
- Return all of the advertised services on any local networks
- """
-
- def __init__(self) -> None:
- self.found_services = set() # type: Set[str]
-
- def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None:
- self.found_services.add(name)
-
- def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None:
- pass
-
- @classmethod
- def find(
- cls,
- zc: Optional['Zeroconf'] = None,
- timeout: Union[int, float] = 5,
- interfaces: InterfacesType = InterfaceChoice.All,
- ip_version: Optional[IPVersion] = None,
- ) -> Tuple[str, ...]:
- """
- Return all of the advertised services on any local networks.
-
- :param zc: Zeroconf() instance. Pass in if already have an
- instance running or if non-default interfaces are needed
- :param timeout: seconds to wait for any responses
- :param interfaces: interfaces to listen on.
- :param ip_version: IP protocol version to use.
- :return: tuple of service type strings
- """
- local_zc = zc or Zeroconf(interfaces=interfaces, ip_version=ip_version)
- listener = cls()
- browser = ServiceBrowser(local_zc, '_services._dns-sd._udp.local.', listener=listener)
-
- # wait for responses
- time.sleep(timeout)
-
- # close down anything we opened
- if zc is None:
- local_zc.close()
- else:
- browser.cancel()
-
- return tuple(sorted(listener.found_services))
-
-
-def get_all_addresses() -> List[str]:
- return list(
- set(
- addr.ip
- for iface in ifaddr.get_adapters()
- for addr in iface.ips
- if addr.is_IPv4 and addr.network_prefix != 32 # Host only netmask 255.255.255.255
- )
- )
-
-
-def get_all_addresses_v6() -> List[int]:
- # IPv6 multicast uses positive indexes for interfaces
- try:
- nameindex = socket.if_nameindex
- except AttributeError:
- # Requires Python 3.8 on Windows. Fall back to Default.
- QuietLogger.log_warning_once(
- 'if_nameindex is not available, falling back to using the default IPv6 interface'
- )
- return [0]
-
- return [tpl[0] for tpl in nameindex()]
-
-
-def ip_to_index(adapters: List[Any], ip: str) -> int:
- if os.name != 'posix':
- # Adapter names that ifaddr reports are not compatible with what if_nametoindex expects on Windows.
- # We need https://github.com/pydron/ifaddr/pull/21 but it seems stuck on review.
- raise RuntimeError('Converting from IP addresses to indexes is not supported on non-POSIX systems')
-
- ipaddr = ipaddress.ip_address(ip)
- for adapter in adapters:
- for adapter_ip in adapter.ips:
- # IPv6 addresses are represented as tuples
- if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr:
- return socket.if_nametoindex(adapter.name)
-
- raise RuntimeError('No adapter found for IP address %s' % ip)
-
-
-def ip6_addresses_to_indexes(interfaces: List[Union[str, int]]) -> List[int]:
- """Convert IPv6 interface addresses to interface indexes.
-
- IPv4 addresses are ignored. The conversion currently only works on POSIX
- systems.
-
- :param interfaces: List of IP addresses and indexes.
- :returns: List of indexes.
- """
- result = []
- adapters = ifaddr.get_adapters()
-
- for iface in interfaces:
- if isinstance(iface, int):
- result.append(iface)
- elif isinstance(iface, str) and ipaddress.ip_address(iface).version == 6:
- result.append(ip_to_index(adapters, iface))
-
- return result
-
-
-def normalize_interface_choice(
- choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only
-) -> List[Union[str, int]]:
- """Convert the interfaces choice into internal representation.
-
- :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only).
- :param ip_address: IP version to use (ignored if `choice` is a list).
- :returns: List of IP addresses (for IPv4) and indexes (for IPv6).
- """
- result = [] # type: List[Union[str, int]]
- if choice is InterfaceChoice.Default:
- if ip_version != IPVersion.V4Only:
- # IPv6 multicast uses interface 0 to mean the default
- result.append(0)
- if ip_version != IPVersion.V6Only:
- result.append('0.0.0.0')
- elif choice is InterfaceChoice.All:
- if ip_version != IPVersion.V4Only:
- result.extend(get_all_addresses_v6())
- if ip_version != IPVersion.V6Only:
- result.extend(get_all_addresses())
- if not result:
- raise RuntimeError(
- 'No interfaces to listen on, check that any interfaces have IP version %s' % ip_version
- )
- elif isinstance(choice, list):
- # First, take IPv4 addresses.
- result = [i for i in choice if isinstance(i, str) and ipaddress.ip_address(i).version == 4]
- # Unlike IP_ADD_MEMBERSHIP, IPV6_JOIN_GROUP requires interface indexes.
- result += ip6_addresses_to_indexes(choice)
- else:
- raise TypeError("choice must be a list or InterfaceChoice, got %r" % choice)
- return result
-
-
-def new_socket(
- port: int = _MDNS_PORT, ip_version: IPVersion = IPVersion.V4Only, apple_p2p: bool = False
-) -> socket.socket:
- if ip_version == IPVersion.V4Only:
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- else:
- s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
-
- if ip_version == IPVersion.All:
- # make V6 sockets work for both V4 and V6 (required for Windows)
- try:
- s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
- except OSError:
- log.error('Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6')
- raise
-
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-
- # SO_REUSEADDR should be equivalent to SO_REUSEPORT for
- # multicast UDP sockets (p 731, "TCP/IP Illustrated,
- # Volume 2"), but some BSD-derived systems require
- # SO_REUSEPORT to be specified explicitly. Also, not all
- # versions of Python have SO_REUSEPORT available.
- # Catch OSError and socket.error for kernel versions <3.9 because lacking
- # SO_REUSEPORT support.
- try:
- reuseport = socket.SO_REUSEPORT
- except AttributeError:
- pass
- else:
- try:
- s.setsockopt(socket.SOL_SOCKET, reuseport, 1)
- except (OSError, socket.error) as err:
- # OSError on python 3, socket.error on python 2
- if not err.errno == errno.ENOPROTOOPT:
- raise
-
- if port is _MDNS_PORT:
- ttl = struct.pack(b'B', 255)
- loop = struct.pack(b'B', 1)
- if ip_version != IPVersion.V6Only:
- # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and
- # IP_MULTICAST_LOOP socket options as an unsigned char.
- s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl)
- s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop)
- if ip_version != IPVersion.V4Only:
- # However, char doesn't work here (at least on Linux)
- s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255)
- s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True)
-
- if apple_p2p:
- # SO_RECV_ANYIF = 0x1104
- # https://opensource.apple.com/source/xnu/xnu-4570.41.2/bsd/sys/socket.h
- s.setsockopt(socket.SOL_SOCKET, 0x1104, 1)
-
- s.bind(('', port))
- return s
-
-
-def add_multicast_member(
- listen_socket: socket.socket, interface: Union[str, int], apple_p2p: bool = False
-) -> Optional[socket.socket]:
- # This is based on assumptions in normalize_interface_choice
- is_v6 = isinstance(interface, int)
- log.debug('Adding %r to multicast group', interface)
- try:
- if is_v6:
- iface_bin = struct.pack('@I', cast(int, interface))
- _value = _MDNS_ADDR6_BYTES + iface_bin
- listen_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, _value)
- else:
- _value = _MDNS_ADDR_BYTES + socket.inet_aton(cast(str, interface))
- listen_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value)
- except socket.error as e:
- _errno = get_errno(e)
- if _errno == errno.EADDRINUSE:
- log.info(
- 'Address in use when adding %s to multicast group, '
- 'it is expected to happen on some systems',
- interface,
- )
- return None
- elif _errno == errno.EADDRNOTAVAIL:
- log.info(
- 'Address not available when adding %s to multicast '
- 'group, it is expected to happen on some systems',
- interface,
- )
- return None
- elif _errno == errno.EINVAL:
- log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface)
- return None
- else:
- raise
-
- respond_socket = new_socket(
- ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), apple_p2p=apple_p2p
- )
- log.debug('Configuring %s with multicast interface %s', respond_socket, interface)
- if is_v6:
- respond_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, iface_bin)
- else:
- respond_socket.setsockopt(
- socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(cast(str, interface))
- )
- return respond_socket
-
-
-def create_sockets(
- interfaces: InterfacesType = InterfaceChoice.All,
- unicast: bool = False,
- ip_version: IPVersion = IPVersion.V4Only,
- apple_p2p: bool = False,
-) -> Tuple[Optional[socket.socket], List[socket.socket]]:
- if unicast:
- listen_socket = None
- else:
- listen_socket = new_socket(ip_version=ip_version, apple_p2p=apple_p2p)
-
- interfaces = normalize_interface_choice(interfaces, ip_version)
-
- respond_sockets = []
-
- for i in interfaces:
- if not unicast:
- respond_socket = add_multicast_member(cast(socket.socket, listen_socket), i, apple_p2p=apple_p2p)
- else:
- respond_socket = new_socket(port=0, ip_version=ip_version, apple_p2p=apple_p2p)
-
- if respond_socket is not None:
- respond_sockets.append(respond_socket)
-
- return listen_socket, respond_sockets
-
-
-def get_errno(e: Exception) -> int:
- assert isinstance(e, socket.error)
- return cast(int, e.args[0])
-
-
-def can_send_to(sock: socket.socket, address: str) -> bool:
- addr = ipaddress.ip_address(address)
- return cast(bool, addr.version == 6 if sock.family == socket.AF_INET6 else addr.version == 4)
-
-
-class Zeroconf(QuietLogger):
-
- """Implementation of Zeroconf Multicast DNS Service Discovery
-
- Supports registration, unregistration, queries and browsing.
- """
-
- def __init__(
- self,
- interfaces: InterfacesType = InterfaceChoice.All,
- unicast: bool = False,
- ip_version: Optional[IPVersion] = None,
- apple_p2p: bool = False,
- ) -> None:
- """Creates an instance of the Zeroconf class, establishing
- multicast communications, listening and reaping threads.
-
- :param interfaces: :class:`InterfaceChoice` or a list of IP addresses
- (IPv4 and IPv6) and interface indexes (IPv6 only).
-
- IPv6 notes for non-POSIX systems:
- * IPv6 addresses are not supported, use indexes instead.
- * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default`
- on Python versions before 3.8.
-
- Also listening on loopback (``::1``) doesn't work, use a real address.
- :param ip_version: IP versions to support. If `choice` is a list, the default is detected
- from it. Otherwise defaults to V4 only for backward compatibility.
- :param apple_p2p: use AWDL interface (only macOS)
- """
- if ip_version is None and isinstance(interfaces, list):
- has_v6 = any(
- isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6)
- for i in interfaces
- )
- has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces)
- if has_v4 and has_v6:
- ip_version = IPVersion.All
- elif has_v6:
- ip_version = IPVersion.V6Only
-
- if ip_version is None:
- ip_version = IPVersion.V4Only
-
- # hook for threads
- self._GLOBAL_DONE = False
- self.unicast = unicast
-
- if apple_p2p and not platform.system() == 'Darwin':
- raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.')
-
- self._listen_socket, self._respond_sockets = create_sockets(
- interfaces, unicast, ip_version, apple_p2p=apple_p2p
- )
-
- self.listeners = [] # type: List[RecordUpdateListener]
- self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser]
- self.services = {} # type: Dict[str, ServiceInfo]
- self.servicetypes = {} # type: Dict[str, int]
-
- self.cache = DNSCache()
-
- self.condition = threading.Condition()
-
- self.engine = Engine(self)
- self.listener = Listener(self)
- if not unicast:
- self.engine.add_reader(self.listener, cast(socket.socket, self._listen_socket))
- else:
- for s in self._respond_sockets:
- self.engine.add_reader(self.listener, s)
- self.reaper = Reaper(self)
-
- self.debug = None # type: Optional[DNSOutgoing]
-
- @property
- def done(self) -> bool:
- return self._GLOBAL_DONE
-
- def wait(self, timeout: float) -> None:
- """Calling thread waits for a given number of milliseconds or
- until notified."""
- with self.condition:
- self.condition.wait(timeout / 1000.0)
-
- def notify_all(self) -> None:
- """Notifies all waiting threads"""
- with self.condition:
- self.condition.notify_all()
-
- def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]:
- """Returns network's service information for a particular
- name and type, or None if no service matches by the timeout,
- which defaults to 3 seconds."""
- info = ServiceInfo(type_, name)
- if info.request(self, timeout):
- return info
- return None
-
- def add_service_listener(self, type_: str, listener: ServiceListener) -> None:
- """Adds a listener for a particular service type. This object
- will then have its add_service and remove_service methods called when
- services of that type become available and unavailable."""
- self.remove_service_listener(listener)
- self.browsers[listener] = ServiceBrowser(self, type_, listener)
-
- def remove_service_listener(self, listener: ServiceListener) -> None:
- """Removes a listener from the set that is currently listening."""
- if listener in self.browsers:
- self.browsers[listener].cancel()
- del self.browsers[listener]
-
- def remove_all_service_listeners(self) -> None:
- """Removes a listener from the set that is currently listening."""
- for listener in [k for k in self.browsers]:
- self.remove_service_listener(listener)
-
- def register_service(
- self, info: ServiceInfo, ttl: Optional[int] = None, allow_name_change: bool = False
- ) -> None:
- """Registers service information to the network with a default TTL.
- Zeroconf will then respond to requests for information for that
- service. The name of the service may be changed if needed to make
- it unique on the network."""
- if ttl is not None:
- # ttl argument is used to maintain backward compatibility
- # Setting TTLs via ServiceInfo is preferred
- info.host_ttl = ttl
- info.other_ttl = ttl
- self.check_service(info, allow_name_change)
- self.services[info.name.lower()] = info
- if info.type in self.servicetypes:
- self.servicetypes[info.type] += 1
- else:
- self.servicetypes[info.type] = 1
-
- self._broadcast_service(info)
-
- def update_service(self, info: ServiceInfo) -> None:
- """Registers service information to the network with a default TTL.
- Zeroconf will then respond to requests for information for that
- service."""
-
- assert self.services[info.name.lower()] is not None
-
- self.services[info.name.lower()] = info
-
- self._broadcast_service(info)
-
- def _broadcast_service(self, info: ServiceInfo) -> None:
-
- now = current_time_millis()
- next_time = now
- i = 0
- while i < 3:
- if now < next_time:
- self.wait(next_time - now)
- now = current_time_millis()
- continue
- out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
- out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, info.other_ttl, info.name), 0)
- out.add_answer_at_time(
- DNSService(
- info.name,
- _TYPE_SRV,
- _CLASS_IN,
- info.host_ttl,
- info.priority,
- info.weight,
- cast(int, info.port),
- info.server,
- ),
- 0,
- )
-
- out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, info.other_ttl, info.text), 0)
- for address in info.addresses_by_version(IPVersion.All):
- type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A
- out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, info.host_ttl, address), 0)
- self.send(out)
- i += 1
- next_time += _REGISTER_TIME
-
- def unregister_service(self, info: ServiceInfo) -> None:
- """Unregister a service."""
- try:
- del self.services[info.name.lower()]
- if self.servicetypes[info.type] > 1:
- self.servicetypes[info.type] -= 1
- else:
- del self.servicetypes[info.type]
- except Exception as e: # TODO stop catching all Exceptions
- log.exception('Unknown error, possibly benign: %r', e)
- now = current_time_millis()
- next_time = now
- i = 0
- while i < 3:
- if now < next_time:
- self.wait(next_time - now)
- now = current_time_millis()
- continue
- out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
- out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
- out.add_answer_at_time(
- DNSService(
- info.name,
- _TYPE_SRV,
- _CLASS_IN,
- 0,
- info.priority,
- info.weight,
- cast(int, info.port),
- info.name,
- ),
- 0,
- )
- out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
-
- for address in info.addresses_by_version(IPVersion.All):
- type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A
- out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, 0, address), 0)
- self.send(out)
- i += 1
- next_time += _UNREGISTER_TIME
-
- def unregister_all_services(self) -> None:
- """Unregister all registered services."""
- if len(self.services) > 0:
- now = current_time_millis()
- next_time = now
- i = 0
- while i < 3:
- if now < next_time:
- self.wait(next_time - now)
- now = current_time_millis()
- continue
- out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
- for info in self.services.values():
- out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0)
- out.add_answer_at_time(
- DNSService(
- info.name,
- _TYPE_SRV,
- _CLASS_IN,
- 0,
- info.priority,
- info.weight,
- cast(int, info.port),
- info.server,
- ),
- 0,
- )
- out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0)
- for address in info.addresses_by_version(IPVersion.All):
- type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A
- out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, 0, address), 0)
- self.send(out)
- i += 1
- next_time += _UNREGISTER_TIME
-
- def check_service(self, info: ServiceInfo, allow_name_change: bool) -> None:
- """Checks the network for a unique service name, modifying the
- ServiceInfo passed in if it is not unique."""
-
- # This is kind of funky because of the subtype based tests
- # need to make subtypes a first class citizen
- service_name = service_type_name(info.name)
- if not info.type.endswith(service_name):
- raise BadTypeInNameException
-
- instance_name = info.name[: -len(service_name) - 1]
- next_instance_number = 2
-
- now = current_time_millis()
- next_time = now
- i = 0
- while i < 3:
- # check for a name conflict
- while self.cache.current_entry_with_name_and_alias(info.type, info.name):
- if not allow_name_change:
- raise NonUniqueNameException
-
- # change the name and look for a conflict
- info.name = '%s-%s.%s' % (instance_name, next_instance_number, info.type)
- next_instance_number += 1
- service_type_name(info.name)
- next_time = now
- i = 0
-
- if now < next_time:
- self.wait(next_time - now)
- now = current_time_millis()
- continue
-
- out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
- self.debug = out
- out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN))
- out.add_authorative_answer(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, info.other_ttl, info.name))
- self.send(out)
- i += 1
- next_time += _CHECK_TIME
-
- def add_listener(self, listener: RecordUpdateListener, question: Optional[DNSQuestion]) -> None:
- """Adds a listener for a given question. The listener will have
- its update_record method called when information is available to
- answer the question."""
- now = current_time_millis()
- self.listeners.append(listener)
- if question is not None:
- for record in self.cache.entries_with_name(question.name):
- if question.answered_by(record) and not record.is_expired(now):
- listener.update_record(self, now, record)
- self.notify_all()
-
- def remove_listener(self, listener: RecordUpdateListener) -> None:
- """Removes a listener."""
- try:
- self.listeners.remove(listener)
- self.notify_all()
- except Exception as e: # TODO stop catching all Exceptions
- log.exception('Unknown error, possibly benign: %r', e)
-
- def update_record(self, now: float, rec: DNSRecord) -> None:
- """Used to notify listeners of new information that has updated
- a record."""
- for listener in self.listeners:
- listener.update_record(self, now, rec)
- self.notify_all()
-
- def handle_response(self, msg: DNSIncoming) -> None:
- """Deal with incoming response packets. All answers
- are held in the cache, and listeners are notified."""
- now = current_time_millis()
- for record in msg.answers:
- if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
- for entry in self.cache.entries():
- if DNSEntry.__eq__(entry, record) and (record.created - entry.created > 1000):
- self.cache.remove(entry)
-
- expired = record.is_expired(now)
- maybe_entry = self.cache.get(record)
- if not expired:
- if maybe_entry is not None:
- maybe_entry.reset_ttl(record)
- else:
- self.cache.add(record)
- self.update_record(now, record)
- else:
- if maybe_entry is not None:
- self.update_record(now, record)
- self.cache.remove(maybe_entry)
-
- def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None:
- """Deal with incoming query packets. Provides a response if
- possible."""
- out = None
-
- # Support unicast client responses
- #
- if port != _MDNS_PORT:
- out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False)
- for question in msg.questions:
- out.add_question(question)
-
- for question in msg.questions:
- if question.type == _TYPE_PTR:
- if question.name == "_services._dns-sd._udp.local.":
- for stype in self.servicetypes.keys():
- if out is None:
- out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
- out.add_answer(
- msg,
- DNSPointer(
- "_services._dns-sd._udp.local.", _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype
- ),
- )
- for service in self.services.values():
- if question.name == service.type:
- if out is None:
- out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
- out.add_answer(
- msg,
- DNSPointer(service.type, _TYPE_PTR, _CLASS_IN, service.other_ttl, service.name),
- )
-
- # Add recommended additional answers according to
- # https://tools.ietf.org/html/rfc6763#section-12.1.
- out.add_additional_answer(
- DNSService(
- service.name,
- _TYPE_SRV,
- _CLASS_IN | _CLASS_UNIQUE,
- service.host_ttl,
- service.priority,
- service.weight,
- cast(int, service.port),
- service.server,
- )
- )
- out.add_additional_answer(
- DNSText(
- service.name,
- _TYPE_TXT,
- _CLASS_IN | _CLASS_UNIQUE,
- service.other_ttl,
- service.text,
- )
- )
- for address in service.addresses_by_version(IPVersion.All):
- type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A
- out.add_additional_answer(
- DNSAddress(
- service.server,
- type_,
- _CLASS_IN | _CLASS_UNIQUE,
- service.host_ttl,
- address,
- )
- )
- else:
- try:
- if out is None:
- out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
-
- # Answer A record queries for any service addresses we know
- if question.type in (_TYPE_A, _TYPE_ANY):
- for service in self.services.values():
- if service.server == question.name.lower():
- for address in service.addresses_by_version(IPVersion.All):
- type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A
- out.add_answer(
- msg,
- DNSAddress(
- question.name,
- type_,
- _CLASS_IN | _CLASS_UNIQUE,
- service.host_ttl,
- address,
- ),
- )
-
- name_to_find = question.name.lower()
- if name_to_find not in self.services:
- continue
- service = self.services[name_to_find]
-
- if question.type in (_TYPE_SRV, _TYPE_ANY):
- out.add_answer(
- msg,
- DNSService(
- question.name,
- _TYPE_SRV,
- _CLASS_IN | _CLASS_UNIQUE,
- service.host_ttl,
- service.priority,
- service.weight,
- cast(int, service.port),
- service.server,
- ),
- )
- if question.type in (_TYPE_TXT, _TYPE_ANY):
- out.add_answer(
- msg,
- DNSText(
- question.name,
- _TYPE_TXT,
- _CLASS_IN | _CLASS_UNIQUE,
- service.other_ttl,
- service.text,
- ),
- )
- if question.type == _TYPE_SRV:
- for address in service.addresses_by_version(IPVersion.All):
- type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A
- out.add_additional_answer(
- DNSAddress(
- service.server,
- type_,
- _CLASS_IN | _CLASS_UNIQUE,
- service.host_ttl,
- address,
- )
- )
- except Exception: # TODO stop catching all Exceptions
- self.log_exception_warning()
-
- if out is not None and out.answers:
- out.id = msg.id
- self.send(out, addr, port)
-
- def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None:
- """Sends an outgoing packet."""
- packet = out.packet()
- if len(packet) > _MAX_MSG_ABSOLUTE:
- self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet)
- return
- log.debug('Sending %r (%d bytes) as %r...', out, len(packet), packet)
- for s in self._respond_sockets:
- if self._GLOBAL_DONE:
- return
- try:
- if addr is None:
- real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR
- elif not can_send_to(s, addr):
- continue
- else:
- real_addr = addr
- bytes_sent = s.sendto(packet, 0, (real_addr, port))
- except Exception as exc: # TODO stop catching all Exceptions
- if (
- isinstance(exc, OSError)
- and exc.errno == errno.ENETUNREACH
- and s.family == socket.AF_INET6
- ):
- # with IPv6 we don't have a reliable way to determine if an interface actually has IPv6
- # support, so we have to try and ignore errors.
- continue
- # on send errors, log the exception and keep going
- self.log_exception_warning()
- else:
- if bytes_sent != len(packet):
- self.log_warning_once('!!! sent %d out of %d bytes to %r' % (bytes_sent, len(packet), s))
-
- def close(self) -> None:
- """Ends the background threads, and prevent this instance from
- servicing further queries."""
- if not self._GLOBAL_DONE:
- # remove service listeners
- self.remove_all_service_listeners()
- self.unregister_all_services()
- self._GLOBAL_DONE = True
-
- # shutdown recv socket and thread
- if not self.unicast:
- self.engine.del_reader(cast(socket.socket, self._listen_socket))
- cast(socket.socket, self._listen_socket).close()
- else:
- for s in self._respond_sockets:
- self.engine.del_reader(s)
- self.engine.join()
-
- # shutdown the rest
- self.notify_all()
- self.reaper.join()
- for s in self._respond_sockets:
- s.close()
diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py
new file mode 100644
index 00000000..24b6a233
--- /dev/null
+++ b/zeroconf/_cache.py
@@ -0,0 +1,209 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import itertools
+from typing import Dict, Iterable, Iterator, List, Optional, Union, cast
+
+from ._dns import (
+ DNSAddress,
+ DNSEntry,
+ DNSHinfo,
+ DNSPointer,
+ DNSRecord,
+ DNSService,
+ DNSText,
+ dns_entry_matches,
+)
+from ._utils.time import current_time_millis
+from .const import _TYPE_PTR
+
+_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService)
+_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService]
+_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]]
+
+
+def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None:
+ """Remove a key from a DNSRecord cache
+
+ This function must be run in from event loop.
+ """
+ del cache[key][entry]
+ if not cache[key]:
+ del cache[key]
+
+
+class DNSCache:
+ """A cache of DNS entries."""
+
+ def __init__(self) -> None:
+ self.cache: _DNSRecordCacheType = {}
+ self.service_cache: _DNSRecordCacheType = {}
+
+ # Functions prefixed with async_ are NOT threadsafe and must
+ # be run in the event loop.
+
+ def _async_add(self, entry: DNSRecord) -> None:
+ """Adds an entry.
+
+ This function must be run in from event loop.
+ """
+ # Previously storage of records was implemented as a list
+ # instead a dict. Since DNSRecords are now hashable, the implementation
+ # uses a dict to ensure that adding a new record to the cache
+ # replaces any existing records that are __eq__ to each other which
+ # removes the risk that accessing the cache from the wrong
+ # direction would return the old incorrect entry.
+ self.cache.setdefault(entry.key, {})[entry] = entry
+ if isinstance(entry, DNSService):
+ self.service_cache.setdefault(entry.server, {})[entry] = entry
+
+ def async_add_records(self, entries: Iterable[DNSRecord]) -> None:
+ """Add multiple records.
+
+ This function must be run in from event loop.
+ """
+ for entry in entries:
+ self._async_add(entry)
+
+ def _async_remove(self, entry: DNSRecord) -> None:
+ """Removes an entry.
+
+ This function must be run in from event loop.
+ """
+ if isinstance(entry, DNSService):
+ _remove_key(self.service_cache, entry.server, entry)
+ _remove_key(self.cache, entry.key, entry)
+
+ def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
+ """Remove multiple records.
+
+ This function must be run in from event loop.
+ """
+ for entry in entries:
+ self._async_remove(entry)
+
+ def async_expire(self, now: float) -> List[DNSRecord]:
+ """Purge expired entries from the cache.
+
+ This function must be run in from event loop.
+ """
+ expired = [record for record in itertools.chain(*self.cache.values()) if record.is_expired(now)]
+ self.async_remove_records(expired)
+ return expired
+
+ def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]:
+ """Gets a unique entry by key. Will return None if there is no
+ matching entry.
+
+ This function is not threadsafe and must be called from
+ the event loop.
+ """
+ return self.cache.get(entry.key, {}).get(entry)
+
+ def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]:
+ """Gets all matching entries by details.
+
+ This function is not threadsafe and must be called from
+ the event loop.
+ """
+ key = name.lower()
+ for entry in self.cache.get(key, []):
+ if dns_entry_matches(entry, key, type_, class_):
+ yield entry
+
+ def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:
+ """Returns a dict of entries whose key matches the name.
+
+ This function is not threadsafe and must be called from
+ the event loop.
+ """
+ return self.cache.get(name.lower(), {})
+
+ def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]:
+ """Returns a dict of entries whose key matches the server.
+
+ This function is not threadsafe and must be called from
+ the event loop.
+ """
+ return self.service_cache.get(name.lower(), {})
+
+ # The below functions are threadsafe and do not need to be run in the
+ # event loop, however they all make copies so they significantly
+ # inefficent
+
+ def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
+ """Gets an entry by key. Will return None if there is no
+ matching entry."""
+ if isinstance(entry, _UNIQUE_RECORD_TYPES):
+ return self.cache.get(entry.key, {}).get(entry)
+ for cached_entry in reversed(list(self.cache.get(entry.key, []))):
+ if entry.__eq__(cached_entry):
+ return cached_entry
+ return None
+
+ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]:
+ """Gets the first matching entry by details. Returns None if no entries match.
+
+ Calling this function is not recommended as it will only
+ return one record even if there are multiple entries.
+
+ For example if there are multiple A or AAAA addresses this
+ function will return the last one that was added to the cache
+ which may not be the one you expect.
+
+ Use get_all_by_details instead.
+ """
+ key = name.lower()
+ for cached_entry in reversed(list(self.cache.get(key, []))):
+ if dns_entry_matches(cached_entry, key, type_, class_):
+ return cached_entry
+ return None
+
+ def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]:
+ """Gets all matching entries by details."""
+ key = name.lower()
+ return [
+ entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_)
+ ]
+
+ def entries_with_server(self, server: str) -> List[DNSRecord]:
+ """Returns a list of entries whose server matches the name."""
+ return list(self.service_cache.get(server.lower(), []))
+
+ def entries_with_name(self, name: str) -> List[DNSRecord]:
+ """Returns a list of entries whose key matches the name."""
+ return list(self.cache.get(name.lower(), []))
+
+ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]:
+ now = current_time_millis()
+ for record in reversed(self.entries_with_name(name)):
+ if (
+ record.type == _TYPE_PTR
+ and not record.is_expired(now)
+ and cast(DNSPointer, record).alias == alias
+ ):
+ return record
+ return None
+
+ def names(self) -> List[str]:
+ """Return a copy of the list of current cache names."""
+ return list(self.cache)
diff --git a/zeroconf/_core.py b/zeroconf/_core.py
new file mode 100644
index 00000000..1575eba2
--- /dev/null
+++ b/zeroconf/_core.py
@@ -0,0 +1,928 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import asyncio
+import itertools
+import logging
+import random
+import socket
+import sys
+import threading
+from types import TracebackType # noqa # used in type hints
+from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union, cast
+
+from ._cache import DNSCache
+from ._dns import DNSQuestion, DNSQuestionType
+from ._exceptions import NonUniqueNameException
+from ._handlers import (
+ MulticastOutgoingQueue,
+ QueryHandler,
+ RecordManager,
+ construct_outgoing_multicast_answers,
+ construct_outgoing_unicast_answers,
+)
+from ._history import QuestionHistory
+from ._logger import QuietLogger, log
+from ._protocol.incoming import DNSIncoming
+from ._protocol.outgoing import DNSOutgoing
+from ._services import ServiceListener
+from ._services.browser import ServiceBrowser
+from ._services.info import ServiceInfo, instance_name_from_service_info
+from ._services.registry import ServiceRegistry
+from ._updates import RecordUpdate, RecordUpdateListener
+from ._utils.asyncio import (
+ await_awaitable,
+ get_running_loop,
+ run_coro_with_timeout,
+ shutdown_loop,
+ wait_event_or_timeout,
+)
+from ._utils.name import service_type_name
+from ._utils.net import (
+ IPVersion,
+ InterfaceChoice,
+ InterfacesType,
+ autodetect_ip_version,
+ can_send_to,
+ create_sockets,
+)
+from ._utils.time import current_time_millis, millis_to_seconds
+from .const import (
+ _CACHE_CLEANUP_INTERVAL,
+ _CHECK_TIME,
+ _CLASS_IN,
+ _CLASS_UNIQUE,
+ _FLAGS_AA,
+ _FLAGS_QR_QUERY,
+ _FLAGS_QR_RESPONSE,
+ _MAX_MSG_ABSOLUTE,
+ _MDNS_ADDR,
+ _MDNS_ADDR6,
+ _MDNS_PORT,
+ _ONE_SECOND,
+ _REGISTER_TIME,
+ _TYPE_PTR,
+ _UNREGISTER_TIME,
+)
+
+_TC_DELAY_RANDOM_INTERVAL = (400, 500)
+# The maximum amont of time to delay a multicast
+# response in order to aggregate answers
+_AGGREGATION_DELAY = 500 # ms
+# The maximum amont of time to delay a multicast
+# response in order to aggregate answers after
+# it has already been delayed to protect the network
+# from excessive traffic. We use a shorter time
+# window here as we want to _try_ to answer all
+# queries in under 1350ms while protecting
+# the network from excessive traffic to ensure
+# a service info request with two questions
+# can be answered in the default timeout of
+# 3000ms
+_PROTECTED_AGGREGATION_DELAY = 200 # ms
+
+_CLOSE_TIMEOUT = 3000 # ms
+_REGISTER_BROADCASTS = 3
+
+
+class AsyncEngine:
+ """An engine wraps sockets in the event loop."""
+
+ def __init__(
+ self,
+ zeroconf: 'Zeroconf',
+ listen_socket: Optional[socket.socket],
+ respond_sockets: List[socket.socket],
+ ) -> None:
+ self.loop: Optional[asyncio.AbstractEventLoop] = None
+ self.zc = zeroconf
+ self.protocols: List[AsyncListener] = []
+ self.readers: List[asyncio.DatagramTransport] = []
+ self.senders: List[asyncio.DatagramTransport] = []
+ self._listen_socket = listen_socket
+ self._respond_sockets = respond_sockets
+ self._cleanup_timer: Optional[asyncio.TimerHandle] = None
+ self._running_event: Optional[asyncio.Event] = None
+
+ def setup(self, loop: asyncio.AbstractEventLoop, loop_thread_ready: Optional[threading.Event]) -> None:
+ """Set up the instance."""
+ self.loop = loop
+ self._running_event = asyncio.Event()
+ self.loop.create_task(self._async_setup(loop_thread_ready))
+
+ async def _async_setup(self, loop_thread_ready: Optional[threading.Event]) -> None:
+ """Set up the instance."""
+ assert self.loop is not None
+ self._cleanup_timer = self.loop.call_later(
+ millis_to_seconds(_CACHE_CLEANUP_INTERVAL), self._async_cache_cleanup
+ )
+ await self._async_create_endpoints()
+ assert self._running_event is not None
+ self._running_event.set()
+ if loop_thread_ready:
+ loop_thread_ready.set()
+
+ async def async_wait_for_start(self) -> None:
+ """Wait for start up."""
+ assert self._running_event is not None
+ await self._running_event.wait()
+
+ async def _async_create_endpoints(self) -> None:
+ """Create endpoints to send and receive."""
+ assert self.loop is not None
+ loop = self.loop
+ reader_sockets = []
+ sender_sockets = []
+ if self._listen_socket:
+ reader_sockets.append(self._listen_socket)
+ for s in self._respond_sockets:
+ if s not in reader_sockets:
+ reader_sockets.append(s)
+ sender_sockets.append(s)
+
+ for s in reader_sockets:
+ transport, protocol = await loop.create_datagram_endpoint(lambda: AsyncListener(self.zc), sock=s)
+ self.protocols.append(cast(AsyncListener, protocol))
+ self.readers.append(cast(asyncio.DatagramTransport, transport))
+ if s in sender_sockets:
+ self.senders.append(cast(asyncio.DatagramTransport, transport))
+
+ def _async_cache_cleanup(self) -> None:
+ """Periodic cache cleanup."""
+ now = current_time_millis()
+ self.zc.question_history.async_expire(now)
+ self.zc.record_manager.async_updates(
+ now, [RecordUpdate(record, None) for record in self.zc.cache.async_expire(now)]
+ )
+ self.zc.record_manager.async_updates_complete()
+ assert self.loop is not None
+ self._cleanup_timer = self.loop.call_later(
+ millis_to_seconds(_CACHE_CLEANUP_INTERVAL), self._async_cache_cleanup
+ )
+
+ async def _async_close(self) -> None:
+ """Cancel and wait for the cleanup task to finish."""
+ self._async_shutdown()
+ await asyncio.sleep(0) # flush out any call soons
+ assert self._cleanup_timer is not None
+ self._cleanup_timer.cancel()
+
+ def _async_shutdown(self) -> None:
+ """Shutdown transports and sockets."""
+ for transport in itertools.chain(self.senders, self.readers):
+ transport.close()
+
+ def close(self) -> None:
+ """Close from sync context."""
+ assert self.loop is not None
+ # Guard against Zeroconf.close() being called from the eventloop
+ if get_running_loop() == self.loop:
+ self._async_shutdown()
+ return
+ if not self.loop.is_running():
+ return
+ run_coro_with_timeout(self._async_close(), self.loop, _CLOSE_TIMEOUT)
+
+
+class AsyncListener(asyncio.Protocol, QuietLogger):
+
+ """A Listener is used by this module to listen on the multicast
+ group to which DNS messages are sent, allowing the implementation
+ to cache information as it arrives.
+
+ It requires registration with an Engine object in order to have
+ the read() method called when a socket is available for reading."""
+
+ __slots__ = ('zc', 'data', 'last_time', 'transport', 'sock_description', '_deferred', '_timers')
+
+ def __init__(self, zc: 'Zeroconf') -> None:
+ self.zc = zc
+ self.data: Optional[bytes] = None
+ self.last_time: float = 0
+ self.transport: Optional[asyncio.DatagramTransport] = None
+ self.sock_description: Optional[str] = None
+ self._deferred: Dict[str, List[DNSIncoming]] = {}
+ self._timers: Dict[str, asyncio.TimerHandle] = {}
+ super().__init__()
+
+ def suppress_duplicate_packet(self, data: bytes, now: float) -> bool:
+ """Suppress duplicate packet if the last one was the same in the last second."""
+ if self.data == data and (now - 1000) < self.last_time:
+ return True
+ self.data = data
+ self.last_time = now
+ return False
+
+ def datagram_received(
+ self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]]
+ ) -> None:
+ assert self.transport is not None
+ v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = ()
+ data_len = len(data)
+
+ if len(addrs) == 2:
+ # https://github.com/python/mypy/issues/1178
+ addr, port = addrs # type: ignore
+ scope = None
+ else:
+ # https://github.com/python/mypy/issues/1178
+ addr, port, flow, scope = addrs # type: ignore
+ log.debug('IPv6 scope_id %d associated to the receiving interface', scope)
+ v6_flow_scope = (flow, scope)
+
+ now = current_time_millis()
+ if self.suppress_duplicate_packet(data, now):
+ # Guard against duplicate packets
+ log.debug(
+ 'Ignoring duplicate message received from %r:%r [socket %s] (%d bytes) as [%r]',
+ addr,
+ port,
+ self.sock_description,
+ data_len,
+ data,
+ )
+ return
+
+ if data_len > _MAX_MSG_ABSOLUTE:
+ # Guard against oversized packets to ensure bad implementations cannot overwhelm
+ # the system.
+ log.debug(
+ "Discarding incoming packet with length %s, which is larger "
+ "than the absolute maximum size of %s",
+ data_len,
+ _MAX_MSG_ABSOLUTE,
+ )
+ return
+
+ msg = DNSIncoming(data, (addr, port), scope, now)
+ if msg.valid:
+ log.debug(
+ 'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]',
+ addr,
+ port,
+ self.sock_description,
+ msg,
+ data_len,
+ data,
+ )
+ else:
+ log.debug(
+ 'Received from %r:%r [socket %s]: (%d bytes) [%r]',
+ addr,
+ port,
+ self.sock_description,
+ data_len,
+ data,
+ )
+ return
+
+ if not msg.is_query():
+ self.zc.handle_response(msg)
+ return
+
+ self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope)
+
+ def handle_query_or_defer(
+ self,
+ msg: DNSIncoming,
+ addr: str,
+ port: int,
+ transport: asyncio.DatagramTransport,
+ v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
+ ) -> None:
+ """Deal with incoming query packets. Provides a response if
+ possible."""
+ if not msg.truncated:
+ self._respond_query(msg, addr, port, transport, v6_flow_scope)
+ return
+
+ deferred = self._deferred.setdefault(addr, [])
+ # If we get the same packet we ignore it
+ for incoming in reversed(deferred):
+ if incoming.data == msg.data:
+ return
+ deferred.append(msg)
+ delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL))
+ assert self.zc.loop is not None
+ self._cancel_any_timers_for_addr(addr)
+ self._timers[addr] = self.zc.loop.call_later(
+ delay, self._respond_query, None, addr, port, transport, v6_flow_scope
+ )
+
+ def _cancel_any_timers_for_addr(self, addr: str) -> None:
+ """Cancel any future truncated packet timers for the address."""
+ if addr in self._timers:
+ self._timers.pop(addr).cancel()
+
+ def _respond_query(
+ self,
+ msg: Optional[DNSIncoming],
+ addr: str,
+ port: int,
+ transport: asyncio.DatagramTransport,
+ v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
+ ) -> None:
+ """Respond to a query and reassemble any truncated deferred packets."""
+ self._cancel_any_timers_for_addr(addr)
+ packets = self._deferred.pop(addr, [])
+ if msg:
+ packets.append(msg)
+
+ self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope)
+
+ def error_received(self, exc: Exception) -> None:
+ """Likely socket closed or IPv6."""
+ # We preformat the message string with the socket as we want
+ # log_exception_once to log a warrning message once PER EACH
+ # different socket in case there are problems with multiple
+ # sockets
+ msg_str = f"Error with socket {self.sock_description}): %s"
+ self.log_exception_once(exc, msg_str, exc)
+
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
+ self.transport = cast(asyncio.DatagramTransport, transport)
+ sock_name = self.transport.get_extra_info('sockname')
+ sock_fileno = self.transport.get_extra_info('socket').fileno()
+ self.sock_description = f"{sock_fileno} ({sock_name})"
+
+ def connection_lost(self, exc: Optional[Exception]) -> None:
+ """Handle connection lost."""
+
+
+def async_send_with_transport(
+ log_debug: bool,
+ transport: asyncio.DatagramTransport,
+ packet: bytes,
+ packet_num: int,
+ out: DNSOutgoing,
+ addr: Optional[str],
+ port: int,
+ v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
+) -> None:
+ s = transport.get_extra_info('socket')
+ ipv6_socket = s.family == socket.AF_INET6
+ if addr is None:
+ real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR
+ else:
+ real_addr = addr
+ if not can_send_to(ipv6_socket, real_addr):
+ return
+ if log_debug:
+ log.debug(
+ 'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...',
+ real_addr,
+ port or _MDNS_PORT,
+ s.fileno(),
+ transport.get_extra_info('sockname'),
+ len(packet),
+ packet_num + 1,
+ out,
+ packet,
+ )
+ # Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6
+ # address tuple: https://docs.python.org/3.6/library/socket.html#socket-families
+ if ipv6_socket and not v6_flow_scope:
+ _, _, sock_flowinfo, sock_scopeid = s.getsockname()
+ v6_flow_scope = (sock_flowinfo, sock_scopeid)
+ transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope))
+
+
+class Zeroconf(QuietLogger):
+
+ """Implementation of Zeroconf Multicast DNS Service Discovery
+
+ Supports registration, unregistration, queries and browsing.
+ """
+
+ def __init__(
+ self,
+ interfaces: InterfacesType = InterfaceChoice.All,
+ unicast: bool = False,
+ ip_version: Optional[IPVersion] = None,
+ apple_p2p: bool = False,
+ ) -> None:
+ """Creates an instance of the Zeroconf class, establishing
+ multicast communications, listening and reaping threads.
+
+ :param interfaces: :class:`InterfaceChoice` or a list of IP addresses
+ (IPv4 and IPv6) and interface indexes (IPv6 only).
+
+ IPv6 notes for non-POSIX systems:
+ * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default`
+ on Python versions before 3.8.
+
+ Also listening on loopback (``::1``) doesn't work, use a real address.
+ :param ip_version: IP versions to support. If `choice` is a list, the default is detected
+ from it. Otherwise defaults to V4 only for backward compatibility.
+ :param apple_p2p: use AWDL interface (only macOS)
+ """
+ if ip_version is None:
+ ip_version = autodetect_ip_version(interfaces)
+
+ self.done = False
+
+ if apple_p2p and sys.platform != 'darwin':
+ raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.')
+
+ self.unicast = unicast
+ listen_socket, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p)
+ log.debug('Listen socket %s, respond sockets %s', listen_socket, respond_sockets)
+
+ self.engine = AsyncEngine(self, listen_socket, respond_sockets)
+
+ self.browsers: Dict[ServiceListener, ServiceBrowser] = {}
+ self.registry = ServiceRegistry()
+ self.cache = DNSCache()
+ self.question_history = QuestionHistory()
+ self.query_handler = QueryHandler(self.registry, self.cache, self.question_history)
+ self.record_manager = RecordManager(self)
+
+ self.notify_event: Optional[asyncio.Event] = None
+ self.loop: Optional[asyncio.AbstractEventLoop] = None
+ self._loop_thread: Optional[threading.Thread] = None
+
+ self._out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY)
+ self._out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY)
+
+ self.start()
+
+ def start(self) -> None:
+ """Start Zeroconf."""
+ self.loop = get_running_loop()
+ if self.loop:
+ self.notify_event = asyncio.Event()
+ self.engine.setup(self.loop, None)
+ return
+ self._start_thread()
+
+ def _start_thread(self) -> None:
+ """Start a thread with a running event loop."""
+ loop_thread_ready = threading.Event()
+
+ def _run_loop() -> None:
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(self.loop)
+ self.notify_event = asyncio.Event()
+ self.engine.setup(self.loop, loop_thread_ready)
+ self.loop.run_forever()
+
+ self._loop_thread = threading.Thread(target=_run_loop, daemon=True)
+ self._loop_thread.start()
+ loop_thread_ready.wait()
+
+ async def async_wait_for_start(self) -> None:
+ """Wait for start up."""
+ await self.engine.async_wait_for_start()
+
+ @property
+ def listeners(self) -> List[RecordUpdateListener]:
+ return self.record_manager.listeners
+
+ async def async_wait(self, timeout: float) -> None:
+ """Calling task waits for a given number of milliseconds or until notified."""
+ assert self.notify_event is not None
+ await wait_event_or_timeout(self.notify_event, timeout=millis_to_seconds(timeout))
+
+ def notify_all(self) -> None:
+ """Notifies all waiting threads and notify listeners."""
+ assert self.loop is not None
+ self.loop.call_soon_threadsafe(self.async_notify_all)
+
+ def async_notify_all(self) -> None:
+ """Schedule an async_notify_all."""
+ assert self.notify_event is not None
+ self.notify_event.set()
+ self.notify_event.clear()
+
+ def get_service_info(
+ self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None
+ ) -> Optional[ServiceInfo]:
+ """Returns network's service information for a particular
+ name and type, or None if no service matches by the timeout,
+ which defaults to 3 seconds."""
+ info = ServiceInfo(type_, name)
+ if info.request(self, timeout, question_type):
+ return info
+ return None
+
+ def add_service_listener(self, type_: str, listener: ServiceListener) -> None:
+ """Adds a listener for a particular service type. This object
+ will then have its add_service and remove_service methods called when
+ services of that type become available and unavailable."""
+ self.remove_service_listener(listener)
+ self.browsers[listener] = ServiceBrowser(self, type_, listener)
+
+ def remove_service_listener(self, listener: ServiceListener) -> None:
+ """Removes a listener from the set that is currently listening."""
+ if listener in self.browsers:
+ self.browsers[listener].cancel()
+ del self.browsers[listener]
+
+ def remove_all_service_listeners(self) -> None:
+ """Removes a listener from the set that is currently listening."""
+ for listener in list(self.browsers):
+ self.remove_service_listener(listener)
+
+ def register_service(
+ self,
+ info: ServiceInfo,
+ ttl: Optional[int] = None,
+ allow_name_change: bool = False,
+ cooperating_responders: bool = False,
+ ) -> None:
+ """Registers service information to the network with a default TTL.
+ Zeroconf will then respond to requests for information for that
+ service. The name of the service may be changed if needed to make
+ it unique on the network. Additionally multiple cooperating responders
+ can register the same service on the network for resilience
+ (if you want this behavior set `cooperating_responders` to `True`)."""
+ assert self.loop is not None
+ run_coro_with_timeout(
+ await_awaitable(
+ self.async_register_service(info, ttl, allow_name_change, cooperating_responders)
+ ),
+ self.loop,
+ _REGISTER_TIME * _REGISTER_BROADCASTS,
+ )
+
+ async def async_register_service(
+ self,
+ info: ServiceInfo,
+ ttl: Optional[int] = None,
+ allow_name_change: bool = False,
+ cooperating_responders: bool = False,
+ ) -> Awaitable:
+ """Registers service information to the network with a default TTL.
+ Zeroconf will then respond to requests for information for that
+ service. The name of the service may be changed if needed to make
+ it unique on the network. Additionally multiple cooperating responders
+ can register the same service on the network for resilience
+ (if you want this behavior set `cooperating_responders` to `True`)."""
+ if ttl is not None:
+ # ttl argument is used to maintain backward compatibility
+ # Setting TTLs via ServiceInfo is preferred
+ info.host_ttl = ttl
+ info.other_ttl = ttl
+
+ await self.async_wait_for_start()
+ await self.async_check_service(info, allow_name_change, cooperating_responders)
+ self.registry.async_add(info)
+ return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))
+
+ def update_service(self, info: ServiceInfo) -> None:
+ """Registers service information to the network with a default TTL.
+ Zeroconf will then respond to requests for information for that
+ service."""
+ assert self.loop is not None
+ run_coro_with_timeout(
+ await_awaitable(self.async_update_service(info)), self.loop, _REGISTER_TIME * _REGISTER_BROADCASTS
+ )
+
+ async def async_update_service(self, info: ServiceInfo) -> Awaitable:
+ """Registers service information to the network with a default TTL.
+ Zeroconf will then respond to requests for information for that
+ service."""
+ self.registry.async_update(info)
+ return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))
+
+ async def _async_broadcast_service(
+ self,
+ info: ServiceInfo,
+ interval: int,
+ ttl: Optional[int],
+ broadcast_addresses: bool = True,
+ ) -> None:
+ """Send a broadcasts to announce a service at intervals."""
+ for i in range(_REGISTER_BROADCASTS):
+ if i != 0:
+ await asyncio.sleep(millis_to_seconds(interval))
+ self.async_send(self.generate_service_broadcast(info, ttl, broadcast_addresses))
+
+ def generate_service_broadcast(
+ self,
+ info: ServiceInfo,
+ ttl: Optional[int],
+ broadcast_addresses: bool = True,
+ ) -> DNSOutgoing:
+ """Generate a broadcast to announce a service."""
+ out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
+ self._add_broadcast_answer(out, info, ttl, broadcast_addresses)
+ return out
+
+ def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use
+ """Generate a query to lookup a service."""
+ out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
+ # https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
+ # Because of the mDNS multicast rate-limiting
+ # rules, the probes SHOULD be sent as "QU" questions with the unicast-
+ # response bit set, to allow a defending host to respond immediately
+ # via unicast, instead of potentially having to wait before replying
+ # via multicast.
+ #
+ # _CLASS_UNIQUE is the "QU" bit
+ out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE))
+ out.add_authorative_answer(info.dns_pointer(created=current_time_millis()))
+ return out
+
+ def _add_broadcast_answer( # pylint: disable=no-self-use
+ self,
+ out: DNSOutgoing,
+ info: ServiceInfo,
+ override_ttl: Optional[int],
+ broadcast_addresses: bool = True,
+ ) -> None:
+ """Add answers to broadcast a service."""
+ now = current_time_millis()
+ other_ttl = info.other_ttl if override_ttl is None else override_ttl
+ host_ttl = info.host_ttl if override_ttl is None else override_ttl
+ out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl, created=now), 0)
+ out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0)
+ out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0)
+ if broadcast_addresses:
+ for dns_address in info.dns_addresses(override_ttl=host_ttl, created=now):
+ out.add_answer_at_time(dns_address, 0)
+
+ def unregister_service(self, info: ServiceInfo) -> None:
+ """Unregister a service."""
+ assert self.loop is not None
+ run_coro_with_timeout(
+ self.async_unregister_service(info), self.loop, _UNREGISTER_TIME * _REGISTER_BROADCASTS
+ )
+
+ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
+ """Unregister a service."""
+ self.registry.async_remove(info)
+ # If another server uses the same addresses, we do not want to send
+ # goodbye packets for the address records
+
+ entries = self.registry.async_get_infos_server(info.server)
+ broadcast_addresses = not bool(entries)
+ return asyncio.ensure_future(
+ self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses)
+ )
+
+ def generate_unregister_all_services(self) -> Optional[DNSOutgoing]:
+ """Generate a DNSOutgoing goodbye for all services and remove them from the registry."""
+ service_infos = self.registry.async_get_service_infos()
+ if not service_infos:
+ return None
+ out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
+ for info in service_infos:
+ self._add_broadcast_answer(out, info, 0)
+ self.registry.async_remove(service_infos)
+ return out
+
+ async def async_unregister_all_services(self) -> None:
+ """Unregister all registered services.
+
+ Unlike async_register_service and async_unregister_service, this
+ method does not return a future and is always expected to be
+ awaited since its only called at shutdown.
+ """
+ # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
+ out = self.generate_unregister_all_services()
+ if not out:
+ return
+ for i in range(_REGISTER_BROADCASTS):
+ if i != 0:
+ await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME))
+ self.async_send(out)
+
+ def unregister_all_services(self) -> None:
+ """Unregister all registered services."""
+ assert self.loop is not None
+ run_coro_with_timeout(
+ self.async_unregister_all_services(), self.loop, _UNREGISTER_TIME * _REGISTER_BROADCASTS
+ )
+
+ async def async_check_service(
+ self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False
+ ) -> None:
+ """Checks the network for a unique service name, modifying the
+ ServiceInfo passed in if it is not unique."""
+ instance_name = instance_name_from_service_info(info)
+ if cooperating_responders:
+ return
+ next_instance_number = 2
+ next_time = now = current_time_millis()
+ i = 0
+ while i < _REGISTER_BROADCASTS:
+ # check for a name conflict
+ while self.cache.current_entry_with_name_and_alias(info.type, info.name):
+ if not allow_name_change:
+ raise NonUniqueNameException
+
+ # change the name and look for a conflict
+ info.name = f'{instance_name}-{next_instance_number}.{info.type}'
+ next_instance_number += 1
+ service_type_name(info.name)
+ next_time = now
+ i = 0
+
+ if now < next_time:
+ await self.async_wait(next_time - now)
+ now = current_time_millis()
+ continue
+
+ self.async_send(self.generate_service_query(info))
+ i += 1
+ next_time += _CHECK_TIME
+
+ def add_listener(
+ self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]]
+ ) -> None:
+ """Adds a listener for a given question. The listener will have
+ its update_record method called when information is available to
+ answer the question(s).
+
+ This function is threadsafe
+ """
+ assert self.loop is not None
+ self.loop.call_soon_threadsafe(self.record_manager.async_add_listener, listener, question)
+
+ def remove_listener(self, listener: RecordUpdateListener) -> None:
+ """Removes a listener.
+
+ This function is threadsafe
+ """
+ assert self.loop is not None
+ self.loop.call_soon_threadsafe(self.record_manager.async_remove_listener, listener)
+
+ def async_add_listener(
+ self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]]
+ ) -> None:
+ """Adds a listener for a given question. The listener will have
+ its update_record method called when information is available to
+ answer the question(s).
+
+ This function is not threadsafe and must be called in the eventloop.
+ """
+ self.record_manager.async_add_listener(listener, question)
+
+ def async_remove_listener(self, listener: RecordUpdateListener) -> None:
+ """Removes a listener.
+
+ This function is not threadsafe and must be called in the eventloop.
+ """
+ self.record_manager.async_remove_listener(listener)
+
+ def handle_response(self, msg: DNSIncoming) -> None:
+ """Deal with incoming response packets. All answers
+ are held in the cache, and listeners are notified."""
+ self.record_manager.async_updates_from_response(msg)
+
+ def handle_assembled_query(
+ self,
+ packets: List[DNSIncoming],
+ addr: str,
+ port: int,
+ transport: asyncio.DatagramTransport,
+ v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
+ ) -> None:
+ """Respond to a (re)assembled query.
+
+ If the protocol recieved packets with the TC bit set, it will
+ wait a bit for the rest of the packets and only call
+ handle_assembled_query once it has a complete set of packets
+ or the timer expires. If the TC bit is not set, a single
+ packet will be in packets.
+ """
+ now = packets[0].now
+ ucast_source = port != _MDNS_PORT
+ question_answers = self.query_handler.async_response(packets, ucast_source)
+ if question_answers.ucast:
+ questions = packets[0].questions
+ id_ = packets[0].id
+ out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_)
+ # When sending unicast, only send back the reply
+ # via the same socket that it was recieved from
+ # as we know its reachable from that socket
+ self.async_send(out, addr, port, v6_flow_scope, transport)
+ if question_answers.mcast_now:
+ self.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now))
+ if question_answers.mcast_aggregate:
+ self._out_queue.async_add(now, question_answers.mcast_aggregate)
+ if question_answers.mcast_aggregate_last_second:
+ # https://datatracker.ietf.org/doc/html/rfc6762#section-14
+ # If we broadcast it in the last second, we have to delay
+ # at least a second before we send it again
+ self._out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second)
+
+ def send(
+ self,
+ out: DNSOutgoing,
+ addr: Optional[str] = None,
+ port: int = _MDNS_PORT,
+ v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
+ transport: Optional[asyncio.DatagramTransport] = None,
+ ) -> None:
+ """Sends an outgoing packet threadsafe."""
+ assert self.loop is not None
+ self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope, transport)
+
+ def async_send(
+ self,
+ out: DNSOutgoing,
+ addr: Optional[str] = None,
+ port: int = _MDNS_PORT,
+ v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
+ transport: Optional[asyncio.DatagramTransport] = None,
+ ) -> None:
+ """Sends an outgoing packet."""
+ if self.done:
+ return
+
+ # If no transport is specified, we send to all the ones
+ # with the same address family
+ transports = [transport] if transport else self.engine.senders
+ log_debug = log.isEnabledFor(logging.DEBUG)
+
+ for packet_num, packet in enumerate(out.packets()):
+ if len(packet) > _MAX_MSG_ABSOLUTE:
+ self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet)
+ return
+ for send_transport in transports:
+ async_send_with_transport(
+ log_debug, send_transport, packet, packet_num, out, addr, port, v6_flow_scope
+ )
+
+ def _close(self) -> None:
+ """Set global done and remove all service listeners."""
+ if self.done:
+ return
+ self.remove_all_service_listeners()
+ self.done = True
+
+ def _shutdown_threads(self) -> None:
+ """Shutdown any threads."""
+ self.notify_all()
+ if not self._loop_thread:
+ return
+ assert self.loop is not None
+ shutdown_loop(self.loop)
+ self._loop_thread.join()
+ self._loop_thread = None
+
+ def close(self) -> None:
+ """Ends the background threads, and prevent this instance from
+ servicing further queries.
+
+ This method is idempotent and irreversible.
+ """
+ assert self.loop is not None
+ if self.loop.is_running():
+ if self.loop == get_running_loop():
+ log.warning(
+ "unregister_all_services skipped as it does blocking i/o; use AsyncZeroconf with asyncio"
+ )
+ else:
+ self.unregister_all_services()
+ self._close()
+ self.engine.close()
+ self._shutdown_threads()
+
+ async def _async_close(self) -> None:
+ """Ends the background threads, and prevent this instance from
+ servicing further queries.
+
+ This method is idempotent and irreversible.
+
+ This call only intended to be used by AsyncZeroconf
+
+ Callers are responsible for unregistering all services
+ before calling this function
+ """
+ self._close()
+ await self.engine._async_close() # pylint: disable=protected-access
+ self._shutdown_threads()
+
+ def __enter__(self) -> 'Zeroconf':
+ return self
+
+ def __exit__( # pylint: disable=useless-return
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> Optional[bool]:
+ self.close()
+ return None
diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py
new file mode 100644
index 00000000..a551a7da
--- /dev/null
+++ b/zeroconf/_dns.py
@@ -0,0 +1,534 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import enum
+import socket
+from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union, cast
+
+from ._exceptions import AbstractMethodException
+from ._utils.net import _is_v6_address
+from ._utils.time import current_time_millis, millis_to_seconds
+from .const import (
+ _CLASSES,
+ _CLASS_MASK,
+ _CLASS_UNIQUE,
+ _TYPES,
+ _TYPE_ANY,
+)
+
+_LEN_BYTE = 1
+_LEN_SHORT = 2
+_LEN_INT = 4
+
+_BASE_MAX_SIZE = _LEN_SHORT + _LEN_SHORT + _LEN_INT + _LEN_SHORT # type # class # ttl # length
+_NAME_COMPRESSION_MIN_SIZE = _LEN_BYTE * 2
+
+_EXPIRE_FULL_TIME_MS = 1000
+_EXPIRE_STALE_TIME_MS = 500
+_RECENT_TIME_MS = 250
+
+
+if TYPE_CHECKING:
+ from ._protocol.incoming import DNSIncoming
+ from ._protocol.outgoing import DNSOutgoing
+
+
+@enum.unique
+class DNSQuestionType(enum.Enum):
+ """An MDNS question type.
+
+ "QU" - questions requesting unicast responses
+ "QM" - questions requesting multicast responses
+ https://datatracker.ietf.org/doc/html/rfc6762#section-5.4
+ """
+
+ QU = 1
+ QM = 2
+
+
+def dns_entry_matches(record: 'DNSEntry', key: str, type_: int, class_: int) -> bool:
+ return key == record.key and type_ == record.type and class_ == record.class_
+
+
+class DNSEntry:
+
+ """A DNS entry"""
+
+ __slots__ = ('key', 'name', 'type', 'class_', 'unique')
+
+ def __init__(self, name: str, type_: int, class_: int) -> None:
+ self.key = name.lower()
+ self.name = name
+ self.type = type_
+ self.class_ = class_ & _CLASS_MASK
+ self.unique = (class_ & _CLASS_UNIQUE) != 0
+
+ def __eq__(self, other: Any) -> bool:
+ """Equality test on key (lowercase name), type, and class"""
+ return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry)
+
+ @staticmethod
+ def get_class_(class_: int) -> str:
+ """Class accessor"""
+ return _CLASSES.get(class_, f"?({class_})")
+
+ @staticmethod
+ def get_type(t: int) -> str:
+ """Type accessor"""
+ return _TYPES.get(t, f"?({t})")
+
+ def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str:
+ """String representation with additional information"""
+ return "{}[{},{}{},{}]{}".format(
+ hdr,
+ self.get_type(self.type),
+ self.get_class_(self.class_),
+ "-unique" if self.unique else "",
+ self.name,
+ "=%s" % cast(Any, other) if other is not None else "",
+ )
+
+
+class DNSQuestion(DNSEntry):
+
+ """A DNS question entry"""
+
+ __slots__ = ('_hash',)
+
+ def __init__(self, name: str, type_: int, class_: int) -> None:
+ super().__init__(name, type_, class_)
+ self._hash = hash((self.key, type_, self.class_))
+
+ def answered_by(self, rec: 'DNSRecord') -> bool:
+ """Returns true if the question is answered by the record"""
+ return self.class_ == rec.class_ and self.type in (rec.type, _TYPE_ANY) and self.name == rec.name
+
+ def __hash__(self) -> int:
+ return self._hash
+
+ def __eq__(self, other: Any) -> bool:
+ """Tests equality on dns question."""
+ return isinstance(other, DNSQuestion) and dns_entry_matches(other, self.key, self.type, self.class_)
+
+ @property
+ def max_size(self) -> int:
+ """Maximum size of the question in the packet."""
+ return len(self.name.encode('utf-8')) + _LEN_BYTE + _LEN_SHORT + _LEN_SHORT # type # class
+
+ @property
+ def unicast(self) -> bool:
+ """Returns true if the QU (not QM) is set.
+
+ unique shares the same mask as the one
+ used for unicast.
+ """
+ return self.unique
+
+ @unicast.setter
+ def unicast(self, value: bool) -> None:
+ """Sets the QU bit (not QM)."""
+ self.unique = value
+
+ def __repr__(self) -> str:
+ """String representation"""
+ return "{}[question,{},{},{}]".format(
+ self.get_type(self.type),
+ "QU" if self.unicast else "QM",
+ self.get_class_(self.class_),
+ self.name,
+ )
+
+
+class DNSRecord(DNSEntry):
+
+ """A DNS record - like a DNS entry, but has a TTL"""
+
+ __slots__ = ('ttl', 'created')
+
+ # TODO: Switch to just int ttl
+ def __init__(
+ self, name: str, type_: int, class_: int, ttl: Union[float, int], created: Optional[float] = None
+ ) -> None:
+ super().__init__(name, type_, class_)
+ self.ttl = ttl
+ self.created = created or current_time_millis()
+
+ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use
+ """Abstract method"""
+ raise AbstractMethodException
+
+ def suppressed_by(self, msg: 'DNSIncoming') -> bool:
+ """Returns true if any answer in a message can suffice for the
+ information held in this record."""
+ return any(self.suppressed_by_answer(record) for record in msg.answers)
+
+ def suppressed_by_answer(self, other: 'DNSRecord') -> bool:
+ """Returns true if another record has same name, type and class,
+ and if its TTL is at least half of this record's."""
+ return self == other and other.ttl > (self.ttl / 2)
+
+ def get_expiration_time(self, percent: int) -> float:
+ """Returns the time at which this record will have expired
+ by a certain percentage."""
+ return self.created + (percent * self.ttl * 10)
+
+ # TODO: Switch to just int here
+ def get_remaining_ttl(self, now: float) -> Union[int, float]:
+ """Returns the remaining TTL in seconds."""
+ return max(0, millis_to_seconds((self.created + (_EXPIRE_FULL_TIME_MS * self.ttl)) - now))
+
+ def is_expired(self, now: float) -> bool:
+ """Returns true if this record has expired."""
+ return self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) <= now
+
+ def is_stale(self, now: float) -> bool:
+ """Returns true if this record is at least half way expired."""
+ return self.created + (_EXPIRE_STALE_TIME_MS * self.ttl) <= now
+
+ def is_recent(self, now: float) -> bool:
+ """Returns true if the record more than one quarter of its TTL remaining."""
+ return self.created + (_RECENT_TIME_MS * self.ttl) > now
+
+ def reset_ttl(self, other: 'DNSRecord') -> None:
+ """Sets this record's TTL and created time to that of
+ another record."""
+ self.set_created_ttl(other.created, other.ttl)
+
+ def set_created_ttl(self, created: float, ttl: Union[float, int]) -> None:
+ """Set the created and ttl of a record."""
+ self.created = created
+ self.ttl = ttl
+
+ def write(self, out: 'DNSOutgoing') -> None: # pylint: disable=no-self-use
+ """Abstract method"""
+ raise AbstractMethodException
+
+ def to_string(self, other: Union[bytes, str]) -> str:
+ """String representation with additional information"""
+ arg = f"{self.ttl}/{int(self.get_remaining_ttl(current_time_millis()))},{cast(Any, other)}"
+ return DNSEntry.entry_to_string(self, "record", arg)
+
+
+class DNSAddress(DNSRecord):
+
+ """A DNS address record"""
+
+ __slots__ = ('_hash', 'address', 'scope_id')
+
+ def __init__(
+ self,
+ name: str,
+ type_: int,
+ class_: int,
+ ttl: int,
+ address: bytes,
+ *,
+ scope_id: Optional[int] = None,
+ created: Optional[float] = None,
+ ) -> None:
+ super().__init__(name, type_, class_, ttl, created)
+ self.address = address
+ self.scope_id = scope_id
+ self._hash = hash((self.key, type_, self.class_, address, scope_id))
+
+ def write(self, out: 'DNSOutgoing') -> None:
+ """Used in constructing an outgoing packet"""
+ out.write_string(self.address)
+
+ def __eq__(self, other: Any) -> bool:
+ """Tests equality on address"""
+ return (
+ isinstance(other, DNSAddress)
+ and self.address == other.address
+ and self.scope_id == other.scope_id
+ and dns_entry_matches(other, self.key, self.type, self.class_)
+ )
+
+ def __hash__(self) -> int:
+ """Hash to compare like DNSAddresses."""
+ return self._hash
+
+ def __repr__(self) -> str:
+ """String representation"""
+ try:
+ return self.to_string(
+ socket.inet_ntop(
+ socket.AF_INET6 if _is_v6_address(self.address) else socket.AF_INET, self.address
+ )
+ )
+ except (ValueError, OSError):
+ return self.to_string(str(self.address))
+
+
+class DNSHinfo(DNSRecord):
+
+ """A DNS host information record"""
+
+ __slots__ = ('_hash', 'cpu', 'os')
+
+ def __init__(
+ self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None
+ ) -> None:
+ super().__init__(name, type_, class_, ttl, created)
+ self.cpu = cpu
+ self.os = os
+ self._hash = hash((self.key, type_, self.class_, cpu, os))
+
+ def write(self, out: 'DNSOutgoing') -> None:
+ """Used in constructing an outgoing packet"""
+ out.write_character_string(self.cpu.encode('utf-8'))
+ out.write_character_string(self.os.encode('utf-8'))
+
+ def __eq__(self, other: Any) -> bool:
+ """Tests equality on cpu and os"""
+ return (
+ isinstance(other, DNSHinfo)
+ and self.cpu == other.cpu
+ and self.os == other.os
+ and dns_entry_matches(other, self.key, self.type, self.class_)
+ )
+
+ def __hash__(self) -> int:
+ """Hash to compare like DNSHinfo."""
+ return self._hash
+
+ def __repr__(self) -> str:
+ """String representation"""
+ return self.to_string(self.cpu + " " + self.os)
+
+
+class DNSPointer(DNSRecord):
+
+ """A DNS pointer record"""
+
+ __slots__ = ('_hash', 'alias')
+
+ def __init__(
+ self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None
+ ) -> None:
+ super().__init__(name, type_, class_, ttl, created)
+ self.alias = alias
+ self._hash = hash((self.key, type_, self.class_, alias))
+
+ @property
+ def max_size_compressed(self) -> int:
+ """Maximum size of the record in the packet assuming the name has been compressed."""
+ return (
+ _BASE_MAX_SIZE
+ + _NAME_COMPRESSION_MIN_SIZE
+ + (len(self.alias) - len(self.name))
+ + _NAME_COMPRESSION_MIN_SIZE
+ )
+
+ def write(self, out: 'DNSOutgoing') -> None:
+ """Used in constructing an outgoing packet"""
+ out.write_name(self.alias)
+
+ def __eq__(self, other: Any) -> bool:
+ """Tests equality on alias"""
+ return (
+ isinstance(other, DNSPointer)
+ and self.alias == other.alias
+ and dns_entry_matches(other, self.key, self.type, self.class_)
+ )
+
+ def __hash__(self) -> int:
+ """Hash to compare like DNSPointer."""
+ return self._hash
+
+ def __repr__(self) -> str:
+ """String representation"""
+ return self.to_string(self.alias)
+
+
+class DNSText(DNSRecord):
+
+ """A DNS text record"""
+
+ __slots__ = ('_hash', 'text')
+
+ def __init__(
+ self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None
+ ) -> None:
+ assert isinstance(text, (bytes, type(None)))
+ super().__init__(name, type_, class_, ttl, created)
+ self.text = text
+ self._hash = hash((self.key, type_, self.class_, text))
+
+ def write(self, out: 'DNSOutgoing') -> None:
+ """Used in constructing an outgoing packet"""
+ out.write_string(self.text)
+
+ def __hash__(self) -> int:
+ """Hash to compare like DNSText."""
+ return self._hash
+
+ def __eq__(self, other: Any) -> bool:
+ """Tests equality on text"""
+ return (
+ isinstance(other, DNSText)
+ and self.text == other.text
+ and dns_entry_matches(other, self.key, self.type, self.class_)
+ )
+
+ def __repr__(self) -> str:
+ """String representation"""
+ if len(self.text) > 10:
+ return self.to_string(self.text[:7]) + "..."
+ return self.to_string(self.text)
+
+
+class DNSService(DNSRecord):
+
+ """A DNS service record"""
+
+ __slots__ = ('_hash', 'priority', 'weight', 'port', 'server')
+
+ def __init__(
+ self,
+ name: str,
+ type_: int,
+ class_: int,
+ ttl: Union[float, int],
+ priority: int,
+ weight: int,
+ port: int,
+ server: str,
+ created: Optional[float] = None,
+ ) -> None:
+ super().__init__(name, type_, class_, ttl, created)
+ self.priority = priority
+ self.weight = weight
+ self.port = port
+ self.server = server
+ self._hash = hash((self.key, type_, self.class_, priority, weight, port, server))
+
+ def write(self, out: 'DNSOutgoing') -> None:
+ """Used in constructing an outgoing packet"""
+ out.write_short(self.priority)
+ out.write_short(self.weight)
+ out.write_short(self.port)
+ out.write_name(self.server)
+
+ def __eq__(self, other: Any) -> bool:
+ """Tests equality on priority, weight, port and server"""
+ return (
+ isinstance(other, DNSService)
+ and self.priority == other.priority
+ and self.weight == other.weight
+ and self.port == other.port
+ and self.server == other.server
+ and dns_entry_matches(other, self.key, self.type, self.class_)
+ )
+
+ def __hash__(self) -> int:
+ """Hash to compare like DNSService."""
+ return self._hash
+
+ def __repr__(self) -> str:
+ """String representation"""
+ return self.to_string(f"{self.server}:{self.port}")
+
+
+class DNSNsec(DNSRecord):
+
+ """A DNS NSEC record"""
+
+ __slots__ = ('_hash', 'next_name', 'rdtypes')
+
+ def __init__(
+ self,
+ name: str,
+ type_: int,
+ class_: int,
+ ttl: int,
+ next_name: str,
+ rdtypes: List[int],
+ created: Optional[float] = None,
+ ) -> None:
+ super().__init__(name, type_, class_, ttl, created)
+ self.next_name = next_name
+ self.rdtypes = sorted(rdtypes)
+ self._hash = hash((self.key, type_, self.class_, next_name, *self.rdtypes))
+
+ def write(self, out: 'DNSOutgoing') -> None:
+ """Used in constructing an outgoing packet."""
+ bitmap = bytearray(b'\0' * 32)
+ for rdtype in self.rdtypes:
+ if rdtype > 255: # mDNS only supports window 0
+ continue
+ offset = rdtype % 256
+ byte = offset // 8
+ total_octets = byte + 1
+ bitmap[byte] |= 0x80 >> (offset % 8)
+ out_bytes = bytes(bitmap[0:total_octets])
+ out.write_name(self.next_name)
+ out.write_short(0)
+ out.write_short(len(out_bytes))
+ out.write_string(out_bytes)
+
+ def __eq__(self, other: Any) -> bool:
+ """Tests equality on cpu and os"""
+ return (
+ isinstance(other, DNSNsec)
+ and self.next_name == other.next_name
+ and self.rdtypes == other.rdtypes
+ and dns_entry_matches(other, self.key, self.type, self.class_)
+ )
+
+ def __hash__(self) -> int:
+ """Hash to compare like DNSNSec."""
+ return self._hash
+
+ def __repr__(self) -> str:
+ """String representation"""
+ return self.to_string(
+ self.next_name + "," + "|".join([self.get_type(type_) for type_ in self.rdtypes])
+ )
+
+
+class DNSRRSet:
+ """A set of dns records independent of the ttl."""
+
+ __slots__ = ('_records', '_lookup')
+
+ def __init__(self, records: Iterable[DNSRecord]) -> None:
+ """Create an RRset from records."""
+ self._records = records
+ self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None
+
+ @property
+ def lookup(self) -> Dict[DNSRecord, DNSRecord]:
+ if self._lookup is None:
+ # Build the hash table so we can lookup the record independent of the ttl
+ self._lookup = {record: record for record in self._records}
+ return self._lookup
+
+ def suppresses(self, record: DNSRecord) -> bool:
+ """Returns true if any answer in the rrset can suffice for the
+ information held in this record."""
+ other = self.lookup.get(record)
+ return bool(other and other.ttl > (record.ttl / 2))
+
+ def __contains__(self, record: DNSRecord) -> bool:
+ """Returns true if the rrset contains the record."""
+ return record in self.lookup
diff --git a/zeroconf/_exceptions.py b/zeroconf/_exceptions.py
new file mode 100644
index 00000000..02771140
--- /dev/null
+++ b/zeroconf/_exceptions.py
@@ -0,0 +1,49 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+
+class Error(Exception):
+ pass
+
+
+class IncomingDecodeError(Error):
+ pass
+
+
+class NonUniqueNameException(Error):
+ pass
+
+
+class NamePartTooLongException(Error):
+ pass
+
+
+class AbstractMethodException(Error):
+ pass
+
+
+class BadTypeInNameException(Error):
+ pass
+
+
+class ServiceNameAlreadyRegistered(Error):
+ pass
diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py
new file mode 100644
index 00000000..b4c31e2d
--- /dev/null
+++ b/zeroconf/_handlers.py
@@ -0,0 +1,597 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import itertools
+import random
+from collections import deque
+from typing import Dict, Iterable, List, NamedTuple, Optional, Set, TYPE_CHECKING, Tuple, Union, cast
+
+from ._cache import DNSCache, _UniqueRecordsType
+from ._dns import DNSAddress, DNSNsec, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord
+from ._history import QuestionHistory
+from ._logger import log
+from ._protocol.incoming import DNSIncoming
+from ._protocol.outgoing import DNSOutgoing
+from ._services.info import ServiceInfo
+from ._services.registry import ServiceRegistry
+from ._updates import RecordUpdate, RecordUpdateListener
+from ._utils.time import current_time_millis, millis_to_seconds
+from .const import (
+ _CLASS_IN,
+ _CLASS_UNIQUE,
+ _DNS_OTHER_TTL,
+ _DNS_PTR_MIN_TTL,
+ _FLAGS_AA,
+ _FLAGS_QR_RESPONSE,
+ _ONE_SECOND,
+ _SERVICE_TYPE_ENUMERATION_NAME,
+ _TYPE_A,
+ _TYPE_AAAA,
+ _TYPE_ANY,
+ _TYPE_NSEC,
+ _TYPE_PTR,
+ _TYPE_SRV,
+ _TYPE_TXT,
+)
+
+if TYPE_CHECKING:
+ from ._core import Zeroconf
+
+
+_AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]]
+
+_MULTICAST_DELAY_RANDOM_INTERVAL = (20, 120)
+_ADDRESS_RECORD_TYPES = {_TYPE_A, _TYPE_AAAA}
+_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES}
+
+
+class QuestionAnswers(NamedTuple):
+ ucast: _AnswerWithAdditionalsType
+ mcast_now: _AnswerWithAdditionalsType
+ mcast_aggregate: _AnswerWithAdditionalsType
+ mcast_aggregate_last_second: _AnswerWithAdditionalsType
+
+
+class AnswerGroup(NamedTuple):
+ """A group of answers scheduled to be sent at the same time."""
+
+ send_after: float # Must be sent after this time
+ send_before: float # Must be sent before this time
+ answers: _AnswerWithAdditionalsType
+
+
+def _message_is_probe(msg: DNSIncoming) -> bool:
+ return msg.num_authorities > 0
+
+
+def construct_nsec_record(name: str, types: List[int], now: float) -> DNSNsec:
+ """Construct an NSEC record for name and a list of dns types.
+
+ This function should only be used for SRV/A/AAAA records
+ which have a TTL of _DNS_OTHER_TTL
+ """
+ return DNSNsec(name, _TYPE_NSEC, _CLASS_IN | _CLASS_UNIQUE, _DNS_OTHER_TTL, name, types, created=now)
+
+
+def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing:
+ """Add answers and additionals to a DNSOutgoing."""
+ out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=True)
+ _add_answers_additionals(out, answers)
+ return out
+
+
+def construct_outgoing_unicast_answers(
+ answers: _AnswerWithAdditionalsType, ucast_source: bool, questions: List[DNSQuestion], id_: int
+) -> DNSOutgoing:
+ """Add answers and additionals to a DNSOutgoing."""
+ out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False, id_=id_)
+ # Adding the questions back when the source is legacy unicast behavior
+ if ucast_source:
+ for question in questions:
+ out.add_question(question)
+ _add_answers_additionals(out, answers)
+ return out
+
+
+def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsType) -> None:
+ # Find additionals and suppress any additionals that are already in answers
+ sending: Set[DNSRecord] = set(answers.keys())
+ # Answers are sorted to group names together to increase the chance
+ # that similar names will end up in the same packet and can reduce the
+ # overall size of the outgoing response via name compression
+ for answer, additionals in sorted(answers.items(), key=lambda kv: kv[0].name):
+ out.add_answer_at_time(answer, 0)
+ for additional in additionals:
+ if additional not in sending:
+ out.add_additional_answer(additional)
+ sending.add(additional)
+
+
+def sanitize_incoming_record(record: DNSRecord) -> None:
+ """Protect zeroconf from records that can cause denial of service.
+
+ We enforce a minimum TTL for PTR records to avoid
+ ServiceBrowsers generating excessive queries refresh queries.
+ Apple uses a 15s minimum TTL, however we do not have the same
+ level of rate limit and safe guards so we use 1/4 of the recommended value.
+ """
+ if record.ttl and record.ttl < _DNS_PTR_MIN_TTL and isinstance(record, DNSPointer):
+ log.debug(
+ "Increasing effective ttl of %s to minimum of %s to protect against excessive refreshes.",
+ record,
+ _DNS_PTR_MIN_TTL,
+ )
+ record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL)
+
+
+class _QueryResponse:
+ """A pair for unicast and multicast DNSOutgoing responses."""
+
+ def __init__(self, cache: DNSCache, msgs: List[DNSIncoming]) -> None:
+ """Build a query response."""
+ self._is_probe = any(_message_is_probe(msg) for msg in msgs)
+ self._msg = msgs[0]
+ self._now = self._msg.now
+ self._cache = cache
+ self._additionals: _AnswerWithAdditionalsType = {}
+ self._ucast: Set[DNSRecord] = set()
+ self._mcast_now: Set[DNSRecord] = set()
+ self._mcast_aggregate: Set[DNSRecord] = set()
+ self._mcast_aggregate_last_second: Set[DNSRecord] = set()
+
+ def add_qu_question_response(self, answers: _AnswerWithAdditionalsType) -> None:
+ """Generate a response to a multicast QU query."""
+ for record, additionals in answers.items():
+ self._additionals[record] = additionals
+ if self._is_probe:
+ self._ucast.add(record)
+ if not self._has_mcast_within_one_quarter_ttl(record):
+ self._mcast_now.add(record)
+ elif not self._is_probe:
+ self._ucast.add(record)
+
+ def add_ucast_question_response(self, answers: _AnswerWithAdditionalsType) -> None:
+ """Generate a response to a unicast query."""
+ self._additionals.update(answers)
+ self._ucast.update(answers.keys())
+
+ def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> None:
+ """Generate a response to a multicast query."""
+ self._additionals.update(answers)
+ for answer in answers:
+ if self._is_probe:
+ self._mcast_now.add(answer)
+ continue
+
+ if self._has_mcast_record_in_last_second(answer):
+ self._mcast_aggregate_last_second.add(answer)
+ elif len(self._msg.questions) == 1 and self._msg.questions[0].type in _RESPOND_IMMEDIATE_TYPES:
+ self._mcast_now.add(answer)
+ else:
+ self._mcast_aggregate.add(answer)
+
+ def _generate_answers_with_additionals(self, rrset: Set[DNSRecord]) -> _AnswerWithAdditionalsType:
+ """Create answers with additionals from an rrset."""
+ return {record: self._additionals[record] for record in rrset}
+
+ def answers(
+ self,
+ ) -> QuestionAnswers:
+ """Return answer sets that will be queued."""
+ return QuestionAnswers(
+ self._generate_answers_with_additionals(self._ucast),
+ self._generate_answers_with_additionals(self._mcast_now),
+ self._generate_answers_with_additionals(self._mcast_aggregate),
+ self._generate_answers_with_additionals(self._mcast_aggregate_last_second),
+ )
+
+ def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool:
+ """Check to see if a record has been mcasted recently.
+
+ https://datatracker.ietf.org/doc/html/rfc6762#section-5.4
+ When receiving a question with the unicast-response bit set, a
+ responder SHOULD usually respond with a unicast packet directed back
+ to the querier. However, if the responder has not multicast that
+ record recently (within one quarter of its TTL), then the responder
+ SHOULD instead multicast the response so as to keep all the peer
+ caches up to date
+ """
+ maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record))
+ return bool(maybe_entry and maybe_entry.is_recent(self._now))
+
+ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
+ """Check if an answer was seen in the last second.
+ Protect the network against excessive packet flooding
+ https://datatracker.ietf.org/doc/html/rfc6762#section-14
+ """
+ maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record))
+ return bool(maybe_entry and self._now - maybe_entry.created < _ONE_SECOND)
+
+
+def _get_address_and_nsec_records(service: ServiceInfo, now: float) -> Set[DNSRecord]:
+ """Build a set of address records and NSEC records for non-present record types."""
+ seen_types: Set[int] = set()
+ records: Set[DNSRecord] = set()
+ for dns_address in service.dns_addresses(created=now):
+ seen_types.add(dns_address.type)
+ records.add(dns_address)
+ missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
+ if missing_types:
+ records.add(construct_nsec_record(service.server, list(missing_types), now))
+ return records
+
+
+class QueryHandler:
+ """Query the ServiceRegistry."""
+
+ def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: QuestionHistory) -> None:
+ """Init the query handler."""
+ self.registry = registry
+ self.cache = cache
+ self.question_history = question_history
+
+ def _add_service_type_enumeration_query_answers(
+ self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
+ ) -> None:
+ """Provide an answer to a service type enumeration query.
+
+ https://datatracker.ietf.org/doc/html/rfc6763#section-9
+ """
+ for stype in self.registry.async_get_types():
+ dns_pointer = DNSPointer(
+ _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now
+ )
+ if not known_answers.suppresses(dns_pointer):
+ answer_set[dns_pointer] = set()
+
+ def _add_pointer_answers(
+ self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
+ ) -> None:
+ """Answer PTR/ANY question."""
+ for service in self.registry.async_get_infos_type(name):
+ # Add recommended additional answers according to
+ # https://tools.ietf.org/html/rfc6763#section-12.1.
+ dns_pointer = service.dns_pointer(created=now)
+ if known_answers.suppresses(dns_pointer):
+ continue
+ additionals: Set[DNSRecord] = {service.dns_service(created=now), service.dns_text(created=now)}
+ additionals |= _get_address_and_nsec_records(service, now)
+ answer_set[dns_pointer] = additionals
+
+ def _add_address_answers(
+ self,
+ name: str,
+ answer_set: _AnswerWithAdditionalsType,
+ known_answers: DNSRRSet,
+ now: float,
+ type_: int,
+ ) -> None:
+ """Answer A/AAAA/ANY question."""
+ for service in self.registry.async_get_infos_server(name):
+ answers: List[DNSAddress] = []
+ additionals: Set[DNSRecord] = set()
+ seen_types: Set[int] = set()
+ for dns_address in service.dns_addresses(created=now):
+ seen_types.add(dns_address.type)
+ if dns_address.type != type_:
+ additionals.add(dns_address)
+ elif not known_answers.suppresses(dns_address):
+ answers.append(dns_address)
+ missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types
+ if answers:
+ if missing_types:
+ additionals.add(construct_nsec_record(service.server, list(missing_types), now))
+ for answer in answers:
+ answer_set[answer] = additionals
+ elif type_ in missing_types:
+ answer_set[construct_nsec_record(service.server, list(missing_types), now)] = set()
+
+ def _answer_question(
+ self,
+ question: DNSQuestion,
+ known_answers: DNSRRSet,
+ now: float,
+ ) -> _AnswerWithAdditionalsType:
+ answer_set: _AnswerWithAdditionalsType = {}
+
+ if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME:
+ self._add_service_type_enumeration_query_answers(answer_set, known_answers, now)
+ return answer_set
+
+ type_ = question.type
+
+ if type_ in (_TYPE_PTR, _TYPE_ANY):
+ self._add_pointer_answers(question.name, answer_set, known_answers, now)
+
+ if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY):
+ self._add_address_answers(question.name, answer_set, known_answers, now, type_)
+
+ if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY):
+ service = self.registry.async_get_info_name(question.name) # type: ignore
+ if service is not None:
+ if type_ in (_TYPE_SRV, _TYPE_ANY):
+ # Add recommended additional answers according to
+ # https://tools.ietf.org/html/rfc6763#section-12.2.
+ dns_service = service.dns_service(created=now)
+ if not known_answers.suppresses(dns_service):
+ answer_set[dns_service] = _get_address_and_nsec_records(service, now)
+ if type_ in (_TYPE_TXT, _TYPE_ANY):
+ dns_text = service.dns_text(created=now)
+ if not known_answers.suppresses(dns_text):
+ answer_set[dns_text] = set()
+
+ return answer_set
+
+ def async_response( # pylint: disable=unused-argument
+ self, msgs: List[DNSIncoming], ucast_source: bool
+ ) -> QuestionAnswers:
+ """Deal with incoming query packets. Provides a response if possible.
+
+ This function must be run in the event loop as it is not
+ threadsafe.
+ """
+ known_answers = DNSRRSet(
+ itertools.chain.from_iterable(msg.answers for msg in msgs if not _message_is_probe(msg))
+ )
+ query_res = _QueryResponse(self.cache, msgs)
+
+ for msg in msgs:
+ for question in msg.questions:
+ if not question.unicast:
+ self.question_history.add_question_at_time(question, msg.now, set(known_answers.lookup))
+ answer_set = self._answer_question(question, known_answers, msg.now)
+ if not ucast_source and question.unicast:
+ query_res.add_qu_question_response(answer_set)
+ continue
+ if ucast_source:
+ query_res.add_ucast_question_response(answer_set)
+ # We always multicast as well even if its a unicast
+ # source as long as we haven't done it recently (75% of ttl)
+ query_res.add_mcast_question_response(answer_set)
+
+ return query_res.answers()
+
+
+class RecordManager:
+ """Process records into the cache and notify listeners."""
+
+ def __init__(self, zeroconf: 'Zeroconf') -> None:
+ """Init the record manager."""
+ self.zc = zeroconf
+ self.cache = zeroconf.cache
+ self.listeners: List[RecordUpdateListener] = []
+
+ def async_updates(self, now: float, records: List[RecordUpdate]) -> None:
+ """Used to notify listeners of new information that has updated
+ a record.
+
+ This method must be called before the cache is updated.
+
+ This method will be run in the event loop.
+ """
+ for listener in self.listeners:
+ listener.async_update_records(self.zc, now, records)
+
+ def async_updates_complete(self) -> None:
+ """Used to notify listeners of new information that has updated
+ a record.
+
+ This method must be called after the cache is updated.
+
+ This method will be run in the event loop.
+ """
+ for listener in self.listeners:
+ listener.async_update_records_complete()
+ self.zc.async_notify_all()
+
+ def async_updates_from_response(self, msg: DNSIncoming) -> None:
+ """Deal with incoming response packets. All answers
+ are held in the cache, and listeners are notified.
+
+ This function must be run in the event loop as it is not
+ threadsafe.
+ """
+ updates: List[RecordUpdate] = []
+ address_adds: List[DNSAddress] = []
+ other_adds: List[DNSRecord] = []
+ removes: Set[DNSRecord] = set()
+ now = msg.now
+ unique_types: Set[Tuple[str, int, int]] = set()
+
+ for record in msg.answers:
+ sanitize_incoming_record(record)
+
+ if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
+ unique_types.add((record.name, record.type, record.class_))
+
+ maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record))
+ if not record.is_expired(now):
+ if maybe_entry is not None:
+ maybe_entry.reset_ttl(record)
+ else:
+ if isinstance(record, DNSAddress):
+ address_adds.append(record)
+ else:
+ other_adds.append(record)
+ updates.append(RecordUpdate(record, maybe_entry))
+ # This is likely a goodbye since the record is
+ # expired and exists in the cache
+ elif maybe_entry is not None:
+ updates.append(RecordUpdate(record, maybe_entry))
+ removes.add(record)
+
+ if unique_types:
+ self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now)
+
+ if updates:
+ self.async_updates(now, updates)
+ # The cache adds must be processed AFTER we trigger
+ # the updates since we compare existing data
+ # with the new data and updating the cache
+ # ahead of update_record will cause listeners
+ # to miss changes
+ #
+ # We must process address adds before non-addresses
+ # otherwise a fetch of ServiceInfo may miss an address
+ # because it thinks the cache is complete
+ #
+ # The cache is processed under the context manager to ensure
+ # that any ServiceBrowser that is going to call
+ # zc.get_service_info will see the cached value
+ # but ONLY after all the record updates have been
+ # processsed.
+ if other_adds or address_adds:
+ self.cache.async_add_records(itertools.chain(address_adds, other_adds))
+ # Removes are processed last since
+ # ServiceInfo could generate an un-needed query
+ # because the data was not yet populated.
+ if removes:
+ self.cache.async_remove_records(removes)
+ if updates:
+ self.async_updates_complete()
+
+ def _async_mark_unique_cached_records_older_than_1s_to_expire(
+ self, unique_types: Set[Tuple[str, int, int]], answers: Iterable[DNSRecord], now: float
+ ) -> None:
+ # rfc6762#section-10.2 para 2
+ # Since unique is set, all old records with that name, rrtype,
+ # and rrclass that were received more than one second ago are declared
+ # invalid, and marked to expire from the cache in one second.
+ answers_rrset = DNSRRSet(answers)
+ for name, type_, class_ in unique_types:
+ for entry in self.cache.async_all_by_details(name, type_, class_):
+ if (now - entry.created > _ONE_SECOND) and entry not in answers_rrset:
+ # Expire in 1s
+ entry.set_created_ttl(now, 1)
+
+ def async_add_listener(
+ self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]]
+ ) -> None:
+ """Adds a listener for a given question. The listener will have
+ its update_record method called when information is available to
+ answer the question(s).
+
+ This function is not threadsafe and must be called in the eventloop.
+ """
+ self.listeners.append(listener)
+
+ if question is None:
+ return
+
+ questions = [question] if isinstance(question, DNSQuestion) else question
+ assert self.zc.loop is not None
+ self._async_update_matching_records(listener, questions)
+
+ def _async_update_matching_records(
+ self, listener: RecordUpdateListener, questions: List[DNSQuestion]
+ ) -> None:
+ """Calls back any existing entries in the cache that answer the question.
+
+ This function must be run from the event loop.
+ """
+ now = current_time_millis()
+ records: List[RecordUpdate] = []
+ for question in questions:
+ for record in self.cache.async_entries_with_name(question.name):
+ if not record.is_expired(now) and question.answered_by(record):
+ records.append(RecordUpdate(record, None))
+
+ if not records:
+ return
+ listener.async_update_records(self.zc, now, records)
+ listener.async_update_records_complete()
+ self.zc.async_notify_all()
+
+ def async_remove_listener(self, listener: RecordUpdateListener) -> None:
+ """Removes a listener.
+
+ This function is not threadsafe and must be called in the eventloop.
+ """
+ try:
+ self.listeners.remove(listener)
+ self.zc.async_notify_all()
+ except ValueError as e:
+ log.exception('Failed to remove listener: %r', e)
+
+
+class MulticastOutgoingQueue:
+ """An outgoing queue used to aggregate multicast responses."""
+
+ def __init__(self, zeroconf: 'Zeroconf', additional_delay: int, max_aggregation_delay: int) -> None:
+ self.zc = zeroconf
+ self.queue: deque = deque()
+ # Additional delay is used to implement
+ # Protect the network against excessive packet flooding
+ # https://datatracker.ietf.org/doc/html/rfc6762#section-14
+ self.additional_delay = additional_delay
+ self.aggregation_delay = max_aggregation_delay
+
+ def async_add(self, now: float, answers: _AnswerWithAdditionalsType) -> None:
+ """Add a group of answers with additionals to the outgoing queue."""
+ assert self.zc.loop is not None
+ random_delay = random.randint(*_MULTICAST_DELAY_RANDOM_INTERVAL) + self.additional_delay
+ send_after = now + random_delay
+ send_before = now + self.aggregation_delay + self.additional_delay
+ if len(self.queue):
+ # If we calculate a random delay for the send after time
+ # that is less than the last group scheduled to go out,
+ # we instead add the answers to the last group as this
+ # allows aggregating additonal responses
+ last_group = self.queue[-1]
+ if send_after <= last_group.send_after:
+ last_group.answers.update(answers)
+ return
+ else:
+ self.zc.loop.call_later(millis_to_seconds(random_delay), self.async_ready)
+ self.queue.append(AnswerGroup(send_after, send_before, answers))
+
+ def _remove_answers_from_queue(self, answers: _AnswerWithAdditionalsType) -> None:
+ """Remove a set of answers from the outgoing queue."""
+ for pending in self.queue:
+ for record in answers:
+ pending.answers.pop(record, None)
+
+ def async_ready(self) -> None:
+ """Process anything in the queue that is ready."""
+ assert self.zc.loop is not None
+ now = current_time_millis()
+
+ if len(self.queue) > 1 and self.queue[0].send_before > now:
+ # There is more than one answer in the queue,
+ # delay until we have to send it (first answer group reaches send_before)
+ self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_before - now), self.async_ready)
+ return
+
+ answers: _AnswerWithAdditionalsType = {}
+ # Add all groups that can be sent now
+ while len(self.queue) and self.queue[0].send_after <= now:
+ answers.update(self.queue.popleft().answers)
+
+ if len(self.queue):
+ # If there are still groups in the queue that are not ready to send
+ # be sure we schedule them to go out later
+ self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_after - now), self.async_ready)
+
+ if answers:
+ # If we have the same answer scheduled to go out, remove them
+ self._remove_answers_from_queue(answers)
+ self.zc.async_send(construct_outgoing_multicast_answers(answers))
diff --git a/zeroconf/_history.py b/zeroconf/_history.py
new file mode 100644
index 00000000..cbb36144
--- /dev/null
+++ b/zeroconf/_history.py
@@ -0,0 +1,70 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+from typing import Dict, Set, Tuple
+
+from ._dns import DNSQuestion, DNSRecord
+from .const import _DUPLICATE_QUESTION_INTERVAL
+
+# The QuestionHistory is used to implement Duplicate Question Suppression
+# https://datatracker.ietf.org/doc/html/rfc6762#section-7.3
+
+
+class QuestionHistory:
+ def __init__(self) -> None:
+ self._history: Dict[DNSQuestion, Tuple[float, Set[DNSRecord]]] = {}
+
+ def add_question_at_time(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> None:
+ """Remember a question with known answers."""
+ self._history[question] = (now, known_answers)
+
+ def suppresses(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> bool:
+ """Check to see if a question should be suppressed.
+
+ https://datatracker.ietf.org/doc/html/rfc6762#section-7.3
+ When multiple queriers on the network are querying
+ for the same resource records, there is no need for them to all be
+ repeatedly asking the same question.
+ """
+ previous_question = self._history.get(question)
+ # There was not previous question in the history
+ if not previous_question:
+ return False
+ than, previous_known_answers = previous_question
+ # The last question was older than 999ms
+ if now - than > _DUPLICATE_QUESTION_INTERVAL:
+ return False
+ # The last question has more known answers than
+ # we knew so we have to ask
+ if previous_known_answers - known_answers:
+ return False
+ return True
+
+ def async_expire(self, now: float) -> None:
+ """Expire the history of old questions."""
+ removes = [
+ question
+ for question, now_known_answers in self._history.items()
+ if now - now_known_answers[0] > _DUPLICATE_QUESTION_INTERVAL
+ ]
+ for question in removes:
+ del self._history[question]
diff --git a/zeroconf/_logger.py b/zeroconf/_logger.py
new file mode 100644
index 00000000..932d1a2f
--- /dev/null
+++ b/zeroconf/_logger.py
@@ -0,0 +1,74 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import logging
+import sys
+from typing import Any, Dict, Union, cast
+
+log = logging.getLogger(__name__.split('.', maxsplit=1)[0])
+log.addHandler(logging.NullHandler())
+
+
+def set_logger_level_if_unset() -> None:
+ if log.level == logging.NOTSET:
+ log.setLevel(logging.WARN)
+
+
+set_logger_level_if_unset()
+
+
+class QuietLogger:
+ _seen_logs: Dict[str, Union[int, tuple]] = {}
+
+ @classmethod
+ def log_exception_warning(cls, *logger_data: Any) -> None:
+ exc_info = sys.exc_info()
+ exc_str = str(exc_info[1])
+ if exc_str not in cls._seen_logs:
+ # log at warning level the first time this is seen
+ cls._seen_logs[exc_str] = exc_info
+ logger = log.warning
+ else:
+ logger = log.debug
+ logger(*(logger_data or ['Exception occurred']), exc_info=True)
+
+ @classmethod
+ def log_warning_once(cls, *args: Any) -> None:
+ msg_str = args[0]
+ if msg_str not in cls._seen_logs:
+ cls._seen_logs[msg_str] = 0
+ logger = log.warning
+ else:
+ logger = log.debug
+ cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1
+ logger(*args)
+
+ @classmethod
+ def log_exception_once(cls, exc: Exception, *args: Any) -> None:
+ msg_str = args[0]
+ if msg_str not in cls._seen_logs:
+ cls._seen_logs[msg_str] = 0
+ logger = log.warning
+ else:
+ logger = log.debug
+ cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1
+ logger(*args, exc_info=exc)
diff --git a/zeroconf/_protocol/__init__.py b/zeroconf/_protocol/__init__.py
new file mode 100644
index 00000000..360b599d
--- /dev/null
+++ b/zeroconf/_protocol/__init__.py
@@ -0,0 +1,51 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+from ..const import (
+ _FLAGS_QR_MASK,
+ _FLAGS_QR_QUERY,
+ _FLAGS_QR_RESPONSE,
+ _FLAGS_TC,
+)
+
+
+class DNSMessage:
+ """A base class for DNS messages."""
+
+ __slots__ = ('flags',)
+
+ def __init__(self, flags: int) -> None:
+ """Construct a DNS message."""
+ self.flags = flags
+
+ def is_query(self) -> bool:
+ """Returns true if this is a query."""
+ return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY
+
+ def is_response(self) -> bool:
+ """Returns true if this is a response."""
+ return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE
+
+ @property
+ def truncated(self) -> bool:
+ """Returns true if this is a truncated."""
+ return (self.flags & _FLAGS_TC) == _FLAGS_TC
diff --git a/zeroconf/_protocol/incoming.py b/zeroconf/_protocol/incoming.py
new file mode 100644
index 00000000..6d7a6153
--- /dev/null
+++ b/zeroconf/_protocol/incoming.py
@@ -0,0 +1,315 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import struct
+from typing import Callable, Dict, List, Optional, Set, Tuple, cast
+
+from . import DNSMessage
+from .._dns import DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText
+from .._exceptions import IncomingDecodeError
+from .._logger import QuietLogger, log
+from .._utils.time import current_time_millis
+from ..const import (
+ _TYPES,
+ _TYPE_A,
+ _TYPE_AAAA,
+ _TYPE_CNAME,
+ _TYPE_HINFO,
+ _TYPE_NSEC,
+ _TYPE_PTR,
+ _TYPE_SRV,
+ _TYPE_TXT,
+)
+
+DNS_COMPRESSION_HEADER_LEN = 1
+DNS_COMPRESSION_POINTER_LEN = 2
+MAX_DNS_LABELS = 128
+MAX_NAME_LENGTH = 253
+
+DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError)
+
+
+class DNSIncoming(DNSMessage, QuietLogger):
+
+ """Object representation of an incoming DNS packet"""
+
+ __slots__ = (
+ 'offset',
+ 'data',
+ 'data_len',
+ 'name_cache',
+ 'questions',
+ '_answers',
+ 'id',
+ 'num_questions',
+ 'num_answers',
+ 'num_authorities',
+ 'num_additionals',
+ 'valid',
+ 'now',
+ 'scope_id',
+ 'source',
+ )
+
+ def __init__(
+ self,
+ data: bytes,
+ source: Optional[Tuple[str, int]] = None,
+ scope_id: Optional[int] = None,
+ now: Optional[float] = None,
+ ) -> None:
+ """Constructor from string holding bytes of packet"""
+ super().__init__(0)
+ self.offset = 0
+ self.data = data
+ self.data_len = len(data)
+ self.name_cache: Dict[int, List[str]] = {}
+ self.questions: List[DNSQuestion] = []
+ self._answers: List[DNSRecord] = []
+ self.id = 0
+ self.num_questions = 0
+ self.num_answers = 0
+ self.num_authorities = 0
+ self.num_additionals = 0
+ self.valid = False
+ self._read_others = False
+ self.now = now or current_time_millis()
+ self.source = source
+ self.scope_id = scope_id
+ self._parse_data(self._initial_parse)
+
+ def _initial_parse(self) -> None:
+ """Parse the data needed to initalize the packet object."""
+ self.read_header()
+ self.read_questions()
+ if not self.num_questions:
+ self.read_others()
+ self.valid = True
+
+ def _parse_data(self, parser_call: Callable) -> None:
+ """Parse part of the packet and catch exceptions."""
+ try:
+ parser_call()
+ except DECODE_EXCEPTIONS:
+ self.log_exception_warning(
+ 'Received invalid packet from %s at offset %d while unpacking %r',
+ self.source,
+ self.offset,
+ self.data,
+ )
+
+ @property
+ def answers(self) -> List[DNSRecord]:
+ """Answers in the packet."""
+ if not self._read_others:
+ self._parse_data(self.read_others)
+ return self._answers
+
+ def __repr__(self) -> str:
+ return '' % ', '.join(
+ [
+ 'id=%s' % self.id,
+ 'flags=%s' % self.flags,
+ 'truncated=%s' % self.truncated,
+ 'n_q=%s' % self.num_questions,
+ 'n_ans=%s' % self.num_answers,
+ 'n_auth=%s' % self.num_authorities,
+ 'n_add=%s' % self.num_additionals,
+ 'questions=%s' % self.questions,
+ 'answers=%s' % self.answers,
+ ]
+ )
+
+ def unpack(self, format_: bytes, length: int) -> tuple:
+ self.offset += length
+ return struct.unpack(format_, self.data[self.offset - length : self.offset])
+
+ def read_header(self) -> None:
+ """Reads header portion of packet"""
+ (
+ self.id,
+ self.flags,
+ self.num_questions,
+ self.num_answers,
+ self.num_authorities,
+ self.num_additionals,
+ ) = self.unpack(b'!6H', 12)
+
+ def read_questions(self) -> None:
+ """Reads questions section of packet"""
+ self.questions = [
+ DNSQuestion(self.read_name(), *self.unpack(b'!HH', 4)) for _ in range(self.num_questions)
+ ]
+
+ def read_character_string(self) -> bytes:
+ """Reads a character string from the packet"""
+ length = self.data[self.offset]
+ self.offset += 1
+ return self.read_string(length)
+
+ def read_string(self, length: int) -> bytes:
+ """Reads a string of a given length from the packet"""
+ info = self.data[self.offset : self.offset + length]
+ self.offset += length
+ return info
+
+ def read_unsigned_short(self) -> int:
+ """Reads an unsigned short from the packet"""
+ return cast(int, self.unpack(b'!H', 2)[0])
+
+ def read_others(self) -> None:
+ """Reads the answers, authorities and additionals section of the
+ packet"""
+ self._read_others = True
+ n = self.num_answers + self.num_authorities + self.num_additionals
+ for _ in range(n):
+ domain = self.read_name()
+ type_, class_, ttl, length = self.unpack(b'!HHiH', 10)
+ end = self.offset + length
+ rec = None
+ try:
+ rec = self.read_record(domain, type_, class_, ttl, length)
+ except DECODE_EXCEPTIONS:
+ # Skip records that fail to decode if we know the length
+ # If the packet is really corrupt read_name and the unpack
+ # above would fail and hit the exception catch in read_others
+ self.offset = end
+ log.debug(
+ 'Unable to parse; skipping record for %s with type %s at offset %d while unpacking %r',
+ domain,
+ _TYPES.get(type_, type_),
+ self.offset,
+ self.data,
+ exc_info=True,
+ )
+ if rec is not None:
+ self._answers.append(rec)
+
+ def read_record(self, domain: str, type_: int, class_: int, ttl: int, length: int) -> Optional[DNSRecord]:
+ """Read known records types and skip unknown ones."""
+ if type_ == _TYPE_A:
+ return DNSAddress(domain, type_, class_, ttl, self.read_string(4), created=self.now)
+ if type_ in (_TYPE_CNAME, _TYPE_PTR):
+ return DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now)
+ if type_ == _TYPE_TXT:
+ return DNSText(domain, type_, class_, ttl, self.read_string(length), self.now)
+ if type_ == _TYPE_SRV:
+ return DNSService(
+ domain,
+ type_,
+ class_,
+ ttl,
+ self.read_unsigned_short(),
+ self.read_unsigned_short(),
+ self.read_unsigned_short(),
+ self.read_name(),
+ self.now,
+ )
+ if type_ == _TYPE_HINFO:
+ return DNSHinfo(
+ domain,
+ type_,
+ class_,
+ ttl,
+ self.read_character_string().decode('utf-8'),
+ self.read_character_string().decode('utf-8'),
+ self.now,
+ )
+ if type_ == _TYPE_AAAA:
+ return DNSAddress(
+ domain, type_, class_, ttl, self.read_string(16), created=self.now, scope_id=self.scope_id
+ )
+ if type_ == _TYPE_NSEC:
+ name_start = self.offset
+ return DNSNsec(
+ domain,
+ type_,
+ class_,
+ ttl,
+ self.read_name(),
+ self.read_bitmap(name_start + length),
+ self.now,
+ )
+ # Try to ignore types we don't know about
+ # Skip the payload for the resource record so the next
+ # records can be parsed correctly
+ self.offset += length
+ return None
+
+ def read_bitmap(self, end: int) -> List[int]:
+ """Reads an NSEC bitmap from the packet."""
+ rdtypes = []
+ while self.offset < end:
+ window = self.data[self.offset]
+ bitmap_length = self.data[self.offset + 1]
+ for i, byte in enumerate(self.data[self.offset + 2 : self.offset + 2 + bitmap_length]):
+ for bit in range(0, 8):
+ if byte & (0x80 >> bit):
+ rdtypes.append(bit + window * 256 + i * 8)
+ self.offset += 2 + bitmap_length
+ return rdtypes
+
+ def read_name(self) -> str:
+ """Reads a domain name from the packet."""
+ labels: List[str] = []
+ seen_pointers: Set[int] = set()
+ self.offset = self._decode_labels_at_offset(self.offset, labels, seen_pointers)
+ name = ".".join(labels) + "."
+ if len(name) > MAX_NAME_LENGTH:
+ raise IncomingDecodeError(f"DNS name {name} exceeds maximum length of {MAX_NAME_LENGTH}")
+ return name
+
+ def _decode_labels_at_offset(self, off: int, labels: List[str], seen_pointers: Set[int]) -> int:
+ # This is a tight loop that is called frequently, small optimizations can make a difference.
+ while off < self.data_len:
+ length = self.data[off]
+ if length == 0:
+ return off + DNS_COMPRESSION_HEADER_LEN
+
+ if length < 0x40:
+ label_idx = off + DNS_COMPRESSION_HEADER_LEN
+ labels.append(str(self.data[label_idx : label_idx + length], 'utf-8', 'replace'))
+ off += DNS_COMPRESSION_HEADER_LEN + length
+ continue
+
+ if length < 0xC0:
+ raise IncomingDecodeError(f"DNS compression type {length} is unknown at {off}")
+
+ # We have a DNS compression pointer
+ link = (length & 0x3F) * 256 + self.data[off + 1]
+ if link > self.data_len:
+ raise IncomingDecodeError(f"DNS compression pointer at {off} points to {link} beyond packet")
+ if link == off:
+ raise IncomingDecodeError(f"DNS compression pointer at {off} points to itself")
+ if link in seen_pointers:
+ raise IncomingDecodeError(f"DNS compression pointer at {off} was seen again")
+ seen_pointers.add(link)
+ linked_labels = self.name_cache.get(link, [])
+ if not linked_labels:
+ self._decode_labels_at_offset(link, linked_labels, seen_pointers)
+ self.name_cache[link] = linked_labels
+ labels.extend(linked_labels)
+ if len(labels) > MAX_DNS_LABELS:
+ raise IncomingDecodeError(f"Maximum dns labels reached while processing pointer at {off}")
+ return off + DNS_COMPRESSION_POINTER_LEN
+
+ raise IncomingDecodeError("Corrupt packet received while decoding name")
diff --git a/zeroconf/_protocol/outgoing.py b/zeroconf/_protocol/outgoing.py
new file mode 100644
index 00000000..59c8382e
--- /dev/null
+++ b/zeroconf/_protocol/outgoing.py
@@ -0,0 +1,442 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import enum
+import struct
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+
+from . import DNSMessage
+from .incoming import DNSIncoming
+from .._cache import DNSCache
+from .._dns import DNSPointer, DNSQuestion, DNSRecord
+from .._exceptions import NamePartTooLongException
+from .._logger import log
+from ..const import (
+ _CLASS_UNIQUE,
+ _DNS_PACKET_HEADER_LEN,
+ _FLAGS_TC,
+ _MAX_MSG_ABSOLUTE,
+ _MAX_MSG_TYPICAL,
+)
+
+
+class DNSOutgoing(DNSMessage):
+
+ """Object representation of an outgoing packet"""
+
+ __slots__ = (
+ 'finished',
+ 'id',
+ 'multicast',
+ 'packets_data',
+ 'names',
+ 'data',
+ 'size',
+ 'allow_long',
+ 'state',
+ 'questions',
+ 'answers',
+ 'authorities',
+ 'additionals',
+ )
+
+ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
+ super().__init__(flags)
+ self.finished = False
+ self.id = id_
+ self.multicast = multicast
+ self.packets_data: List[bytes] = []
+
+ # these 3 are per-packet -- see also _reset_for_next_packet()
+ self.names: Dict[str, int] = {}
+ self.data: List[bytes] = []
+ self.size: int = _DNS_PACKET_HEADER_LEN
+ self.allow_long: bool = True
+
+ self.state = self.State.init
+
+ self.questions: List[DNSQuestion] = []
+ self.answers: List[Tuple[DNSRecord, float]] = []
+ self.authorities: List[DNSPointer] = []
+ self.additionals: List[DNSRecord] = []
+
+ def _reset_for_next_packet(self) -> None:
+ self.names = {}
+ self.data = []
+ self.size = _DNS_PACKET_HEADER_LEN
+ self.allow_long = True
+
+ def __repr__(self) -> str:
+ return '' % ', '.join(
+ [
+ 'multicast=%s' % self.multicast,
+ 'flags=%s' % self.flags,
+ 'questions=%s' % self.questions,
+ 'answers=%s' % self.answers,
+ 'authorities=%s' % self.authorities,
+ 'additionals=%s' % self.additionals,
+ ]
+ )
+
+ class State(enum.Enum):
+ init = 0
+ finished = 1
+
+ def add_question(self, record: DNSQuestion) -> None:
+ """Adds a question"""
+ self.questions.append(record)
+
+ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None:
+ """Adds an answer"""
+ if not record.suppressed_by(inp):
+ self.add_answer_at_time(record, 0)
+
+ def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None:
+ """Adds an answer if it does not expire by a certain time"""
+ if record is not None and (now == 0 or not record.is_expired(now)):
+ self.answers.append((record, now))
+
+ def add_authorative_answer(self, record: DNSPointer) -> None:
+ """Adds an authoritative answer"""
+ self.authorities.append(record)
+
+ def add_additional_answer(self, record: DNSRecord) -> None:
+ """Adds an additional answer
+
+ From: RFC 6763, DNS-Based Service Discovery, February 2013
+
+ 12. DNS Additional Record Generation
+
+ DNS has an efficiency feature whereby a DNS server may place
+ additional records in the additional section of the DNS message.
+ These additional records are records that the client did not
+ explicitly request, but the server has reasonable grounds to expect
+ that the client might request them shortly, so including them can
+ save the client from having to issue additional queries.
+
+ This section recommends which additional records SHOULD be generated
+ to improve network efficiency, for both Unicast and Multicast DNS-SD
+ responses.
+
+ 12.1. PTR Records
+
+ When including a DNS-SD Service Instance Enumeration or Selective
+ Instance Enumeration (subtype) PTR record in a response packet, the
+ server/responder SHOULD include the following additional records:
+
+ o The SRV record(s) named in the PTR rdata.
+ o The TXT record(s) named in the PTR rdata.
+ o All address records (type "A" and "AAAA") named in the SRV rdata.
+
+ 12.2. SRV Records
+
+ When including an SRV record in a response packet, the
+ server/responder SHOULD include the following additional records:
+
+ o All address records (type "A" and "AAAA") named in the SRV rdata.
+
+ """
+ self.additionals.append(record)
+
+ def add_question_or_one_cache(
+ self, cache: DNSCache, now: float, name: str, type_: int, class_: int
+ ) -> None:
+ """Add a question if it is not already cached."""
+ cached_entry = cache.get_by_details(name, type_, class_)
+ if not cached_entry:
+ self.add_question(DNSQuestion(name, type_, class_))
+ else:
+ self.add_answer_at_time(cached_entry, now)
+
+ def add_question_or_all_cache(
+ self, cache: DNSCache, now: float, name: str, type_: int, class_: int
+ ) -> None:
+ """Add a question if it is not already cached.
+ This is currently only used for IPv6 addresses.
+ """
+ cached_entries = cache.get_all_by_details(name, type_, class_)
+ if not cached_entries:
+ self.add_question(DNSQuestion(name, type_, class_))
+ return
+ for cached_entry in cached_entries:
+ self.add_answer_at_time(cached_entry, now)
+
+ def _pack(self, format_: Union[bytes, str], size: int, value: Any) -> None:
+ self.data.append(struct.pack(format_, value))
+ self.size += size
+
+ def _write_byte(self, value: int) -> None:
+ """Writes a single byte to the packet"""
+ self._pack(b'!c', 1, bytes((value,)))
+
+ def _insert_short_at_start(self, value: int) -> None:
+ """Inserts an unsigned short at the start of the packet"""
+ self.data.insert(0, struct.pack(b'!H', value))
+
+ def _replace_short(self, index: int, value: int) -> None:
+ """Replaces an unsigned short in a certain position in the packet"""
+ self.data[index] = struct.pack(b'!H', value)
+
+ def write_short(self, value: int) -> None:
+ """Writes an unsigned short to the packet"""
+ self._pack(b'!H', 2, value)
+
+ def _write_int(self, value: Union[float, int]) -> None:
+ """Writes an unsigned integer to the packet"""
+ self._pack(b'!I', 4, int(value))
+
+ def write_string(self, value: bytes) -> None:
+ """Writes a string to the packet"""
+ assert isinstance(value, bytes)
+ self.data.append(value)
+ self.size += len(value)
+
+ def _write_utf(self, s: str) -> None:
+ """Writes a UTF-8 string of a given length to the packet"""
+ utfstr = s.encode('utf-8')
+ length = len(utfstr)
+ if length > 64:
+ raise NamePartTooLongException
+ self._write_byte(length)
+ self.write_string(utfstr)
+
+ def write_character_string(self, value: bytes) -> None:
+ assert isinstance(value, bytes)
+ length = len(value)
+ if length > 256:
+ raise NamePartTooLongException
+ self._write_byte(length)
+ self.write_string(value)
+
+ def write_name(self, name: str) -> None:
+ """
+ Write names to packet
+
+ 18.14. Name Compression
+
+ When generating Multicast DNS messages, implementations SHOULD use
+ name compression wherever possible to compress the names of resource
+ records, by replacing some or all of the resource record name with a
+ compact two-byte reference to an appearance of that data somewhere
+ earlier in the message [RFC1035].
+ """
+
+ # split name into each label
+ name_length = None
+ if name.endswith('.'):
+ name = name[: len(name) - 1]
+ labels = name.split('.')
+ # Write each new label or a pointer to the existing
+ # on in the packet
+ start_size = self.size
+ for count in range(len(labels)):
+ label = name if count == 0 else '.'.join(labels[count:])
+ index = self.names.get(label)
+ if index:
+ # If part of the name already exists in the packet,
+ # create a pointer to it
+ self._write_byte((index >> 8) | 0xC0)
+ self._write_byte(index & 0xFF)
+ return
+ if name_length is None:
+ name_length = len(name.encode('utf-8'))
+ self.names[label] = start_size + name_length - len(label.encode('utf-8'))
+ self._write_utf(labels[count])
+
+ # this is the end of a name
+ self._write_byte(0)
+
+ def _write_question(self, question: DNSQuestion) -> bool:
+ """Writes a question to the packet"""
+ start_data_length, start_size = len(self.data), self.size
+ self.write_name(question.name)
+ self.write_short(question.type)
+ self._write_record_class(question)
+ return self._check_data_limit_or_rollback(start_data_length, start_size)
+
+ def _write_record_class(self, record: Union[DNSQuestion, DNSRecord]) -> None:
+ """Write out the record class including the unique/unicast (QU) bit."""
+ if record.unique and self.multicast:
+ self.write_short(record.class_ | _CLASS_UNIQUE)
+ else:
+ self.write_short(record.class_)
+
+ def _write_ttl(self, record: DNSRecord, now: float) -> None:
+ """Write out the record ttl."""
+ self._write_int(record.ttl if now == 0 else record.get_remaining_ttl(now))
+
+ def _write_record(self, record: DNSRecord, now: float) -> bool:
+ """Writes a record (answer, authoritative answer, additional) to
+ the packet. Returns True on success, or False if we did not
+ because the packet because the record does not fit."""
+ start_data_length, start_size = len(self.data), self.size
+ self.write_name(record.name)
+ self.write_short(record.type)
+ self._write_record_class(record)
+ self._write_ttl(record, now)
+ index = len(self.data)
+ self.write_short(0) # Will get replaced with the actual size
+ record.write(self)
+ # Adjust size for the short we will write before this record
+ length = sum(len(d) for d in self.data[index + 1 :])
+ # Here we replace the 0 length short we wrote
+ # before with the actual length
+ self._replace_short(index, length)
+ return self._check_data_limit_or_rollback(start_data_length, start_size)
+
+ def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) -> bool:
+ """Check data limit, if we go over, then rollback and return False."""
+ len_limit = _MAX_MSG_ABSOLUTE if self.allow_long else _MAX_MSG_TYPICAL
+ self.allow_long = False
+
+ if self.size <= len_limit:
+ return True
+
+ log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit)
+ del self.data[start_data_length:]
+ self.size = start_size
+
+ rollback_names = [name for name, idx in self.names.items() if idx >= start_size]
+ for name in rollback_names:
+ del self.names[name]
+ return False
+
+ def _write_questions_from_offset(self, questions_offset: int) -> int:
+ questions_written = 0
+ for question in self.questions[questions_offset:]:
+ if not self._write_question(question):
+ break
+ questions_written += 1
+ return questions_written
+
+ def _write_answers_from_offset(self, answer_offset: int) -> int:
+ answers_written = 0
+ for answer, time_ in self.answers[answer_offset:]:
+ if not self._write_record(answer, time_):
+ break
+ answers_written += 1
+ return answers_written
+
+ def _write_records_from_offset(self, records: Sequence[DNSRecord], offset: int) -> int:
+ records_written = 0
+ for record in records[offset:]:
+ if not self._write_record(record, 0):
+ break
+ records_written += 1
+ return records_written
+
+ def _has_more_to_add(
+ self, questions_offset: int, answer_offset: int, authority_offset: int, additional_offset: int
+ ) -> bool:
+ """Check if all questions, answers, authority, and additionals have been written to the packet."""
+ return (
+ questions_offset < len(self.questions)
+ or answer_offset < len(self.answers)
+ or authority_offset < len(self.authorities)
+ or additional_offset < len(self.additionals)
+ )
+
+ def packets(self) -> List[bytes]:
+ """Returns a list of bytestrings containing the packets' bytes
+
+ No further parts should be added to the packet once this
+ is done. The packets are each restricted to _MAX_MSG_TYPICAL
+ or less in length, except for the case of a single answer which
+ will be written out to a single oversized packet no more than
+ _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP
+ fragmentation potentially)."""
+
+ if self.state == self.State.finished:
+ return self.packets_data
+
+ questions_offset = 0
+ answer_offset = 0
+ authority_offset = 0
+ additional_offset = 0
+ # we have to at least write out the question
+ first_time = True
+
+ while first_time or self._has_more_to_add(
+ questions_offset, answer_offset, authority_offset, additional_offset
+ ):
+ first_time = False
+ log.debug(
+ "offsets = questions=%d, answers=%d, authorities=%d, additionals=%d",
+ questions_offset,
+ answer_offset,
+ authority_offset,
+ additional_offset,
+ )
+ log.debug(
+ "lengths = questions=%d, answers=%d, authorities=%d, additionals=%d",
+ len(self.questions),
+ len(self.answers),
+ len(self.authorities),
+ len(self.additionals),
+ )
+
+ questions_written = self._write_questions_from_offset(questions_offset)
+ answers_written = self._write_answers_from_offset(answer_offset)
+ authorities_written = self._write_records_from_offset(self.authorities, authority_offset)
+ additionals_written = self._write_records_from_offset(self.additionals, additional_offset)
+
+ self._insert_short_at_start(additionals_written)
+ self._insert_short_at_start(authorities_written)
+ self._insert_short_at_start(answers_written)
+ self._insert_short_at_start(questions_written)
+
+ questions_offset += questions_written
+ answer_offset += answers_written
+ authority_offset += authorities_written
+ additional_offset += additionals_written
+ log.debug(
+ "now offsets = questions=%d, answers=%d, authorities=%d, additionals=%d",
+ questions_offset,
+ answer_offset,
+ authority_offset,
+ additional_offset,
+ )
+
+ if self.is_query() and self._has_more_to_add(
+ questions_offset, answer_offset, authority_offset, additional_offset
+ ):
+ # https://datatracker.ietf.org/doc/html/rfc6762#section-7.2
+ log.debug("Setting TC flag")
+ self._insert_short_at_start(self.flags | _FLAGS_TC)
+ else:
+ self._insert_short_at_start(self.flags)
+
+ if self.multicast:
+ self._insert_short_at_start(0)
+ else:
+ self._insert_short_at_start(self.id)
+
+ self.packets_data.append(b''.join(self.data))
+ self._reset_for_next_packet()
+
+ if (questions_written + answers_written + authorities_written + additionals_written) == 0 and (
+ len(self.questions) + len(self.answers) + len(self.authorities) + len(self.additionals)
+ ) > 0:
+ log.warning("packets() made no progress adding records; returning")
+ break
+ self.state = self.State.finished
+ return self.packets_data
diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py
new file mode 100644
index 00000000..5b9fbf01
--- /dev/null
+++ b/zeroconf/_services/__init__.py
@@ -0,0 +1,72 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import enum
+from typing import Any, Callable, List, TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+ from .._core import Zeroconf
+
+
+@enum.unique
+class ServiceStateChange(enum.Enum):
+ Added = 1
+ Removed = 2
+ Updated = 3
+
+
+class ServiceListener:
+ def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None:
+ raise NotImplementedError()
+
+ def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None:
+ raise NotImplementedError()
+
+ def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None:
+ raise NotImplementedError()
+
+
+class Signal:
+ def __init__(self) -> None:
+ self._handlers: List[Callable[..., None]] = []
+
+ def fire(self, **kwargs: Any) -> None:
+ for h in list(self._handlers):
+ h(**kwargs)
+
+ @property
+ def registration_interface(self) -> 'SignalRegistrationInterface':
+ return SignalRegistrationInterface(self._handlers)
+
+
+class SignalRegistrationInterface:
+ def __init__(self, handlers: List[Callable[..., None]]) -> None:
+ self._handlers = handlers
+
+ def register_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface':
+ self._handlers.append(handler)
+ return self
+
+ def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface':
+ self._handlers.remove(handler)
+ return self
diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py
new file mode 100644
index 00000000..f6448fd2
--- /dev/null
+++ b/zeroconf/_services/browser.py
@@ -0,0 +1,539 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import asyncio
+import queue
+import random
+import threading
+import warnings
+from collections import OrderedDict
+from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast
+
+from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSQuestionType, DNSRecord
+from .._logger import log
+from .._protocol.outgoing import DNSOutgoing
+from .._services import (
+ ServiceListener,
+ ServiceStateChange,
+ Signal,
+ SignalRegistrationInterface,
+)
+from .._updates import RecordUpdate, RecordUpdateListener
+from .._utils.asyncio import get_best_available_queue
+from .._utils.name import service_type_name
+from .._utils.time import current_time_millis, millis_to_seconds
+from ..const import (
+ _BROWSER_BACKOFF_LIMIT,
+ _BROWSER_TIME,
+ _CLASS_IN,
+ _DNS_PACKET_HEADER_LEN,
+ _EXPIRE_REFRESH_TIME_PERCENT,
+ _FLAGS_QR_QUERY,
+ _MAX_MSG_TYPICAL,
+ _MDNS_ADDR,
+ _MDNS_ADDR6,
+ _MDNS_PORT,
+ _TYPE_PTR,
+)
+
+# https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
+_FIRST_QUERY_DELAY_RANDOM_INTERVAL = (20, 120) # ms
+
+_ON_CHANGE_DISPATCH = {
+ ServiceStateChange.Added: "add_service",
+ ServiceStateChange.Removed: "remove_service",
+ ServiceStateChange.Updated: "update_service",
+}
+
+if TYPE_CHECKING:
+ from .._core import Zeroconf
+
+
+_QuestionWithKnownAnswers = Dict[DNSQuestion, Set[DNSPointer]]
+
+
+class _DNSPointerOutgoingBucket:
+ """A DNSOutgoing bucket."""
+
+ def __init__(self, now: float, multicast: bool) -> None:
+ """Create a bucke to wrap a DNSOutgoing."""
+ self.now = now
+ self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=multicast)
+ self.bytes = 0
+
+ def add(self, max_compressed_size: int, question: DNSQuestion, answers: Set[DNSPointer]) -> None:
+ """Add a new set of questions and known answers to the outgoing."""
+ self.out.add_question(question)
+ for answer in answers:
+ self.out.add_answer_at_time(answer, self.now)
+ self.bytes += max_compressed_size
+
+
+def _group_ptr_queries_with_known_answers(
+ now: float, multicast: bool, question_with_known_answers: _QuestionWithKnownAnswers
+) -> List[DNSOutgoing]:
+ """Aggregate queries so that as many known answers as possible fit in the same packet
+ without having known answers spill over into the next packet unless the
+ question and known answers are always going to exceed the packet size.
+
+ Some responders do not implement multi-packet known answer suppression
+ so we try to keep all the known answers in the same packet as the
+ questions.
+ """
+ # This is the maximum size the query + known answers can be with name compression.
+ # The actual size of the query + known answers may be a bit smaller since other
+ # parts may be shared when the final DNSOutgoing packets are constructed. The
+ # goal of this algorithm is to quickly bucket the query + known answers without
+ # the overhead of actually constructing the packets.
+ query_by_size: Dict[DNSQuestion, int] = {
+ question: (question.max_size + sum([answer.max_size_compressed for answer in known_answers]))
+ for question, known_answers in question_with_known_answers.items()
+ }
+ max_bucket_size = _MAX_MSG_TYPICAL - _DNS_PACKET_HEADER_LEN
+ query_buckets: List[_DNSPointerOutgoingBucket] = []
+ for question in sorted(
+ query_by_size,
+ key=query_by_size.get, # type: ignore
+ reverse=True,
+ ):
+ max_compressed_size = query_by_size[question]
+ answers = question_with_known_answers[question]
+ for query_bucket in query_buckets:
+ if query_bucket.bytes + max_compressed_size <= max_bucket_size:
+ query_bucket.add(max_compressed_size, question, answers)
+ break
+ else:
+ # If a single question and known answers won't fit in a packet
+ # we will end up generating multiple packets, but there will never
+ # be multiple questions
+ query_bucket = _DNSPointerOutgoingBucket(now, multicast)
+ query_bucket.add(max_compressed_size, question, answers)
+ query_buckets.append(query_bucket)
+
+ return [query_bucket.out for query_bucket in query_buckets]
+
+
+def generate_service_query(
+ zc: 'Zeroconf',
+ now: float,
+ types_: List[str],
+ multicast: bool = True,
+ question_type: Optional[DNSQuestionType] = None,
+) -> List[DNSOutgoing]:
+ """Generate a service query for sending with zeroconf.send."""
+ questions_with_known_answers: _QuestionWithKnownAnswers = {}
+ qu_question = not multicast if question_type is None else question_type == DNSQuestionType.QU
+ for type_ in types_:
+ question = DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)
+ question.unicast = qu_question
+ known_answers = set(
+ cast(DNSPointer, record)
+ for record in zc.cache.get_all_by_details(type_, _TYPE_PTR, _CLASS_IN)
+ if not record.is_stale(now)
+ )
+ if not qu_question and zc.question_history.suppresses(
+ question, now, cast(Set[DNSRecord], known_answers)
+ ):
+ log.debug("Asking %s was suppressed by the question history", question)
+ continue
+ questions_with_known_answers[question] = known_answers
+ if not qu_question:
+ zc.question_history.add_question_at_time(question, now, cast(Set[DNSRecord], known_answers))
+
+ return _group_ptr_queries_with_known_answers(now, multicast, questions_with_known_answers)
+
+
+def _service_state_changed_from_listener(listener: ServiceListener) -> Callable[..., None]:
+ """Generate a service_state_changed handlers from a listener."""
+ assert listener is not None
+ if not hasattr(listener, 'update_service'):
+ warnings.warn(
+ "%r has no update_service method. Provide one (it can be empty if you "
+ "don't care about the updates), it'll become mandatory." % (listener,),
+ FutureWarning,
+ )
+
+ def on_change(
+ zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange
+ ) -> None:
+ getattr(listener, _ON_CHANGE_DISPATCH[state_change])(zeroconf, service_type, name)
+
+ return on_change
+
+
+class QueryScheduler:
+ """Schedule outgoing PTR queries for Continuous Multicast DNS Querying
+
+ https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
+
+ """
+
+ def __init__(
+ self,
+ types: Set[str],
+ delay: int,
+ first_random_delay_interval: Tuple[int, int],
+ ):
+ self._schedule_changed_event: Optional[asyncio.Event] = None
+ self._types = types
+ self._next_time: Dict[str, float] = {}
+ self._first_random_delay_interval = first_random_delay_interval
+ self._delay: Dict[str, float] = {check_type_: delay for check_type_ in self._types}
+
+ def start(self, now: float) -> None:
+ """Start the scheduler."""
+ self._schedule_changed_event = asyncio.Event()
+ self._generate_first_next_time(now)
+
+ def _generate_first_next_time(self, now: float) -> None:
+ """Generate the initial next query times.
+
+ https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
+ To avoid accidental synchronization when, for some reason, multiple
+ clients begin querying at exactly the same moment (e.g., because of
+ some common external trigger event), a Multicast DNS querier SHOULD
+ also delay the first query of the series by a randomly chosen amount
+ in the range 20-120 ms.
+ """
+ delay = millis_to_seconds(random.randint(*self._first_random_delay_interval))
+ next_time = now + delay
+ self._next_time = {check_type_: next_time for check_type_ in self._types}
+
+ def millis_to_wait(self, now: float) -> float:
+ """Returns the number of milliseconds to wait for the next event."""
+ # Wait for the type has the smallest next time
+ next_time = min(self._next_time.values())
+ return 0 if next_time <= now else next_time - now
+
+ def reschedule_type(self, type_: str, next_time: float) -> bool:
+ """Reschedule the query for a type to happen sooner."""
+ if next_time >= self._next_time[type_]:
+ return False
+ self._next_time[type_] = next_time
+ return True
+
+ def process_ready_types(self, now: float) -> List[str]:
+ """Generate a list of ready types that is due and schedule the next time."""
+ if self.millis_to_wait(now):
+ return []
+
+ ready_types: List[str] = []
+
+ for type_, due in self._next_time.items():
+ if due > now:
+ continue
+
+ ready_types.append(type_)
+ self._next_time[type_] = now + self._delay[type_]
+ self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2)
+
+ return ready_types
+
+
+class _ServiceBrowserBase(RecordUpdateListener):
+ """Base class for ServiceBrowser."""
+
+ def __init__(
+ self,
+ zc: 'Zeroconf',
+ type_: Union[str, list],
+ handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None,
+ listener: Optional[ServiceListener] = None,
+ addr: Optional[str] = None,
+ port: int = _MDNS_PORT,
+ delay: int = _BROWSER_TIME,
+ question_type: Optional[DNSQuestionType] = None,
+ ) -> None:
+ """Used to browse for a service for specific type(s).
+
+ Constructor parameters are as follows:
+
+ * `zc`: A Zeroconf instance
+ * `type_`: fully qualified service type name
+ * `handler`: ServiceListener or Callable that knows how to process ServiceStateChange events
+ * `listener`: ServiceListener
+ * `addr`: address to send queries (will default to multicast)
+ * `port`: port to send queries (will default to mdns 5353)
+ * `delay`: The initial delay between answering questions
+ * `question_type`: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU)
+
+ The listener object will have its add_service() and
+ remove_service() methods called when this browser
+ discovers changes in the services availability.
+ """
+ assert handlers or listener, 'You need to specify at least one handler'
+ self.types: Set[str] = set(type_ if isinstance(type_, list) else [type_])
+ for check_type_ in self.types:
+ # Will generate BadTypeInNameException on a bad name
+ service_type_name(check_type_, strict=False)
+ self.zc = zc
+ self.addr = addr
+ self.port = port
+ self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6)
+ self.question_type = question_type
+ self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict()
+ self._service_state_changed = Signal()
+ self.query_scheduler = QueryScheduler(self.types, delay, _FIRST_QUERY_DELAY_RANDOM_INTERVAL)
+ self.queue: Optional[queue.Queue] = None
+ self.done = False
+ self._first_request: bool = True
+ self._next_send_timer: Optional[asyncio.TimerHandle] = None
+
+ if hasattr(handlers, 'add_service'):
+ listener = cast('ServiceListener', handlers)
+ handlers = None
+
+ handlers = cast(List[Callable[..., None]], handlers or [])
+
+ if listener:
+ handlers.append(_service_state_changed_from_listener(listener))
+
+ for h in handlers:
+ self.service_state_changed.register_handler(h)
+
+ def _async_start(self) -> None:
+ """Generate the next time and setup listeners.
+
+ Must be called by uses of this base class after they
+ have finished setting their properties.
+ """
+ self.query_scheduler.start(current_time_millis())
+ self.zc.async_add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types])
+ # Only start queries after the listener is installed
+ asyncio.ensure_future(self._async_start_query_sender())
+
+ @property
+ def service_state_changed(self) -> SignalRegistrationInterface:
+ return self._service_state_changed.registration_interface
+
+ def _record_matching_type(self, record: DNSRecord) -> Optional[str]:
+ """Return the type if the record matches one of the types we are browsing."""
+ return next((type_ for type_ in self.types if record.name.endswith(type_)), None)
+
+ def _enqueue_callback(
+ self,
+ state_change: ServiceStateChange,
+ type_: str,
+ name: str,
+ ) -> None:
+ # Code to ensure we only do a single update message
+ # Precedence is; Added, Remove, Update
+ key = (name, type_)
+ if (
+ state_change is ServiceStateChange.Added
+ or (
+ state_change is ServiceStateChange.Removed
+ and self._pending_handlers.get(key) != ServiceStateChange.Added
+ )
+ or (state_change is ServiceStateChange.Updated and key not in self._pending_handlers)
+ ):
+ self._pending_handlers[key] = state_change
+
+ def _async_process_record_update(
+ self, now: float, record: DNSRecord, old_record: Optional[DNSRecord]
+ ) -> None:
+ """Process a single record update from a batch of updates."""
+ if isinstance(record, DNSPointer):
+ if record.name not in self.types:
+ return
+ if old_record is None:
+ self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias)
+ elif record.is_expired(now):
+ self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
+ else:
+ self.reschedule_type(record.name, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT))
+ return
+
+ # If its expired or already exists in the cache it cannot be updated.
+ if old_record or record.is_expired(now):
+ return
+
+ if isinstance(record, DNSAddress):
+ # Iterate through the DNSCache and callback any services that use this address
+ for service in self.zc.cache.async_entries_with_server(record.name):
+ type_ = self._record_matching_type(service)
+ if type_:
+ self._enqueue_callback(ServiceStateChange.Updated, type_, service.name)
+ break
+
+ return
+
+ type_ = self._record_matching_type(record)
+ if type_:
+ self._enqueue_callback(ServiceStateChange.Updated, type_, record.name)
+
+ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
+ """Callback invoked by Zeroconf when new information arrives.
+
+ Updates information required by browser in the Zeroconf cache.
+
+ Ensures that there is are no unecessary duplicates in the list.
+
+ This method will be run in the event loop.
+ """
+ for record in records:
+ self._async_process_record_update(now, record[0], record[1])
+
+ def async_update_records_complete(self) -> None:
+ """Called when a record update has completed for all handlers.
+
+ At this point the cache will have the new records.
+
+ This method will be run in the event loop.
+ """
+ while self._pending_handlers:
+ event = self._pending_handlers.popitem(False)
+ # If there is a queue running (ServiceBrowser)
+ # get fired in dedicated thread
+ if self.queue:
+ self.queue.put(event)
+ else:
+ self._fire_service_state_changed_event(event)
+
+ def _fire_service_state_changed_event(self, event: Tuple[Tuple[str, str], ServiceStateChange]) -> None:
+ """Fire a service state changed event.
+
+ When running with ServiceBrowser, this will happen in the dedicated
+ thread.
+
+ When running with AsyncServiceBrowser, this will happen in the event loop.
+ """
+ name_type, state_change = event
+ self._service_state_changed.fire(
+ zeroconf=self.zc,
+ service_type=name_type[1],
+ name=name_type[0],
+ state_change=state_change,
+ )
+
+ def _async_cancel(self) -> None:
+ """Cancel the browser."""
+ self.done = True
+ self._cancel_send_timer()
+ self.zc.async_remove_listener(self)
+
+ def _generate_ready_queries(self, first_request: bool) -> List[DNSOutgoing]:
+ """Generate the service browser query for any type that is due."""
+ now = current_time_millis()
+ ready_types = self.query_scheduler.process_ready_types(now)
+ if not ready_types:
+ return []
+
+ # If they did not specify and this is the first request, ask QU questions
+ # https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 since we are
+ # just starting up and we know our cache is likely empty. This ensures
+ # the next outgoing will be sent with the known answers list.
+ question_type = DNSQuestionType.QU if not self.question_type and first_request else self.question_type
+ return generate_service_query(self.zc, now, ready_types, self.multicast, question_type)
+
+ async def _async_start_query_sender(self) -> None:
+ """Start scheduling queries."""
+ await self.zc.async_wait_for_start()
+ self._async_send_ready_queries()
+ self._async_schedule_next()
+
+ def _cancel_send_timer(self) -> None:
+ """Cancel the next send."""
+ if self._next_send_timer:
+ self._next_send_timer.cancel()
+
+ def reschedule_type(self, type_: str, next_time: float) -> None:
+ """Reschedule a type to be refreshed in the future."""
+ if self.query_scheduler.reschedule_type(type_, next_time):
+ self._cancel_send_timer()
+ self._async_schedule_next()
+ self._async_send_ready_queries()
+
+ def _async_send_ready_queries(self) -> None:
+ """Send any ready queries."""
+ if self.done or self.zc.done:
+ return
+
+ outs = self._generate_ready_queries(self._first_request)
+ if outs:
+ self._first_request = False
+ for out in outs:
+ self.zc.async_send(out, addr=self.addr, port=self.port)
+
+ def _async_send_ready_queries_schedule_next(self) -> None:
+ """Send ready queries and schedule next one."""
+ self._async_send_ready_queries()
+ self._async_schedule_next()
+
+ def _async_schedule_next(self) -> None:
+ """Scheule the next time."""
+ assert self.zc.loop is not None
+ delay = millis_to_seconds(self.query_scheduler.millis_to_wait(current_time_millis()))
+ self._next_send_timer = self.zc.loop.call_later(delay, self._async_send_ready_queries_schedule_next)
+
+
+class ServiceBrowser(_ServiceBrowserBase, threading.Thread):
+ """Used to browse for a service of a specific type.
+
+ The listener object will have its add_service() and
+ remove_service() methods called when this browser
+ discovers changes in the services availability."""
+
+ def __init__(
+ self,
+ zc: 'Zeroconf',
+ type_: Union[str, list],
+ handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None,
+ listener: Optional[ServiceListener] = None,
+ addr: Optional[str] = None,
+ port: int = _MDNS_PORT,
+ delay: int = _BROWSER_TIME,
+ question_type: Optional[DNSQuestionType] = None,
+ ) -> None:
+ assert zc.loop is not None
+ if not zc.loop.is_running():
+ raise RuntimeError("The event loop is not running")
+ threading.Thread.__init__(self)
+ super().__init__(zc, type_, handlers, listener, addr, port, delay, question_type)
+ # Add the queue before the listener is installed in _setup
+ # to ensure that events run in the dedicated thread and do
+ # not block the event loop
+ self.queue = get_best_available_queue()
+ self.daemon = True
+ self.start()
+ zc.loop.call_soon_threadsafe(self._async_start)
+ self.name = "zeroconf-ServiceBrowser-%s-%s" % (
+ '-'.join([type_[:-7] for type_ in self.types]),
+ getattr(self, 'native_id', self.ident),
+ )
+
+ def cancel(self) -> None:
+ """Cancel the browser."""
+ assert self.zc.loop is not None
+ assert self.queue is not None
+ self.queue.put(None)
+ self.zc.loop.call_soon_threadsafe(self._async_cancel)
+ self.join()
+
+ def run(self) -> None:
+ """Run the browser thread."""
+ assert self.queue is not None
+ while True:
+ event = self.queue.get()
+ if event is None:
+ return
+ self._fire_service_state_changed_event(event)
diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py
new file mode 100644
index 00000000..beaf0678
--- /dev/null
+++ b/zeroconf/_services/info.py
@@ -0,0 +1,513 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import ipaddress
+import random
+import socket
+from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast
+
+from .._dns import DNSAddress, DNSPointer, DNSQuestionType, DNSRecord, DNSService, DNSText
+from .._exceptions import BadTypeInNameException
+from .._protocol.outgoing import DNSOutgoing
+from .._updates import RecordUpdate, RecordUpdateListener
+from .._utils.asyncio import get_running_loop, run_coro_with_timeout
+from .._utils.name import service_type_name
+from .._utils.net import (
+ IPVersion,
+ _encode_address,
+ _is_v6_address,
+)
+from .._utils.time import current_time_millis
+from ..const import (
+ _CLASS_IN,
+ _CLASS_UNIQUE,
+ _DNS_HOST_TTL,
+ _DNS_OTHER_TTL,
+ _FLAGS_QR_QUERY,
+ _LISTENER_TIME,
+ _TYPE_A,
+ _TYPE_AAAA,
+ _TYPE_PTR,
+ _TYPE_SRV,
+ _TYPE_TXT,
+)
+
+
+# https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
+# The most common case for calling ServiceInfo is from a
+# ServiceBrowser. After the first request we add a few random
+# milliseconds to the delay between requests to reduce the chance
+# that there are multiple ServiceBrowser callbacks running on
+# the network that are firing at the same time when they
+# see the same multicast response and decide to refresh
+# the A/AAAA/SRV records for a host.
+_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120)
+
+if TYPE_CHECKING:
+ from .._core import Zeroconf
+
+
+def instance_name_from_service_info(info: "ServiceInfo") -> str:
+ """Calculate the instance name from the ServiceInfo."""
+ # This is kind of funky because of the subtype based tests
+ # need to make subtypes a first class citizen
+ service_name = service_type_name(info.name)
+ if not info.type.endswith(service_name):
+ raise BadTypeInNameException
+ return info.name[: -len(service_name) - 1]
+
+
+class ServiceInfo(RecordUpdateListener):
+ """Service information.
+
+ Constructor parameters are as follows:
+
+ * `type_`: fully qualified service type name
+ * `name`: fully qualified service name
+ * `port`: port that the service runs on
+ * `weight`: weight of the service
+ * `priority`: priority of the service
+ * `properties`: dictionary of properties (or a bytes object holding the contents of the `text` field).
+ converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to
+ value-less attributes.
+ * `server`: fully qualified name for service host (defaults to name)
+ * `host_ttl`: ttl used for A/SRV records
+ * `other_ttl`: ttl used for PTR/TXT records
+ * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order,
+ or in parsed form as text; at most one of those parameters can be provided)
+ * interface_index: scope_id or zone_id for IPv6 link-local addresses i.e. an identifier of the interface
+ where the peer is connected to
+ """
+
+ text = b''
+
+ def __init__(
+ self,
+ type_: str,
+ name: str,
+ port: Optional[int] = None,
+ weight: int = 0,
+ priority: int = 0,
+ properties: Union[bytes, Dict] = b'',
+ server: Optional[str] = None,
+ host_ttl: int = _DNS_HOST_TTL,
+ other_ttl: int = _DNS_OTHER_TTL,
+ *,
+ addresses: Optional[List[bytes]] = None,
+ parsed_addresses: Optional[List[str]] = None,
+ interface_index: Optional[int] = None,
+ ) -> None:
+ # Accept both none, or one, but not both.
+ if addresses is not None and parsed_addresses is not None:
+ raise TypeError("addresses and parsed_addresses cannot be provided together")
+ if not type_.endswith(service_type_name(name, strict=False)):
+ raise BadTypeInNameException
+ self.type = type_
+ self._name = name
+ self.key = name.lower()
+ if addresses is not None:
+ self._addresses = addresses
+ elif parsed_addresses is not None:
+ self._addresses = [_encode_address(a) for a in parsed_addresses]
+ else:
+ self._addresses = []
+ # This results in an ugly error when registering, better check now
+ invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)]
+ if invalid:
+ raise TypeError(
+ 'Addresses must be bytes, got %s. Hint: convert string addresses '
+ 'with socket.inet_pton' % invalid
+ )
+ self.port = port
+ self.weight = weight
+ self.priority = priority
+ self.server = server if server else name
+ self.server_key = self.server.lower()
+ self._properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {}
+ if isinstance(properties, bytes):
+ self._set_text(properties)
+ else:
+ self._set_properties(properties)
+ self.host_ttl = host_ttl
+ self.other_ttl = other_ttl
+ self.interface_index = interface_index
+
+ @property
+ def name(self) -> str:
+ """The name of the service."""
+ return self._name
+
+ @name.setter
+ def name(self, name: str) -> None:
+ """Replace the the name and reset the key."""
+ self._name = name
+ self.key = name.lower()
+
+ @property
+ def addresses(self) -> List[bytes]:
+ """IPv4 addresses of this service.
+
+ Only IPv4 addresses are returned for backward compatibility.
+ Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to
+ include IPv6 addresses as well.
+ """
+ return self.addresses_by_version(IPVersion.V4Only)
+
+ @addresses.setter
+ def addresses(self, value: List[bytes]) -> None:
+ """Replace the addresses list.
+
+ This replaces all currently stored addresses, both IPv4 and IPv6.
+ """
+ self._addresses = value
+
+ @property
+ def properties(self) -> Dict:
+ """If properties were set in the constructor this property returns the original dictionary
+ of type `Dict[Union[bytes, str], Any]`.
+
+ If properties are coming from the network, after decoding a TXT record, the keys are always
+ bytes and the values are either bytes, if there was a value, even empty, or `None`, if there
+ was none. No further decoding is attempted. The type returned is `Dict[bytes, Optional[bytes]]`.
+ """
+ return self._properties
+
+ def addresses_by_version(self, version: IPVersion) -> List[bytes]:
+ """List addresses matching IP version."""
+ if version == IPVersion.V4Only:
+ return [addr for addr in self._addresses if not _is_v6_address(addr)]
+ if version == IPVersion.V6Only:
+ return list(filter(_is_v6_address, self._addresses))
+ return self._addresses
+
+ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
+ """List addresses in their parsed string form."""
+ result = self.addresses_by_version(version)
+ return [
+ socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr)
+ for addr in result
+ ]
+
+ def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
+ """Equivalent to parsed_addresses, with the exception that IPv6 Link-Local
+ addresses are qualified with % when available
+ """
+ if self.interface_index is None:
+ return self.parsed_addresses(version)
+
+ def is_link_local(addr_str: str) -> Any:
+ addr = ipaddress.ip_address(addr_str)
+ return addr.version == 6 and addr.is_link_local
+
+ ll_addrs = list(filter(is_link_local, self.parsed_addresses(version)))
+ other_addrs = list(filter(lambda addr: not is_link_local(addr), self.parsed_addresses(version)))
+ return ["{}%{}".format(addr, self.interface_index) for addr in ll_addrs] + other_addrs
+
+ def _set_properties(self, properties: Dict) -> None:
+ """Sets properties and text of this info from a dictionary"""
+ self._properties = properties
+ list_ = []
+ result = b''
+ for key, value in properties.items():
+ if isinstance(key, str):
+ key = key.encode('utf-8')
+
+ record = key
+ if value is not None:
+ if not isinstance(value, bytes):
+ value = str(value).encode('utf-8')
+ record += b'=' + value
+ list_.append(record)
+ for item in list_:
+ result = b''.join((result, bytes((len(item),)), item))
+ self.text = result
+
+ def _set_text(self, text: bytes) -> None:
+ """Sets properties and text given a text field"""
+ self.text = text
+ end = len(text)
+ if end == 0:
+ self._properties = {}
+ return
+ result: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {}
+ index = 0
+ strs = []
+ while index < end:
+ length = text[index]
+ index += 1
+ strs.append(text[index : index + length])
+ index += length
+
+ key: bytes
+ value: Optional[bytes]
+ for s in strs:
+ try:
+ key, value = s.split(b'=', 1)
+ except ValueError:
+ # No equals sign at all
+ key = s
+ value = None
+
+ # Only update non-existent properties
+ if key and result.get(key) is None:
+ result[key] = value
+
+ self._properties = result
+
+ def get_name(self) -> str:
+ """Name accessor"""
+ return self.name[: len(self.name) - len(self.type) - 1]
+
+ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None:
+ """Updates service information from a DNS record.
+
+ This method is deprecated and will be removed in a future version.
+ update_records should be implemented instead.
+
+ This method will be run in the event loop.
+ """
+ if record is not None:
+ self._process_records_threadsafe(zc, now, [RecordUpdate(record, None)])
+
+ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
+ """Updates service information from a DNS record.
+
+ This method will be run in the event loop.
+ """
+ self._process_records_threadsafe(zc, now, records)
+
+ def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
+ """Thread safe record updating."""
+ update_addresses = False
+ for record_update in records:
+ if isinstance(record_update[0], DNSService):
+ update_addresses = True
+ self._process_record_threadsafe(record_update[0], now)
+
+ # Only update addresses if the DNSService (.server) has changed
+ if not update_addresses:
+ return
+
+ for record in self._get_address_records_from_cache(zc):
+ self._process_record_threadsafe(record, now)
+
+ def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None:
+ if record.is_expired(now):
+ return
+
+ if isinstance(record, DNSAddress):
+ if record.key == self.server_key and record.address not in self._addresses:
+ self._addresses.append(record.address)
+ if record.type is _TYPE_AAAA and ipaddress.IPv6Address(record.address).is_link_local:
+ self.interface_index = record.scope_id
+ return
+
+ if isinstance(record, DNSService):
+ if record.key != self.key:
+ return
+ self.name = record.name
+ self.server = record.server
+ self.server_key = record.server.lower()
+ self.port = record.port
+ self.weight = record.weight
+ self.priority = record.priority
+ return
+
+ if isinstance(record, DNSText):
+ if record.key == self.key:
+ self._set_text(record.text)
+
+ def dns_addresses(
+ self,
+ override_ttl: Optional[int] = None,
+ version: IPVersion = IPVersion.All,
+ created: Optional[float] = None,
+ ) -> List[DNSAddress]:
+ """Return matching DNSAddress from ServiceInfo."""
+ return [
+ DNSAddress(
+ self.server,
+ _TYPE_AAAA if _is_v6_address(address) else _TYPE_A,
+ _CLASS_IN | _CLASS_UNIQUE,
+ override_ttl if override_ttl is not None else self.host_ttl,
+ address,
+ created=created,
+ )
+ for address in self.addresses_by_version(version)
+ ]
+
+ def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer:
+ """Return DNSPointer from ServiceInfo."""
+ return DNSPointer(
+ self.type,
+ _TYPE_PTR,
+ _CLASS_IN,
+ override_ttl if override_ttl is not None else self.other_ttl,
+ self.name,
+ created,
+ )
+
+ def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService:
+ """Return DNSService from ServiceInfo."""
+ return DNSService(
+ self.name,
+ _TYPE_SRV,
+ _CLASS_IN | _CLASS_UNIQUE,
+ override_ttl if override_ttl is not None else self.host_ttl,
+ self.priority,
+ self.weight,
+ cast(int, self.port),
+ self.server,
+ created,
+ )
+
+ def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText:
+ """Return DNSText from ServiceInfo."""
+ return DNSText(
+ self.name,
+ _TYPE_TXT,
+ _CLASS_IN | _CLASS_UNIQUE,
+ override_ttl if override_ttl is not None else self.other_ttl,
+ self.text,
+ created,
+ )
+
+ def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]:
+ """Get the address records from the cache."""
+ return [
+ *zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN),
+ *zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN),
+ ]
+
+ def load_from_cache(self, zc: 'Zeroconf') -> bool:
+ """Populate the service info from the cache.
+
+ This method is designed to be threadsafe.
+ """
+ now = current_time_millis()
+ record_updates: List[RecordUpdate] = []
+ cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN)
+ if cached_srv_record:
+ # If there is a srv record, A and AAAA will already
+ # be called and we do not want to do it twice
+ record_updates.append(RecordUpdate(cached_srv_record, None))
+ else:
+ for record in self._get_address_records_from_cache(zc):
+ record_updates.append(RecordUpdate(record, None))
+ cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN)
+ if cached_txt_record:
+ record_updates.append(RecordUpdate(cached_txt_record, None))
+ self._process_records_threadsafe(zc, now, record_updates)
+ return self._is_complete
+
+ @property
+ def _is_complete(self) -> bool:
+ """The ServiceInfo has all expected properties."""
+ return not (self.text is None or not self._addresses)
+
+ def request(
+ self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None
+ ) -> bool:
+ """Returns true if the service could be discovered on the
+ network, and updates this object with details discovered.
+ """
+ assert zc.loop is not None and zc.loop.is_running()
+ if zc.loop == get_running_loop():
+ raise RuntimeError("Use AsyncServiceInfo.async_request from the event loop")
+ return bool(run_coro_with_timeout(self.async_request(zc, timeout, question_type), zc.loop, timeout))
+
+ async def async_request(
+ self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None
+ ) -> bool:
+ """Returns true if the service could be discovered on the
+ network, and updates this object with details discovered.
+ """
+ if self.load_from_cache(zc):
+ return True
+
+ first_request = True
+ now = current_time_millis()
+ delay = _LISTENER_TIME
+ next_ = now
+ last = now + timeout
+ await zc.async_wait_for_start()
+ try:
+ zc.async_add_listener(self, None)
+ while not self._is_complete:
+ if last <= now:
+ return False
+ if next_ <= now:
+ out = self.generate_request_query(
+ zc, now, question_type or DNSQuestionType.QU if first_request else DNSQuestionType.QM
+ )
+ first_request = False
+ if not out.questions:
+ return self.load_from_cache(zc)
+ zc.async_send(out)
+ next_ = now + delay
+ delay *= 2
+ next_ += random.randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL)
+
+ await zc.async_wait(min(next_, last) - now)
+ now = current_time_millis()
+ finally:
+ zc.async_remove_listener(self)
+
+ return True
+
+ def generate_request_query(
+ self, zc: 'Zeroconf', now: float, question_type: Optional[DNSQuestionType] = None
+ ) -> DNSOutgoing:
+ """Generate the request query."""
+ out = DNSOutgoing(_FLAGS_QR_QUERY)
+ out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN)
+ out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN)
+ out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN)
+ out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN)
+ if question_type == DNSQuestionType.QU:
+ for question in out.questions:
+ question.unicast = True
+ return out
+
+ def __eq__(self, other: object) -> bool:
+ """Tests equality of service name"""
+ return isinstance(other, ServiceInfo) and other.name == self.name
+
+ def __repr__(self) -> str:
+ """String representation"""
+ return '%s(%s)' % (
+ type(self).__name__,
+ ', '.join(
+ '%s=%r' % (name, getattr(self, name))
+ for name in (
+ 'type',
+ 'name',
+ 'addresses',
+ 'port',
+ 'weight',
+ 'priority',
+ 'server',
+ 'properties',
+ 'interface_index',
+ )
+ ),
+ )
diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py
new file mode 100644
index 00000000..203b3b39
--- /dev/null
+++ b/zeroconf/_services/registry.py
@@ -0,0 +1,99 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+from typing import Dict, List, Optional, Union
+
+
+from .info import ServiceInfo
+from .._exceptions import ServiceNameAlreadyRegistered
+
+
+class ServiceRegistry:
+ """A registry to keep track of services.
+
+ The registry must only be accessed from
+ the event loop as it is not thread safe.
+ """
+
+ def __init__(
+ self,
+ ) -> None:
+ """Create the ServiceRegistry class."""
+ self._services: Dict[str, ServiceInfo] = {}
+ self.types: Dict[str, List] = {}
+ self.servers: Dict[str, List] = {}
+
+ def async_add(self, info: ServiceInfo) -> None:
+ """Add a new service to the registry."""
+ self._add(info)
+
+ def async_remove(self, info: Union[List[ServiceInfo], ServiceInfo]) -> None:
+ """Remove a new service from the registry."""
+ self._remove(info if isinstance(info, list) else [info])
+
+ def async_update(self, info: ServiceInfo) -> None:
+ """Update new service in the registry."""
+ self._remove([info])
+ self._add(info)
+
+ def async_get_service_infos(self) -> List[ServiceInfo]:
+ """Return all ServiceInfo."""
+ return list(self._services.values())
+
+ def async_get_info_name(self, name: str) -> Optional[ServiceInfo]:
+ """Return all ServiceInfo for the name."""
+ return self._services.get(name.lower())
+
+ def async_get_types(self) -> List[str]:
+ """Return all types."""
+ return list(self.types.keys())
+
+ def async_get_infos_type(self, type_: str) -> List[ServiceInfo]:
+ """Return all ServiceInfo matching type."""
+ return self._async_get_by_index(self.types, type_)
+
+ def async_get_infos_server(self, server: str) -> List[ServiceInfo]:
+ """Return all ServiceInfo matching server."""
+ return self._async_get_by_index(self.servers, server)
+
+ def _async_get_by_index(self, records: Dict[str, List], key: str) -> List[ServiceInfo]:
+ """Return all ServiceInfo matching the index."""
+ return [self._services[name] for name in records.get(key.lower(), [])]
+
+ def _add(self, info: ServiceInfo) -> None:
+ """Add a new service under the lock."""
+ if info.key in self._services:
+ raise ServiceNameAlreadyRegistered
+
+ self._services[info.key] = info
+ self.types.setdefault(info.type.lower(), []).append(info.key)
+ self.servers.setdefault(info.server_key, []).append(info.key)
+
+ def _remove(self, infos: List[ServiceInfo]) -> None:
+ """Remove a services under the lock."""
+ for info in infos:
+ if info.key not in self._services:
+ continue
+ old_service_info = self._services[info.key]
+ self.types[old_service_info.type.lower()].remove(info.key)
+ self.servers[old_service_info.server_key].remove(info.key)
+ del self._services[info.key]
diff --git a/zeroconf/_services/types.py b/zeroconf/_services/types.py
new file mode 100644
index 00000000..34b000f1
--- /dev/null
+++ b/zeroconf/_services/types.py
@@ -0,0 +1,83 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import time
+from typing import Optional, Set, Tuple, Union
+
+from .browser import ServiceBrowser
+from .._core import Zeroconf
+from .._services import ServiceListener
+from .._utils.net import IPVersion, InterfaceChoice, InterfacesType
+from ..const import _SERVICE_TYPE_ENUMERATION_NAME
+
+
+class ZeroconfServiceTypes(ServiceListener):
+ """
+ Return all of the advertised services on any local networks
+ """
+
+ def __init__(self) -> None:
+ """Keep track of found services in a set."""
+ self.found_services: Set[str] = set()
+
+ def add_service(self, zc: Zeroconf, type_: str, name: str) -> None:
+ """Service added."""
+ self.found_services.add(name)
+
+ def update_service(self, zc: Zeroconf, type_: str, name: str) -> None:
+ """Service updated."""
+
+ def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None:
+ """Service removed."""
+
+ @classmethod
+ def find(
+ cls,
+ zc: Optional[Zeroconf] = None,
+ timeout: Union[int, float] = 5,
+ interfaces: InterfacesType = InterfaceChoice.All,
+ ip_version: Optional[IPVersion] = None,
+ ) -> Tuple[str, ...]:
+ """
+ Return all of the advertised services on any local networks.
+
+ :param zc: Zeroconf() instance. Pass in if already have an
+ instance running or if non-default interfaces are needed
+ :param timeout: seconds to wait for any responses
+ :param interfaces: interfaces to listen on.
+ :param ip_version: IP protocol version to use.
+ :return: tuple of service type strings
+ """
+ local_zc = zc or Zeroconf(interfaces=interfaces, ip_version=ip_version)
+ listener = cls()
+ browser = ServiceBrowser(local_zc, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener)
+
+ # wait for responses
+ time.sleep(timeout)
+
+ browser.cancel()
+
+ # close down anything we opened
+ if zc is None:
+ local_zc.close()
+
+ return tuple(sorted(listener.found_services))
diff --git a/zeroconf/_updates.py b/zeroconf/_updates.py
new file mode 100644
index 00000000..bc7dcab5
--- /dev/null
+++ b/zeroconf/_updates.py
@@ -0,0 +1,76 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+from typing import List, NamedTuple, Optional, TYPE_CHECKING
+
+
+from ._dns import DNSRecord
+
+
+if TYPE_CHECKING:
+ from ._core import Zeroconf
+
+
+class RecordUpdate(NamedTuple):
+ new: DNSRecord
+ old: Optional[DNSRecord]
+
+
+class RecordUpdateListener:
+ def update_record( # pylint: disable=no-self-use
+ self, zc: 'Zeroconf', now: float, record: DNSRecord
+ ) -> None:
+ """Update a single record.
+
+ This method is deprecated and will be removed in a future version.
+ update_records should be implemented instead.
+ """
+ raise RuntimeError("update_record is deprecated and will be removed in a future version.")
+
+ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
+ """Update multiple records in one shot.
+
+ All records that are received in a single packet are passed
+ to update_records.
+
+ This implementation is a compatiblity shim to ensure older code
+ that uses RecordUpdateListener as a base class will continue to
+ get calls to update_record. This method will raise
+ NotImplementedError in a future version.
+
+ At this point the cache will not have the new records
+
+ Records are passed as a list of RecordUpdate. This
+ allows consumers of async_update_records to avoid cache lookups.
+
+ This method will be run in the event loop.
+ """
+ for record in records:
+ self.update_record(zc, now, record[0])
+
+ def async_update_records_complete(self) -> None:
+ """Called when a record update has completed for all handlers.
+
+ At this point the cache will have the new records.
+
+ This method will be run in the event loop.
+ """
diff --git a/zeroconf/_utils/__init__.py b/zeroconf/_utils/__init__.py
new file mode 100644
index 00000000..2ef4b15b
--- /dev/null
+++ b/zeroconf/_utils/__init__.py
@@ -0,0 +1,21 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
diff --git a/zeroconf/_utils/asyncio.py b/zeroconf/_utils/asyncio.py
new file mode 100644
index 00000000..10b8b3d9
--- /dev/null
+++ b/zeroconf/_utils/asyncio.py
@@ -0,0 +1,123 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import asyncio
+import contextlib
+import queue
+from typing import Any, Awaitable, Coroutine, List, Optional, Set, cast
+
+from .time import millis_to_seconds
+from ..const import _LOADED_SYSTEM_TIMEOUT
+
+# The combined timeouts should be lower than _CLOSE_TIMEOUT + _WAIT_FOR_LOOP_TASKS_TIMEOUT
+_TASK_AWAIT_TIMEOUT = 1
+_GET_ALL_TASKS_TIMEOUT = 3
+_WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT
+
+
+def get_best_available_queue() -> queue.Queue:
+ """Create the best available queue type."""
+ if hasattr(queue, "SimpleQueue"):
+ return queue.SimpleQueue() # type: ignore # pylint: disable=all
+ return queue.Queue()
+
+
+# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed
+async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None:
+ """Wait for an event or timeout."""
+ loop = asyncio.get_event_loop()
+ future = loop.create_future()
+
+ def _handle_timeout_or_wait_complete(*_: Any) -> None:
+ if not future.done():
+ future.set_result(None)
+
+ timer_handle = loop.call_later(timeout, _handle_timeout_or_wait_complete)
+ event_wait = loop.create_task(event.wait())
+ event_wait.add_done_callback(_handle_timeout_or_wait_complete)
+
+ try:
+ await future
+ finally:
+ timer_handle.cancel()
+ if not event_wait.done():
+ event_wait.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await event_wait
+
+
+async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> List[asyncio.Task]:
+ """Return all tasks running."""
+ await asyncio.sleep(0) # flush out any call_soon_threadsafe
+ # If there are multiple event loops running, all_tasks is not
+ # safe EVEN WHEN CALLED FROM THE EVENTLOOP
+ # under PyPy so we have to try a few times.
+ for _ in range(3):
+ with contextlib.suppress(RuntimeError):
+ if hasattr(asyncio, 'all_tasks'):
+ return asyncio.all_tasks(loop) # type: ignore # pylint: disable=no-member
+ return asyncio.Task.all_tasks(loop) # type: ignore # pylint: disable=no-member
+ return []
+
+
+async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None:
+ """Wait for the event loop thread we started to shutdown."""
+ await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT)
+
+
+async def await_awaitable(aw: Awaitable) -> None:
+ """Wait on an awaitable and the task it returns."""
+ task = await aw
+ await task
+
+
+def run_coro_with_timeout(aw: Coroutine, loop: asyncio.AbstractEventLoop, timeout: float) -> Any:
+ """Run a coroutine with a timeout."""
+ return asyncio.run_coroutine_threadsafe(aw, loop).result(
+ millis_to_seconds(timeout) + _LOADED_SYSTEM_TIMEOUT
+ )
+
+
+def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None:
+ """Wait for pending tasks and stop an event loop."""
+ pending_tasks = set(
+ asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result(_GET_ALL_TASKS_TIMEOUT)
+ )
+ pending_tasks -= set(task for task in pending_tasks if task.done())
+ if pending_tasks:
+ asyncio.run_coroutine_threadsafe(_wait_for_loop_tasks(pending_tasks), loop).result(
+ _WAIT_FOR_LOOP_TASKS_TIMEOUT
+ )
+ loop.call_soon_threadsafe(loop.stop)
+
+
+# Remove the call to _get_running_loop once we drop python 3.6 support
+def get_running_loop() -> Optional[asyncio.AbstractEventLoop]:
+ """Check if an event loop is already running."""
+ with contextlib.suppress(RuntimeError):
+ if hasattr(asyncio, "get_running_loop"):
+ return cast(
+ asyncio.AbstractEventLoop,
+ asyncio.get_running_loop(), # type: ignore # pylint: disable=no-member # noqa
+ )
+ return asyncio._get_running_loop() # pylint: disable=no-member,protected-access
+ return None
diff --git a/zeroconf/_utils/name.py b/zeroconf/_utils/name.py
new file mode 100644
index 00000000..c59ac33a
--- /dev/null
+++ b/zeroconf/_utils/name.py
@@ -0,0 +1,157 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+from .._exceptions import BadTypeInNameException
+from ..const import (
+ _HAS_ASCII_CONTROL_CHARS,
+ _HAS_A_TO_Z,
+ _HAS_ONLY_A_TO_Z_NUM_HYPHEN,
+ _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE,
+ _LOCAL_TRAILER,
+ _NONTCP_PROTOCOL_LOCAL_TRAILER,
+ _TCP_PROTOCOL_LOCAL_TRAILER,
+)
+
+
+def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: disable=too-many-branches
+ """
+ Validate a fully qualified service name, instance or subtype. [rfc6763]
+
+ Returns fully qualified service name.
+
+ Domain names used by mDNS-SD take the following forms:
+
+ . <_tcp|_udp> . local.
+ . . <_tcp|_udp> . local.
+ ._sub . . <_tcp|_udp> . local.
+
+ 1) must end with 'local.'
+
+ This is true because we are implementing mDNS and since the 'm' means
+ multi-cast, the 'local.' domain is mandatory.
+
+ 2) local is preceded with either '_udp.' or '_tcp.' unless
+ strict is False
+
+ 3) service name precedes <_tcp|_udp> unless
+ strict is False
+
+ The rules for Service Names [RFC6335] state that they may be no more
+ than fifteen characters long (not counting the mandatory underscore),
+ consisting of only letters, digits, and hyphens, must begin and end
+ with a letter or digit, must not contain consecutive hyphens, and
+ must contain at least one letter.
+
+ The instance name and sub type may be up to 63 bytes.
+
+ The portion of the Service Instance Name is a user-
+ friendly name consisting of arbitrary Net-Unicode text [RFC5198]. It
+ MUST NOT contain ASCII control characters (byte values 0x00-0x1F and
+ 0x7F) [RFC20] but otherwise is allowed to contain any characters,
+ without restriction, including spaces, uppercase, lowercase,
+ punctuation -- including dots -- accented characters, non-Roman text,
+ and anything else that may be represented using Net-Unicode.
+
+ :param type_: Type, SubType or service name to validate
+ :return: fully qualified service name (eg: _http._tcp.local.)
+ """
+ if len(type_) > 256:
+ # https://datatracker.ietf.org/doc/html/rfc6763#section-7.2
+ raise BadTypeInNameException("Full name (%s) must be > 256 bytes" % type_)
+
+ if type_.endswith((_TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER)):
+ remaining = type_[: -len(_TCP_PROTOCOL_LOCAL_TRAILER)].split('.')
+ trailer = type_[-len(_TCP_PROTOCOL_LOCAL_TRAILER) :]
+ has_protocol = True
+ elif strict:
+ raise BadTypeInNameException(
+ "Type '%s' must end with '%s' or '%s'"
+ % (type_, _TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER)
+ )
+ elif type_.endswith(_LOCAL_TRAILER):
+ remaining = type_[: -len(_LOCAL_TRAILER)].split('.')
+ trailer = type_[-len(_LOCAL_TRAILER) + 1 :]
+ has_protocol = False
+ else:
+ raise BadTypeInNameException("Type '%s' must end with '%s'" % (type_, _LOCAL_TRAILER))
+
+ if strict or has_protocol:
+ service_name = remaining.pop()
+ if not service_name:
+ raise BadTypeInNameException("No Service name found")
+
+ if len(remaining) == 1 and len(remaining[0]) == 0:
+ raise BadTypeInNameException("Type '%s' must not start with '.'" % type_)
+
+ if service_name[0] != '_':
+ raise BadTypeInNameException("Service name (%s) must start with '_'" % service_name)
+
+ test_service_name = service_name[1:]
+
+ if strict and len(test_service_name) > 15:
+ # https://datatracker.ietf.org/doc/html/rfc6763#section-7.2
+ raise BadTypeInNameException("Service name (%s) must be <= 15 bytes" % test_service_name)
+
+ if '--' in test_service_name:
+ raise BadTypeInNameException("Service name (%s) must not contain '--'" % test_service_name)
+
+ if '-' in (test_service_name[0], test_service_name[-1]):
+ raise BadTypeInNameException(
+ "Service name (%s) may not start or end with '-'" % test_service_name
+ )
+
+ if not _HAS_A_TO_Z.search(test_service_name):
+ raise BadTypeInNameException(
+ "Service name (%s) must contain at least one letter (eg: 'A-Z')" % test_service_name
+ )
+
+ allowed_characters_re = (
+ _HAS_ONLY_A_TO_Z_NUM_HYPHEN if strict else _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE
+ )
+
+ if not allowed_characters_re.search(test_service_name):
+ raise BadTypeInNameException(
+ "Service name (%s) must contain only these characters: "
+ "A-Z, a-z, 0-9, hyphen ('-')%s" % (test_service_name, "" if strict else ", underscore ('_')")
+ )
+ else:
+ service_name = ''
+
+ if remaining and remaining[-1] == '_sub':
+ remaining.pop()
+ if len(remaining) == 0 or len(remaining[0]) == 0:
+ raise BadTypeInNameException("_sub requires a subtype name")
+
+ if len(remaining) > 1:
+ remaining = ['.'.join(remaining)]
+
+ if remaining:
+ length = len(remaining[0].encode('utf-8'))
+ if length > 63:
+ raise BadTypeInNameException("Too long: '%s'" % remaining[0])
+
+ if _HAS_ASCII_CONTROL_CHARS.search(remaining[0]):
+ raise BadTypeInNameException(
+ "Ascii control character 0x00-0x1F and 0x7F illegal in '%s'" % remaining[0]
+ )
+
+ return service_name + trailer
diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py
new file mode 100644
index 00000000..c53ec978
--- /dev/null
+++ b/zeroconf/_utils/net.py
@@ -0,0 +1,404 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import enum
+import errno
+import ipaddress
+import socket
+import struct
+import sys
+from typing import Any, List, Optional, Tuple, Union, cast
+
+import ifaddr
+
+from .._logger import log
+from ..const import _IPPROTO_IPV6, _MDNS_ADDR, _MDNS_ADDR6, _MDNS_PORT
+
+
+@enum.unique
+class InterfaceChoice(enum.Enum):
+ Default = 1
+ All = 2
+
+
+InterfacesType = Union[List[Union[str, int, Tuple[Tuple[str, int, int], int]]], InterfaceChoice]
+
+
+@enum.unique
+class ServiceStateChange(enum.Enum):
+ Added = 1
+ Removed = 2
+ Updated = 3
+
+
+@enum.unique
+class IPVersion(enum.Enum):
+ V4Only = 1
+ V6Only = 2
+ All = 3
+
+
+# utility functions
+
+
+def _is_v6_address(addr: bytes) -> bool:
+ return len(addr) == 16
+
+
+def _encode_address(address: str) -> bytes:
+ is_ipv6 = ':' in address
+ address_family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
+ return socket.inet_pton(address_family, address)
+
+
+def get_all_addresses() -> List[str]:
+ return list(set(addr.ip for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4))
+
+
+def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]:
+ # IPv6 multicast uses positive indexes for interfaces
+ # TODO: What about multi-address interfaces?
+ return list(
+ set((addr.ip, iface.index) for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv6)
+ )
+
+
+def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, int, int], int]:
+ ipaddr = ipaddress.ip_address(ip)
+ for adapter in adapters:
+ for adapter_ip in adapter.ips:
+ # IPv6 addresses are represented as tuples
+ if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr:
+ return (cast(Tuple[str, int, int], adapter_ip.ip), cast(int, adapter.index))
+
+ raise RuntimeError('No adapter found for IP address %s' % ip)
+
+
+def interface_index_to_ip6_address(adapters: List[Any], index: int) -> Tuple[str, int, int]:
+ for adapter in adapters:
+ if adapter.index == index:
+ for adapter_ip in adapter.ips:
+ # IPv6 addresses are represented as tuples
+ if isinstance(adapter_ip.ip, tuple):
+ return cast(Tuple[str, int, int], adapter_ip.ip)
+
+ raise RuntimeError('No adapter found for index %s' % index)
+
+
+def ip6_addresses_to_indexes(
+ interfaces: List[Union[str, int, Tuple[Tuple[str, int, int], int]]]
+) -> List[Tuple[Tuple[str, int, int], int]]:
+ """Convert IPv6 interface addresses to interface indexes.
+
+ IPv4 addresses are ignored.
+
+ :param interfaces: List of IP addresses and indexes.
+ :returns: List of indexes.
+ """
+ result = []
+ adapters = ifaddr.get_adapters()
+
+ for iface in interfaces:
+ if isinstance(iface, int):
+ result.append((interface_index_to_ip6_address(adapters, iface), iface))
+ elif isinstance(iface, str) and ipaddress.ip_address(iface).version == 6:
+ result.append(ip6_to_address_and_index(adapters, iface))
+
+ return result
+
+
+def normalize_interface_choice(
+ choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only
+) -> List[Union[str, Tuple[Tuple[str, int, int], int]]]:
+ """Convert the interfaces choice into internal representation.
+
+ :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only).
+ :param ip_address: IP version to use (ignored if `choice` is a list).
+ :returns: List of IP addresses (for IPv4) and indexes (for IPv6).
+ """
+ result: List[Union[str, Tuple[Tuple[str, int, int], int]]] = []
+ if choice is InterfaceChoice.Default:
+ if ip_version != IPVersion.V4Only:
+ # IPv6 multicast uses interface 0 to mean the default
+ result.append((('', 0, 0), 0))
+ if ip_version != IPVersion.V6Only:
+ result.append('0.0.0.0')
+ elif choice is InterfaceChoice.All:
+ if ip_version != IPVersion.V4Only:
+ result.extend(get_all_addresses_v6())
+ if ip_version != IPVersion.V6Only:
+ result.extend(get_all_addresses())
+ if not result:
+ raise RuntimeError(
+ 'No interfaces to listen on, check that any interfaces have IP version %s' % ip_version
+ )
+ elif isinstance(choice, list):
+ # First, take IPv4 addresses.
+ result = [i for i in choice if isinstance(i, str) and ipaddress.ip_address(i).version == 4]
+ # Unlike IP_ADD_MEMBERSHIP, IPV6_JOIN_GROUP requires interface indexes.
+ result += ip6_addresses_to_indexes(choice)
+ else:
+ raise TypeError("choice must be a list or InterfaceChoice, got %r" % choice)
+ return result
+
+
+def disable_ipv6_only_or_raise(s: socket.socket) -> None:
+ """Make V6 sockets work for both V4 and V6 (required for Windows)."""
+ try:
+ s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
+ except OSError:
+ log.error('Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6')
+ raise
+
+
+def set_so_reuseport_if_available(s: socket.socket) -> None:
+ """Set SO_REUSEADDR on a socket if available."""
+ # SO_REUSEADDR should be equivalent to SO_REUSEPORT for
+ # multicast UDP sockets (p 731, "TCP/IP Illustrated,
+ # Volume 2"), but some BSD-derived systems require
+ # SO_REUSEPORT to be specified explicitly. Also, not all
+ # versions of Python have SO_REUSEPORT available.
+ # Catch OSError and socket.error for kernel versions <3.9 because lacking
+ # SO_REUSEPORT support.
+ if not hasattr(socket, 'SO_REUSEPORT'):
+ return
+
+ try:
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # pylint: disable=no-member
+ except OSError as err:
+ if err.errno != errno.ENOPROTOOPT:
+ raise
+
+
+def set_mdns_port_socket_options_for_ip_version(
+ s: socket.socket, bind_addr: Union[Tuple[str], Tuple[str, int, int]], ip_version: IPVersion
+) -> None:
+ """Set ttl/hops and loop for mdns port."""
+ if ip_version != IPVersion.V6Only:
+ ttl = struct.pack(b'B', 255)
+ loop = struct.pack(b'B', 1)
+ # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and
+ # IP_MULTICAST_LOOP socket options as an unsigned char.
+ try:
+ s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl)
+ s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop)
+ except socket.error as e:
+ if bind_addr[0] != '' or get_errno(e) != errno.EINVAL: # Fails to set on MacOS
+ raise
+
+ if ip_version != IPVersion.V4Only:
+ # However, char doesn't work here (at least on Linux)
+ s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255)
+ s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True)
+
+
+def new_socket(
+ bind_addr: Union[Tuple[str], Tuple[str, int, int]],
+ port: int = _MDNS_PORT,
+ ip_version: IPVersion = IPVersion.V4Only,
+ apple_p2p: bool = False,
+) -> socket.socket:
+ log.debug(
+ 'Creating new socket with port %s, ip_version %s, apple_p2p %s and bind_addr %r',
+ port,
+ ip_version,
+ apple_p2p,
+ bind_addr,
+ )
+ socket_family = socket.AF_INET if ip_version == IPVersion.V4Only else socket.AF_INET6
+ s = socket.socket(socket_family, socket.SOCK_DGRAM)
+
+ if ip_version == IPVersion.All:
+ disable_ipv6_only_or_raise(s)
+
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ set_so_reuseport_if_available(s)
+
+ if port == _MDNS_PORT:
+ set_mdns_port_socket_options_for_ip_version(s, bind_addr, ip_version)
+
+ if apple_p2p:
+ # SO_RECV_ANYIF = 0x1104
+ # https://opensource.apple.com/source/xnu/xnu-4570.41.2/bsd/sys/socket.h
+ s.setsockopt(socket.SOL_SOCKET, 0x1104, 1)
+
+ s.bind((bind_addr[0], port, *bind_addr[1:]))
+ log.debug('Created socket %s', s)
+ return s
+
+
+def add_multicast_member(
+ listen_socket: socket.socket,
+ interface: Union[str, Tuple[Tuple[str, int, int], int]],
+) -> bool:
+ # This is based on assumptions in normalize_interface_choice
+ is_v6 = isinstance(interface, tuple)
+ err_einval = {errno.EINVAL}
+ if sys.platform == 'win32':
+ # No WSAEINVAL definition in typeshed
+ err_einval |= {cast(Any, errno).WSAEINVAL} # pylint: disable=no-member
+ log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno())
+ try:
+ if is_v6:
+ try:
+ mdns_addr6_bytes = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6)
+ except OSError:
+ log.info(
+ 'Unable to translate IPv6 address when adding %s to multicast group, '
+ 'this can happen if IPv6 is disabled on the system',
+ interface,
+ )
+ return False
+ iface_bin = struct.pack('@I', cast(int, interface[1]))
+ _value = mdns_addr6_bytes + iface_bin
+ listen_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, _value)
+ else:
+ _value = socket.inet_aton(_MDNS_ADDR) + socket.inet_aton(cast(str, interface))
+ listen_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value)
+ except socket.error as e:
+ _errno = get_errno(e)
+ if _errno == errno.EADDRINUSE:
+ log.info(
+ 'Address in use when adding %s to multicast group, '
+ 'it is expected to happen on some systems',
+ interface,
+ )
+ return False
+ if _errno == errno.EADDRNOTAVAIL:
+ log.info(
+ 'Address not available when adding %s to multicast '
+ 'group, it is expected to happen on some systems',
+ interface,
+ )
+ return False
+ if _errno in err_einval:
+ log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface)
+ return False
+ if _errno == errno.ENOPROTOOPT:
+ log.info(
+ 'Failed to set socket option on %s, this can happen if '
+ 'the network adapter is in a disconnected state',
+ interface,
+ )
+ return False
+ if is_v6 and _errno == errno.ENODEV:
+ log.info(
+ 'Address in use when adding %s to multicast group, '
+ 'it is expected to happen when the device does not have ipv6',
+ interface,
+ )
+ return False
+ raise
+ return True
+
+
+def new_respond_socket(
+ interface: Union[str, Tuple[Tuple[str, int, int], int]],
+ apple_p2p: bool = False,
+) -> Optional[socket.socket]:
+ is_v6 = isinstance(interface, tuple)
+ respond_socket = new_socket(
+ ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only),
+ apple_p2p=apple_p2p,
+ bind_addr=cast(Tuple[Tuple[str, int, int], int], interface)[0] if is_v6 else (cast(str, interface),),
+ )
+ log.debug('Configuring socket %s with multicast interface %s', respond_socket, interface)
+ if is_v6:
+ iface_bin = struct.pack('@I', cast(int, interface[1]))
+ respond_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, iface_bin)
+ else:
+ respond_socket.setsockopt(
+ socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(cast(str, interface))
+ )
+ return respond_socket
+
+
+def create_sockets(
+ interfaces: InterfacesType = InterfaceChoice.All,
+ unicast: bool = False,
+ ip_version: IPVersion = IPVersion.V4Only,
+ apple_p2p: bool = False,
+) -> Tuple[Optional[socket.socket], List[socket.socket]]:
+ if unicast:
+ listen_socket = None
+ else:
+ listen_socket = new_socket(ip_version=ip_version, apple_p2p=apple_p2p, bind_addr=('',))
+
+ normalized_interfaces = normalize_interface_choice(interfaces, ip_version)
+
+ # If we are using InterfaceChoice.Default we can use
+ # a single socket to listen and respond.
+ if not unicast and interfaces is InterfaceChoice.Default:
+ for i in normalized_interfaces:
+ add_multicast_member(cast(socket.socket, listen_socket), i)
+ return listen_socket, [cast(socket.socket, listen_socket)]
+
+ respond_sockets = []
+
+ for i in normalized_interfaces:
+ if not unicast:
+ if add_multicast_member(cast(socket.socket, listen_socket), i):
+ respond_socket = new_respond_socket(i, apple_p2p=apple_p2p)
+ else:
+ respond_socket = None
+ else:
+ respond_socket = new_socket(
+ port=0,
+ ip_version=ip_version,
+ apple_p2p=apple_p2p,
+ bind_addr=i[0] if isinstance(i, tuple) else (i,),
+ )
+
+ if respond_socket is not None:
+ respond_sockets.append(respond_socket)
+
+ return listen_socket, respond_sockets
+
+
+def get_errno(e: Exception) -> int:
+ assert isinstance(e, socket.error)
+ return cast(int, e.args[0])
+
+
+def can_send_to(ipv6_socket: bool, address: str) -> bool:
+ """Check if the address type matches the socket type.
+
+ This function does not validate if the address is a valid
+ ipv6 or ipv4 address.
+ """
+ return ":" in address if ipv6_socket else ":" not in address
+
+
+def autodetect_ip_version(interfaces: InterfacesType) -> IPVersion:
+ """Auto detect the IP version when it is not provided."""
+ if isinstance(interfaces, list):
+ has_v6 = any(
+ isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6)
+ for i in interfaces
+ )
+ has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces)
+ if has_v4 and has_v6:
+ return IPVersion.All
+ if has_v6:
+ return IPVersion.V6Only
+
+ return IPVersion.V4Only
diff --git a/zeroconf/_utils/time.py b/zeroconf/_utils/time.py
new file mode 100644
index 00000000..0ba91ead
--- /dev/null
+++ b/zeroconf/_utils/time.py
@@ -0,0 +1,34 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+
+import time
+
+
+def current_time_millis() -> float:
+ """Current system time in milliseconds"""
+ return time.time() * 1000
+
+
+def millis_to_seconds(millis: float) -> float:
+ """Convert milliseconds to seconds."""
+ return millis / 1000.0
diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py
new file mode 100644
index 00000000..ef7e7f64
--- /dev/null
+++ b/zeroconf/asyncio.py
@@ -0,0 +1,267 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+import asyncio
+import contextlib
+from types import TracebackType # noqa # used in type hints
+from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
+
+from ._core import Zeroconf
+from ._dns import DNSQuestionType
+from ._services import ServiceListener
+from ._services.browser import _ServiceBrowserBase
+from ._services.info import ServiceInfo
+from ._services.types import ZeroconfServiceTypes
+from ._utils.net import IPVersion, InterfaceChoice, InterfacesType
+from .const import (
+ _BROWSER_TIME,
+ _MDNS_PORT,
+ _SERVICE_TYPE_ENUMERATION_NAME,
+)
+
+
+__all__ = [
+ "AsyncZeroconf",
+ "AsyncServiceInfo",
+ "AsyncServiceBrowser",
+ "AsyncZeroconfServiceTypes",
+]
+
+
+class AsyncServiceInfo(ServiceInfo):
+ """An async version of ServiceInfo."""
+
+
+class AsyncServiceBrowser(_ServiceBrowserBase):
+ """Used to browse for a service for specific type(s).
+
+ Constructor parameters are as follows:
+
+ * `zc`: A Zeroconf instance
+ * `type_`: fully qualified service type name
+ * `handler`: ServiceListener or Callable that knows how to process ServiceStateChange events
+ * `listener`: ServiceListener
+ * `addr`: address to send queries (will default to multicast)
+ * `port`: port to send queries (will default to mdns 5353)
+ * `delay`: The initial delay between answering questions
+ * `question_type`: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU)
+
+ The listener object will have its add_service() and
+ remove_service() methods called when this browser
+ discovers changes in the services availability.
+ """
+
+ def __init__(
+ self,
+ zeroconf: 'Zeroconf',
+ type_: Union[str, list],
+ handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None,
+ listener: Optional[ServiceListener] = None,
+ addr: Optional[str] = None,
+ port: int = _MDNS_PORT,
+ delay: int = _BROWSER_TIME,
+ question_type: Optional[DNSQuestionType] = None,
+ ) -> None:
+ super().__init__(zeroconf, type_, handlers, listener, addr, port, delay, question_type)
+ self._async_start()
+
+ async def async_cancel(self) -> None:
+ """Cancel the browser."""
+ self._async_cancel()
+
+
+class AsyncZeroconfServiceTypes(ZeroconfServiceTypes):
+ """An async version of ZeroconfServiceTypes."""
+
+ @classmethod
+ async def async_find(
+ cls,
+ aiozc: Optional['AsyncZeroconf'] = None,
+ timeout: Union[int, float] = 5,
+ interfaces: InterfacesType = InterfaceChoice.All,
+ ip_version: Optional[IPVersion] = None,
+ ) -> Tuple[str, ...]:
+ """
+ Return all of the advertised services on any local networks.
+
+ :param aiozc: AsyncZeroconf() instance. Pass in if already have an
+ instance running or if non-default interfaces are needed
+ :param timeout: seconds to wait for any responses
+ :param interfaces: interfaces to listen on.
+ :param ip_version: IP protocol version to use.
+ :return: tuple of service type strings
+ """
+ local_zc = aiozc or AsyncZeroconf(interfaces=interfaces, ip_version=ip_version)
+ listener = cls()
+ async_browser = AsyncServiceBrowser(
+ local_zc.zeroconf, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener
+ )
+
+ # wait for responses
+ await asyncio.sleep(timeout)
+
+ await async_browser.async_cancel()
+
+ # close down anything we opened
+ if aiozc is None:
+ await local_zc.async_close()
+
+ return tuple(sorted(listener.found_services))
+
+
+class AsyncZeroconf:
+ """Implementation of Zeroconf Multicast DNS Service Discovery
+
+ Supports registration, unregistration, queries and browsing.
+
+ The async version is currently a wrapper around the sync version
+ with I/O being done in the executor for backwards compatibility.
+ """
+
+ def __init__(
+ self,
+ interfaces: InterfacesType = InterfaceChoice.All,
+ unicast: bool = False,
+ ip_version: Optional[IPVersion] = None,
+ apple_p2p: bool = False,
+ zc: Optional[Zeroconf] = None,
+ ) -> None:
+ """Creates an instance of the Zeroconf class, establishing
+ multicast communications, listening and reaping threads.
+
+ :param interfaces: :class:`InterfaceChoice` or a list of IP addresses
+ (IPv4 and IPv6) and interface indexes (IPv6 only).
+
+ IPv6 notes for non-POSIX systems:
+ * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default`
+ on Python versions before 3.8.
+
+ Also listening on loopback (``::1``) doesn't work, use a real address.
+ :param ip_version: IP versions to support. If `choice` is a list, the default is detected
+ from it. Otherwise defaults to V4 only for backward compatibility.
+ :param apple_p2p: use AWDL interface (only macOS)
+ """
+ self.zeroconf = zc or Zeroconf(
+ interfaces=interfaces,
+ unicast=unicast,
+ ip_version=ip_version,
+ apple_p2p=apple_p2p,
+ )
+ self.async_browsers: Dict[ServiceListener, AsyncServiceBrowser] = {}
+
+ async def async_register_service(
+ self,
+ info: ServiceInfo,
+ ttl: Optional[int] = None,
+ allow_name_change: bool = False,
+ cooperating_responders: bool = False,
+ ) -> Awaitable:
+ """Registers service information to the network with a default TTL.
+ Zeroconf will then respond to requests for information for that
+ service. The name of the service may be changed if needed to make
+ it unique on the network. Additionally multiple cooperating responders
+ can register the same service on the network for resilience
+ (if you want this behavior set `cooperating_responders` to `True`).
+
+ The service will be broadcast in a task. This task is returned
+ and therefore can be awaited if necessary.
+ """
+ return await self.zeroconf.async_register_service(
+ info, ttl, allow_name_change, cooperating_responders
+ )
+
+ async def async_unregister_all_services(self) -> None:
+ """Unregister all registered services.
+
+ Unlike async_register_service and async_unregister_service, this
+ method does not return a future and is always expected to be
+ awaited since its only called at shutdown.
+ """
+ await self.zeroconf.async_unregister_all_services()
+
+ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
+ """Unregister a service.
+
+ The service will be broadcast in a task. This task is returned
+ and therefore can be awaited if necessary.
+ """
+ return await self.zeroconf.async_unregister_service(info)
+
+ async def async_update_service(self, info: ServiceInfo) -> Awaitable:
+ """Registers service information to the network with a default TTL.
+ Zeroconf will then respond to requests for information for that
+ service.
+
+ The service will be broadcast in a task. This task is returned
+ and therefore can be awaited if necessary.
+ """
+ return await self.zeroconf.async_update_service(info)
+
+ async def async_close(self) -> None:
+ """Ends the background threads, and prevent this instance from
+ servicing further queries."""
+ with contextlib.suppress(asyncio.TimeoutError):
+ await asyncio.wait_for(self.zeroconf.async_wait_for_start(), timeout=1)
+ await self.async_remove_all_service_listeners()
+ await self.async_unregister_all_services()
+ await self.zeroconf._async_close() # pylint: disable=protected-access
+
+ async def async_get_service_info(
+ self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None
+ ) -> Optional[AsyncServiceInfo]:
+ """Returns network's service information for a particular
+ name and type, or None if no service matches by the timeout,
+ which defaults to 3 seconds."""
+ info = AsyncServiceInfo(type_, name)
+ if await info.async_request(self.zeroconf, timeout, question_type):
+ return info
+ return None
+
+ async def async_add_service_listener(self, type_: str, listener: ServiceListener) -> None:
+ """Adds a listener for a particular service type. This object
+ will then have its add_service and remove_service methods called when
+ services of that type become available and unavailable."""
+ await self.async_remove_service_listener(listener)
+ self.async_browsers[listener] = AsyncServiceBrowser(self.zeroconf, type_, listener)
+
+ async def async_remove_service_listener(self, listener: ServiceListener) -> None:
+ """Removes a listener from the set that is currently listening."""
+ if listener in self.async_browsers:
+ await self.async_browsers[listener].async_cancel()
+ del self.async_browsers[listener]
+
+ async def async_remove_all_service_listeners(self) -> None:
+ """Removes a listener from the set that is currently listening."""
+ await asyncio.gather(
+ *(self.async_remove_service_listener(listener) for listener in list(self.async_browsers))
+ )
+
+ async def __aenter__(self) -> 'AsyncZeroconf':
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> Optional[bool]:
+ await self.async_close()
+ return None
diff --git a/zeroconf/const.py b/zeroconf/const.py
new file mode 100644
index 00000000..ff5cc3a2
--- /dev/null
+++ b/zeroconf/const.py
@@ -0,0 +1,159 @@
+""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine
+ Copyright 2003 Paul Scott-Murphy, 2014 William McBrine
+
+ This module provides a framework for the use of DNS Service Discovery
+ using IP multicast.
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
+ USA
+"""
+
+import re
+import socket
+
+# Some timing constants
+
+_UNREGISTER_TIME = 125 # ms
+_CHECK_TIME = 175 # ms
+_REGISTER_TIME = 225 # ms
+_LISTENER_TIME = 200 # ms
+_BROWSER_TIME = 1000 # ms
+_DUPLICATE_QUESTION_INTERVAL = _BROWSER_TIME - 1 # ms
+_BROWSER_BACKOFF_LIMIT = 3600 # s
+_CACHE_CLEANUP_INTERVAL = 10000 # ms
+_LOADED_SYSTEM_TIMEOUT = 10 # s
+_ONE_SECOND = 1000 # ms
+
+# If the system is loaded or the event
+# loop was blocked by another task that was doing I/O in the loop
+# (shouldn't happen but it does in practice) we need to give
+# a buffer timeout to ensure a coroutine can finish before
+# the future times out
+
+# Some DNS constants
+
+_MDNS_ADDR = '224.0.0.251'
+_MDNS_ADDR6 = 'ff02::fb'
+_MDNS_PORT = 5353
+_DNS_PORT = 53
+_DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762
+_DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762
+# Currently we enforce a minimum TTL for PTR records to avoid
+# ServiceBrowsers generating excessive queries refresh queries.
+# Apple uses a 15s minimum TTL, however we do not have the same
+# level of rate limit and safe guards so we use 1/4 of the recommended value
+_DNS_PTR_MIN_TTL = _DNS_OTHER_TTL / 4
+
+_DNS_PACKET_HEADER_LEN = 12
+
+_MAX_MSG_TYPICAL = 1460 # unused
+_MAX_MSG_ABSOLUTE = 8966
+
+_FLAGS_QR_MASK = 0x8000 # query response mask
+_FLAGS_QR_QUERY = 0x0000 # query
+_FLAGS_QR_RESPONSE = 0x8000 # response
+
+_FLAGS_AA = 0x0400 # Authoritative answer
+_FLAGS_TC = 0x0200 # Truncated
+_FLAGS_RD = 0x0100 # Recursion desired
+_FLAGS_RA = 0x8000 # Recursion available
+
+_FLAGS_Z = 0x0040 # Zero
+_FLAGS_AD = 0x0020 # Authentic data
+_FLAGS_CD = 0x0010 # Checking disabled
+
+_CLASS_IN = 1
+_CLASS_CS = 2
+_CLASS_CH = 3
+_CLASS_HS = 4
+_CLASS_NONE = 254
+_CLASS_ANY = 255
+_CLASS_MASK = 0x7FFF
+_CLASS_UNIQUE = 0x8000
+
+_TYPE_A = 1
+_TYPE_NS = 2
+_TYPE_MD = 3
+_TYPE_MF = 4
+_TYPE_CNAME = 5
+_TYPE_SOA = 6
+_TYPE_MB = 7
+_TYPE_MG = 8
+_TYPE_MR = 9
+_TYPE_NULL = 10
+_TYPE_WKS = 11
+_TYPE_PTR = 12
+_TYPE_HINFO = 13
+_TYPE_MINFO = 14
+_TYPE_MX = 15
+_TYPE_TXT = 16
+_TYPE_AAAA = 28
+_TYPE_SRV = 33
+_TYPE_NSEC = 47
+_TYPE_ANY = 255
+
+# Mapping constants to names
+
+_CLASSES = {
+ _CLASS_IN: "in",
+ _CLASS_CS: "cs",
+ _CLASS_CH: "ch",
+ _CLASS_HS: "hs",
+ _CLASS_NONE: "none",
+ _CLASS_ANY: "any",
+}
+
+_TYPES = {
+ _TYPE_A: "a",
+ _TYPE_NS: "ns",
+ _TYPE_MD: "md",
+ _TYPE_MF: "mf",
+ _TYPE_CNAME: "cname",
+ _TYPE_SOA: "soa",
+ _TYPE_MB: "mb",
+ _TYPE_MG: "mg",
+ _TYPE_MR: "mr",
+ _TYPE_NULL: "null",
+ _TYPE_WKS: "wks",
+ _TYPE_PTR: "ptr",
+ _TYPE_HINFO: "hinfo",
+ _TYPE_MINFO: "minfo",
+ _TYPE_MX: "mx",
+ _TYPE_TXT: "txt",
+ _TYPE_AAAA: "quada",
+ _TYPE_SRV: "srv",
+ _TYPE_ANY: "any",
+ _TYPE_NSEC: "nsec",
+}
+
+_HAS_A_TO_Z = re.compile(r'[A-Za-z]')
+_HAS_ONLY_A_TO_Z_NUM_HYPHEN = re.compile(r'^[A-Za-z0-9\-]+$')
+_HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$')
+_HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]')
+
+_EXPIRE_REFRESH_TIME_PERCENT = 75
+
+_LOCAL_TRAILER = '.local.'
+_TCP_PROTOCOL_LOCAL_TRAILER = '._tcp.local.'
+_NONTCP_PROTOCOL_LOCAL_TRAILER = '._udp.local.'
+
+# https://datatracker.ietf.org/doc/html/rfc6763#section-9
+_SERVICE_TYPE_ENUMERATION_NAME = "_services._dns-sd._udp.local."
+
+try:
+ _IPPROTO_IPV6 = socket.IPPROTO_IPV6
+except AttributeError:
+ # Sigh: https://bugs.python.org/issue29515
+ _IPPROTO_IPV6 = 41
diff --git a/zeroconf/test.py b/zeroconf/test.py
deleted file mode 100644
index 5a89c9f6..00000000
--- a/zeroconf/test.py
+++ /dev/null
@@ -1,1370 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-
-""" Unit tests for zeroconf.py """
-
-import copy
-import logging
-import socket
-import struct
-import time
-import unittest
-from threading import Event
-from typing import Dict, Optional # noqa # used in type hints
-from typing import cast
-
-from nose.plugins.attrib import attr
-
-import zeroconf as r
-from zeroconf import (
- DNSHinfo,
- DNSText,
- ServiceBrowser,
- ServiceInfo,
- ServiceStateChange,
- Zeroconf,
- ZeroconfServiceTypes,
-)
-
-log = logging.getLogger('zeroconf')
-original_logging_level = logging.NOTSET
-
-
-def setup_module():
- global original_logging_level
- original_logging_level = log.level
- log.setLevel(logging.DEBUG)
-
-
-def teardown_module():
- if original_logging_level != logging.NOTSET:
- log.setLevel(original_logging_level)
-
-
-class TestDunder(unittest.TestCase):
- def test_dns_text_repr(self):
- # There was an issue on Python 3 that prevented DNSText's repr
- # from working when the text was longer than 10 bytes
- text = DNSText('irrelevant', 0, 0, 0, b'12345678901')
- repr(text)
-
- text = DNSText('irrelevant', 0, 0, 0, b'123')
- repr(text)
-
- def test_dns_hinfo_repr_eq(self):
- hinfo = DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os')
- assert hinfo == hinfo
- repr(hinfo)
-
- def test_dns_pointer_repr(self):
- pointer = r.DNSPointer('irrelevant', r._TYPE_PTR, r._CLASS_IN, r._DNS_OTHER_TTL, '123')
- repr(pointer)
-
- def test_dns_address_repr(self):
- address = r.DNSAddress('irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, b'a')
- repr(address)
-
- def test_dns_question_repr(self):
- question = r.DNSQuestion('irrelevant', r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE)
- repr(question)
- assert not question != question
-
- def test_dns_service_repr(self):
- service = r.DNSService('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, 'a')
- repr(service)
-
- def test_dns_record_abc(self):
- record = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL)
- self.assertRaises(r.AbstractMethodException, record.__eq__, record)
- self.assertRaises(r.AbstractMethodException, record.write, None)
-
- def test_dns_record_reset_ttl(self):
- record = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL)
- time.sleep(1)
- record2 = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL)
- now = r.current_time_millis()
-
- assert record.created != record2.created
- assert record.get_remaining_ttl(now) != record2.get_remaining_ttl(now)
-
- record.reset_ttl(record2)
-
- assert record.ttl == record2.ttl
- assert record.created == record2.created
- assert record.get_remaining_ttl(now) == record2.get_remaining_ttl(now)
-
- def test_service_info_dunder(self):
- type_ = "_test-srvc-type._tcp.local."
- name = "xxxyyy"
- registration_name = "%s.%s" % (name, type_)
- info = ServiceInfo(
- type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, b'', "ash-2.local."
- )
-
- assert not info != info
- repr(info)
-
- def test_service_info_text_properties_not_given(self):
- type_ = "_test-srvc-type._tcp.local."
- name = "xxxyyy"
- registration_name = "%s.%s" % (name, type_)
- info = ServiceInfo(
- type_=type_,
- name=registration_name,
- address=socket.inet_aton("10.0.1.2"),
- port=80,
- server="ash-2.local.",
- )
-
- assert isinstance(info.text, bytes)
- repr(info)
-
- def test_dns_outgoing_repr(self):
- dns_outgoing = r.DNSOutgoing(r._FLAGS_QR_QUERY)
- repr(dns_outgoing)
-
-
-class PacketGeneration(unittest.TestCase):
- def test_parse_own_packet_simple(self):
- generated = r.DNSOutgoing(0)
- r.DNSIncoming(generated.packet())
-
- def test_parse_own_packet_simple_unicast(self):
- generated = r.DNSOutgoing(0, False)
- r.DNSIncoming(generated.packet())
-
- def test_parse_own_packet_flags(self):
- generated = r.DNSOutgoing(r._FLAGS_QR_QUERY)
- r.DNSIncoming(generated.packet())
-
- def test_parse_own_packet_question(self):
- generated = r.DNSOutgoing(r._FLAGS_QR_QUERY)
- generated.add_question(r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN))
- r.DNSIncoming(generated.packet())
-
- def test_parse_own_packet_response(self):
- generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
- generated.add_answer_at_time(
- r.DNSService("æøå.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local."), 0
- )
- parsed = r.DNSIncoming(generated.packet())
- self.assertEqual(len(generated.answers), 1)
- self.assertEqual(len(generated.answers), len(parsed.answers))
-
- def test_match_question(self):
- generated = r.DNSOutgoing(r._FLAGS_QR_QUERY)
- question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN)
- generated.add_question(question)
- parsed = r.DNSIncoming(generated.packet())
- self.assertEqual(len(generated.questions), 1)
- self.assertEqual(len(generated.questions), len(parsed.questions))
- self.assertEqual(question, parsed.questions[0])
-
- def test_suppress_answer(self):
- query_generated = r.DNSOutgoing(r._FLAGS_QR_QUERY)
- question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN)
- query_generated.add_question(question)
- answer1 = r.DNSService(
- "testname1.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local."
- )
- staleanswer2 = r.DNSService(
- "testname2.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL / 2, 0, 0, 80, "foo.local."
- )
- answer2 = r.DNSService(
- "testname2.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local."
- )
- query_generated.add_answer_at_time(answer1, 0)
- query_generated.add_answer_at_time(staleanswer2, 0)
- query = r.DNSIncoming(query_generated.packet())
-
- # Should be suppressed
- response = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
- response.add_answer(query, answer1)
- assert len(response.answers) == 0
-
- # Should not be suppressed, TTL in query is too short
- response.add_answer(query, answer2)
- assert len(response.answers) == 1
-
- # Should not be suppressed, name is different
- tmp = copy.copy(answer1)
- tmp.name = "testname3.local."
- response.add_answer(query, tmp)
- assert len(response.answers) == 2
-
- # Should not be suppressed, type is different
- tmp = copy.copy(answer1)
- tmp.type = r._TYPE_A
- response.add_answer(query, tmp)
- assert len(response.answers) == 3
-
- # Should not be suppressed, class is different
- tmp = copy.copy(answer1)
- tmp.class_ = r._CLASS_NONE
- response.add_answer(query, tmp)
- assert len(response.answers) == 4
-
- # ::TODO:: could add additional tests for DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService
-
- def test_dns_hinfo(self):
- generated = r.DNSOutgoing(0)
- generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os'))
- parsed = r.DNSIncoming(generated.packet())
- answer = cast(r.DNSHinfo, parsed.answers[0])
- self.assertEqual(answer.cpu, u'cpu')
- self.assertEqual(answer.os, u'os')
-
- generated = r.DNSOutgoing(0)
- generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257))
- self.assertRaises(r.NamePartTooLongException, generated.packet)
-
-
-class PacketForm(unittest.TestCase):
- def test_transaction_id(self):
- """ID must be zero in a DNS-SD packet"""
- generated = r.DNSOutgoing(r._FLAGS_QR_QUERY)
- bytes = generated.packet()
- id = bytes[0] << 8 | bytes[1]
- self.assertEqual(id, 0)
-
- def test_query_header_bits(self):
- generated = r.DNSOutgoing(r._FLAGS_QR_QUERY)
- bytes = generated.packet()
- flags = bytes[2] << 8 | bytes[3]
- self.assertEqual(flags, 0x0)
-
- def test_response_header_bits(self):
- generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
- bytes = generated.packet()
- flags = bytes[2] << 8 | bytes[3]
- self.assertEqual(flags, 0x8000)
-
- def test_numbers(self):
- generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
- bytes = generated.packet()
- (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12])
- self.assertEqual(num_questions, 0)
- self.assertEqual(num_answers, 0)
- self.assertEqual(num_authorities, 0)
- self.assertEqual(num_additionals, 0)
-
- def test_numbers_questions(self):
- generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
- question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN)
- for i in range(10):
- generated.add_question(question)
- bytes = generated.packet()
- (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12])
- self.assertEqual(num_questions, 10)
- self.assertEqual(num_answers, 0)
- self.assertEqual(num_authorities, 0)
- self.assertEqual(num_additionals, 0)
-
-
-class Names(unittest.TestCase):
- def test_long_name(self):
- generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
- question = r.DNSQuestion(
- "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", r._TYPE_SRV, r._CLASS_IN
- )
- generated.add_question(question)
- r.DNSIncoming(generated.packet())
-
- def test_exceedingly_long_name(self):
- generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
- name = "%slocal." % ("part." * 1000)
- question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN)
- generated.add_question(question)
- r.DNSIncoming(generated.packet())
-
- def test_exceedingly_long_name_part(self):
- name = "%s.local." % ("a" * 1000)
- generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
- question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN)
- generated.add_question(question)
- self.assertRaises(r.NamePartTooLongException, generated.packet)
-
- def test_same_name(self):
- name = "paired.local."
- generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
- question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN)
- generated.add_question(question)
- generated.add_question(question)
- r.DNSIncoming(generated.packet())
-
- def test_lots_of_names(self):
-
- # instantiate a zeroconf instance
- zc = Zeroconf(interfaces=['127.0.0.1'])
-
- # create a bunch of servers
- type_ = "_my-service._tcp.local."
- name = 'a wonderful service'
- server_count = 300
- self.generate_many_hosts(zc, type_, name, server_count)
-
- # verify that name changing works
- self.verify_name_change(zc, type_, name, server_count)
-
- # we are going to monkey patch the zeroconf send to check packet sizes
- old_send = zc.send
-
- longest_packet_len = 0
- longest_packet = None # type: Optional[r.DNSOutgoing]
-
- def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT):
- """Sends an outgoing packet."""
- packet = out.packet()
- nonlocal longest_packet_len, longest_packet
- if longest_packet_len < len(packet):
- longest_packet_len = len(packet)
- longest_packet = out
- old_send(out, addr=addr, port=port)
-
- # monkey patch the zeroconf send
- setattr(zc, "send", send)
-
- # dummy service callback
- def on_service_state_change(zeroconf, service_type, state_change, name):
- pass
-
- # start a browser
- browser = ServiceBrowser(zc, type_, [on_service_state_change])
-
- # wait until the browse request packet has maxed out in size
- sleep_count = 0
- while sleep_count < 100 and longest_packet_len < r._MAX_MSG_ABSOLUTE - 100:
- sleep_count += 1
- time.sleep(0.1)
-
- browser.cancel()
- time.sleep(0.5)
-
- import zeroconf
-
- zeroconf.log.debug('sleep_count %d, sized %d', sleep_count, longest_packet_len)
-
- # now the browser has sent at least one request, verify the size
- assert longest_packet_len <= r._MAX_MSG_ABSOLUTE
- assert longest_packet_len >= r._MAX_MSG_ABSOLUTE - 100
-
- # mock zeroconf's logger warning() and debug()
- from unittest.mock import patch
-
- patch_warn = patch('zeroconf.log.warning')
- patch_debug = patch('zeroconf.log.debug')
- mocked_log_warn = patch_warn.start()
- mocked_log_debug = patch_debug.start()
-
- # now that we have a long packet in our possession, let's verify the
- # exception handling.
- out = longest_packet
- assert out is not None
- out.data.append(b'\0' * 1000)
-
- # mock the zeroconf logger and check for the correct logging backoff
- call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count
- # try to send an oversized packet
- zc.send(out)
- assert mocked_log_warn.call_count == call_counts[0] + 1
- assert mocked_log_debug.call_count == call_counts[0]
- zc.send(out)
- assert mocked_log_warn.call_count == call_counts[0] + 1
- assert mocked_log_debug.call_count == call_counts[0] + 1
-
- # force a receive of an oversized packet
- packet = out.packet()
- s = zc._respond_sockets[0]
-
- # mock the zeroconf logger and check for the correct logging backoff
- call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count
- # force receive on oversized packet
- s.sendto(packet, 0, (r._MDNS_ADDR, r._MDNS_PORT))
- s.sendto(packet, 0, (r._MDNS_ADDR, r._MDNS_PORT))
- time.sleep(2.0)
- zeroconf.log.debug(
- 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts
- )
- assert mocked_log_debug.call_count > call_counts[0]
-
- # close our zeroconf which will close the sockets
- zc.close()
-
- # pop the big chunk off the end of the data and send on a closed socket
- out.data.pop()
- zc._GLOBAL_DONE = False
-
- # mock the zeroconf logger and check for the correct logging backoff
- call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count
- # send on a closed socket (force a socket error)
- zc.send(out)
- zeroconf.log.debug(
- 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts
- )
- assert mocked_log_warn.call_count > call_counts[0]
- assert mocked_log_debug.call_count > call_counts[0]
- zc.send(out)
- zeroconf.log.debug(
- 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts
- )
- assert mocked_log_debug.call_count > call_counts[0] + 2
-
- mocked_log_warn.stop()
- mocked_log_debug.stop()
-
- def verify_name_change(self, zc, type_, name, number_hosts):
- desc = {'path': '/~paulsm/'}
- info_service = ServiceInfo(
- type_, '%s.%s' % (name, type_), socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local."
- )
-
- # verify name conflict
- self.assertRaises(r.NonUniqueNameException, zc.register_service, info_service)
-
- zc.register_service(info_service, allow_name_change=True)
- assert info_service.name.split('.')[0] == '%s-%d' % (name, number_hosts + 1)
-
- def generate_many_hosts(self, zc, type_, name, number_hosts):
- records_per_server = 2
- block_size = 25
- number_hosts = int(((number_hosts - 1) / block_size + 1)) * block_size
- for i in range(1, number_hosts + 1):
- next_name = name if i == 1 else '%s-%d' % (name, i)
- self.generate_host(zc, next_name, type_)
- if i % block_size == 0:
- sleep_count = 0
- while sleep_count < 40 and i * records_per_server > len(zc.cache.entries_with_name(type_)):
- sleep_count += 1
- time.sleep(0.05)
-
- @staticmethod
- def generate_host(zc, host_name, type_):
- name = '.'.join((host_name, type_))
- out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA)
- out.add_answer_at_time(r.DNSPointer(type_, r._TYPE_PTR, r._CLASS_IN, r._DNS_OTHER_TTL, name), 0)
- out.add_answer_at_time(
- r.DNSService(type_, r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, name), 0
- )
- zc.send(out)
-
-
-class Framework(unittest.TestCase):
- def test_launch_and_close(self):
- rv = r.Zeroconf(interfaces=r.InterfaceChoice.All)
- rv.close()
- rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default)
- rv.close()
-
- @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6')
- @attr('IPv6')
- def test_launch_and_close_v4_v6(self):
- rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.All)
- rv.close()
- rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.All)
- rv.close()
-
- @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6')
- @attr('IPv6')
- def test_launch_and_close_v6_only(self):
- rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.V6Only)
- rv.close()
- rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.V6Only)
- rv.close()
-
- def test_handle_response(self):
- def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
- ttl = 120
- generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
-
- if service_state_change == r.ServiceStateChange.Updated:
- generated.add_answer_at_time(
- r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0
- )
- return r.DNSIncoming(generated.packet())
-
- if service_state_change == r.ServiceStateChange.Removed:
- ttl = 0
-
- generated.add_answer_at_time(
- r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_name), 0
- )
- generated.add_answer_at_time(
- r.DNSService(
- service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server
- ),
- 0,
- )
- generated.add_answer_at_time(
- r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0
- )
- generated.add_answer_at_time(
- r.DNSAddress(
- service_server,
- r._TYPE_A,
- r._CLASS_IN | r._CLASS_UNIQUE,
- ttl,
- socket.inet_aton(service_address),
- ),
- 0,
- )
-
- return r.DNSIncoming(generated.packet())
-
- service_name = 'name._type._tcp.local.'
- service_type = '_type._tcp.local.'
- service_server = 'ash-2.local.'
- service_text = b'path=/~paulsm/'
- service_address = '10.0.1.2'
-
- zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
-
- try:
- # service added
- zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Added))
- dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN)
- assert dns_text is not None
- assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~paulsm/'
-
- # https://tools.ietf.org/html/rfc6762#section-10.2
- # Instead of merging this new record additively into the cache in addition
- # to any previous records with the same name, rrtype, and rrclass,
- # all old records with that name, rrtype, and rrclass that were received
- # more than one second ago are declared invalid,
- # and marked to expire from the cache in one second.
- time.sleep(1.1)
-
- # service updated. currently only text record can be updated
- service_text = b'path=/~humingchun/'
- zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated))
- dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN)
- assert dns_text is not None
- assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/'
-
- time.sleep(1.1)
-
- # service removed
- zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Removed))
- dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN)
- assert dns_text is None
-
- finally:
- zeroconf.close()
-
-
-class Exceptions(unittest.TestCase):
-
- browser = None # type: Zeroconf
-
- @classmethod
- def setUpClass(cls):
- cls.browser = Zeroconf(interfaces=['127.0.0.1'])
-
- @classmethod
- def tearDownClass(cls):
- cls.browser.close()
- del cls.browser
-
- def test_bad_service_info_name(self):
- self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, "type", "type_not")
-
- def test_bad_service_names(self):
- bad_names_to_try = (
- '',
- 'local',
- '_tcp.local.',
- '_udp.local.',
- '._udp.local.',
- '_@._tcp.local.',
- '_A@._tcp.local.',
- '_x--x._tcp.local.',
- '_-x._udp.local.',
- '_x-._tcp.local.',
- '_22._udp.local.',
- '_2-2._tcp.local.',
- '_1234567890-abcde._udp.local.',
- '\x00._x._udp.local.',
- )
- for name in bad_names_to_try:
- self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, name, 'x.' + name)
-
- def test_good_instance_names(self):
- good_names_to_try = (
- '.._x._tcp.local.',
- 'x.sub._http._tcp.local.',
- '6d86f882b90facee9170ad3439d72a4d6ee9f511._zget._http._tcp.local.',
- )
- for name in good_names_to_try:
- r.service_type_name(name)
-
- def test_bad_types(self):
- bad_names_to_try = (
- '._x._tcp.local.',
- 'a' * 64 + '._sub._http._tcp.local.',
- 'a' * 62 + u'â._sub._http._tcp.local.',
- )
- for name in bad_names_to_try:
- self.assertRaises(r.BadTypeInNameException, r.service_type_name, name)
-
- def test_bad_sub_types(self):
- bad_names_to_try = (
- '_sub._http._tcp.local.',
- '._sub._http._tcp.local.',
- '\x7f._sub._http._tcp.local.',
- '\x1f._sub._http._tcp.local.',
- )
- for name in bad_names_to_try:
- self.assertRaises(r.BadTypeInNameException, r.service_type_name, name)
-
- def test_good_service_names(self):
- good_names_to_try = (
- '_x._tcp.local.',
- '_x._udp.local.',
- '_12345-67890-abc._udp.local.',
- 'x._sub._http._tcp.local.',
- 'a' * 63 + '._sub._http._tcp.local.',
- 'a' * 61 + u'â._sub._http._tcp.local.',
- )
- for name in good_names_to_try:
- r.service_type_name(name)
-
- r.service_type_name('_one_two._tcp.local.', allow_underscores=True)
-
- def test_invalid_addresses(self):
- type_ = "_test-srvc-type._tcp.local."
- name = "xxxyyy"
- registration_name = "%s.%s" % (name, type_)
-
- bad = ('127.0.0.1', '::1', 42)
- for addr in bad:
- self.assertRaisesRegex(
- TypeError,
- 'Addresses must be bytes',
- ServiceInfo,
- type_,
- registration_name,
- port=80,
- addresses=[addr],
- )
-
-
-class TestDnsIncoming(unittest.TestCase):
- def test_incoming_exception_handling(self):
- generated = r.DNSOutgoing(0)
- packet = generated.packet()
- packet = packet[:8] + b'deadbeef' + packet[8:]
- parsed = r.DNSIncoming(packet)
- parsed = r.DNSIncoming(packet)
- assert parsed.valid is False
-
- def test_incoming_unknown_type(self):
- generated = r.DNSOutgoing(0)
- answer = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a')
- generated.add_additional_answer(answer)
- packet = generated.packet()
- parsed = r.DNSIncoming(packet)
- assert len(parsed.answers) == 0
- assert parsed.is_query() != parsed.is_response()
-
- def test_incoming_ipv6(self):
- addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com
- packed = socket.inet_pton(socket.AF_INET6, addr)
- generated = r.DNSOutgoing(0)
- answer = r.DNSAddress('domain', r._TYPE_AAAA, r._CLASS_IN, 1, packed)
- generated.add_additional_answer(answer)
- packet = generated.packet()
- parsed = r.DNSIncoming(packet)
- record = parsed.answers[0]
- assert isinstance(record, r.DNSAddress)
- assert record.address == packed
-
-
-class TestRegistrar(unittest.TestCase):
- def test_ttl(self):
-
- # instantiate a zeroconf instance
- zc = Zeroconf(interfaces=['127.0.0.1'])
-
- # service definition
- type_ = "_test-srvc-type._tcp.local."
- name = "xxxyyy"
- registration_name = "%s.%s" % (name, type_)
-
- desc = {'path': '/~paulsm/'}
- info = ServiceInfo(
- type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local."
- )
-
- # we are going to monkey patch the zeroconf send to check packet sizes
- old_send = zc.send
-
- nbr_answers = nbr_additionals = nbr_authorities = 0
-
- def get_ttl(record_type):
- if expected_ttl is not None:
- return expected_ttl
- elif record_type in [r._TYPE_A, r._TYPE_SRV]:
- return r._DNS_HOST_TTL
- else:
- return r._DNS_OTHER_TTL
-
- def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT):
- """Sends an outgoing packet."""
- nonlocal nbr_answers, nbr_additionals, nbr_authorities
-
- for answer, time_ in out.answers:
- nbr_answers += 1
- assert answer.ttl == get_ttl(answer.type)
- for answer in out.additionals:
- nbr_additionals += 1
- assert answer.ttl == get_ttl(answer.type)
- for answer in out.authorities:
- nbr_authorities += 1
- assert answer.ttl == get_ttl(answer.type)
- old_send(out, addr=addr, port=port)
-
- # monkey patch the zeroconf send
- setattr(zc, "send", send)
-
- # register service with default TTL
- expected_ttl = None
- zc.register_service(info)
- assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3
- nbr_answers = nbr_additionals = nbr_authorities = 0
-
- # query
- query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA)
- query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN))
- query.add_question(r.DNSQuestion(info.name, r._TYPE_SRV, r._CLASS_IN))
- query.add_question(r.DNSQuestion(info.name, r._TYPE_TXT, r._CLASS_IN))
- query.add_question(r.DNSQuestion(info.server, r._TYPE_A, r._CLASS_IN))
- zc.handle_query(r.DNSIncoming(query.packet()), r._MDNS_ADDR, r._MDNS_PORT)
- assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0
- nbr_answers = nbr_additionals = nbr_authorities = 0
-
- # unregister
- expected_ttl = 0
- zc.unregister_service(info)
- assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0
- nbr_answers = nbr_additionals = nbr_authorities = 0
-
- # register service with custom TTL
- expected_ttl = r._DNS_HOST_TTL * 2
- assert expected_ttl != r._DNS_HOST_TTL
- zc.register_service(info, ttl=expected_ttl)
- assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3
- nbr_answers = nbr_additionals = nbr_authorities = 0
-
- # query
- query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA)
- query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN))
- query.add_question(r.DNSQuestion(info.name, r._TYPE_SRV, r._CLASS_IN))
- query.add_question(r.DNSQuestion(info.name, r._TYPE_TXT, r._CLASS_IN))
- query.add_question(r.DNSQuestion(info.server, r._TYPE_A, r._CLASS_IN))
- zc.handle_query(r.DNSIncoming(query.packet()), r._MDNS_ADDR, r._MDNS_PORT)
- assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0
- nbr_answers = nbr_additionals = nbr_authorities = 0
-
- # unregister
- expected_ttl = 0
- zc.unregister_service(info)
- assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0
- nbr_answers = nbr_additionals = nbr_authorities = 0
-
-
-class TestDNSCache(unittest.TestCase):
- def test_order(self):
- record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a')
- record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b')
- cache = r.DNSCache()
- cache.add(record1)
- cache.add(record2)
- entry = r.DNSEntry('a', r._TYPE_SOA, r._CLASS_IN)
- cached_record = cache.get(entry)
- self.assertEqual(cached_record, record2)
-
-
-class ServiceTypesQuery(unittest.TestCase):
- def test_integration_with_listener(self):
-
- type_ = "_test-srvc-type._tcp.local."
- name = "xxxyyy"
- registration_name = "%s.%s" % (name, type_)
-
- zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1'])
- desc = {'path': '/~paulsm/'}
- info = ServiceInfo(
- type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local."
- )
- zeroconf_registrar.register_service(info)
-
- try:
- service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5)
- assert type_ in service_types
- service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5)
- assert type_ in service_types
-
- finally:
- zeroconf_registrar.close()
-
- @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6')
- def test_integration_with_listener_v6_records(self):
-
- type_ = "_test-srvc-type._tcp.local."
- name = "xxxyyy"
- registration_name = "%s.%s" % (name, type_)
- addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com
-
- zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1'])
- desc = {'path': '/~paulsm/'}
- info = ServiceInfo(
- type_, registration_name, socket.inet_pton(socket.AF_INET6, addr), 80, 0, 0, desc, "ash-2.local."
- )
- zeroconf_registrar.register_service(info)
-
- try:
- service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5)
- assert type_ in service_types
- service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5)
- assert type_ in service_types
-
- finally:
- zeroconf_registrar.close()
-
- @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6')
- @attr('IPv6')
- def test_integration_with_listener_ipv6(self):
-
- type_ = "_test-srvc-type._tcp.local."
- name = "xxxyyy"
- registration_name = "%s.%s" % (name, type_)
-
- zeroconf_registrar = Zeroconf(ip_version=r.IPVersion.V6Only)
- desc = {'path': '/~paulsm/'}
- info = ServiceInfo(
- type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local."
- )
- zeroconf_registrar.register_service(info)
-
- try:
- service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5)
- assert type_ in service_types, service_types
- service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5)
- assert type_ in service_types, service_types
-
- finally:
- zeroconf_registrar.close()
-
- def test_integration_with_subtype_and_listener(self):
- subtype_ = "_subtype._sub"
- type_ = "_type._tcp.local."
- name = "xxxyyy"
- # Note: discovery returns only DNS-SD type not subtype
- discovery_type = "%s.%s" % (subtype_, type_)
- registration_name = "%s.%s" % (name, type_)
-
- zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1'])
- desc = {'path': '/~paulsm/'}
- info = ServiceInfo(
- discovery_type, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local."
- )
- zeroconf_registrar.register_service(info)
-
- try:
- service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5)
- assert discovery_type in service_types
- service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5)
- assert discovery_type in service_types
-
- finally:
- zeroconf_registrar.close()
-
-
-class ListenerTest(unittest.TestCase):
- def test_integration_with_listener_class(self):
-
- service_added = Event()
- service_removed = Event()
- service_updated = Event()
-
- subtype_name = "My special Subtype"
- type_ = "_http._tcp.local."
- subtype = subtype_name + "._sub." + type_
- name = "xxxyyyæøå"
- registration_name = "%s.%s" % (name, subtype)
-
- class MyListener(r.ServiceListener):
- def add_service(self, zeroconf, type, name):
- zeroconf.get_service_info(type, name)
- service_added.set()
-
- def remove_service(self, zeroconf, type, name):
- service_removed.set()
-
- def update_service(self, zeroconf, type, name):
- pass
-
- class MySubListener(r.ServiceListener):
- def add_service(self, zeroconf, type, name):
- pass
-
- def remove_service(self, zeroconf, type, name):
- pass
-
- def update_service(self, zeroconf, type, name):
- service_updated.set()
-
- listener = MyListener()
- zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
- zeroconf_browser.add_service_listener(subtype, listener)
-
- properties = dict(
- prop_none=None,
- prop_string=b'a_prop',
- prop_float=1.0,
- prop_blank=b'a blanked string',
- prop_true=1,
- prop_false=0,
- )
-
- zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1'])
- desc = {'path': '/~paulsm/'} # type: Dict
- desc.update(properties)
- addresses = [socket.inet_aton("10.0.1.2")]
- if socket.has_ipv6:
- addresses.append(socket.inet_pton(socket.AF_INET6, "2001:db8::1"))
- info_service = ServiceInfo(
- subtype, registration_name, port=80, properties=desc, server="ash-2.local.", addresses=addresses
- )
- zeroconf_registrar.register_service(info_service)
-
- try:
- service_added.wait(1)
- assert service_added.is_set()
-
- # short pause to allow multicast timers to expire
- time.sleep(3)
-
- # clear the answer cache to force query
- for record in zeroconf_browser.cache.entries():
- zeroconf_browser.cache.remove(record)
-
- # get service info without answer cache
- info = zeroconf_browser.get_service_info(type_, registration_name)
- assert info is not None
- assert info.properties[b'prop_none'] is False
- assert info.properties[b'prop_string'] == properties['prop_string']
- assert info.properties[b'prop_float'] is False
- assert info.properties[b'prop_blank'] == properties['prop_blank']
- assert info.properties[b'prop_true'] is True
- assert info.properties[b'prop_false'] is False
- assert info.addresses == addresses[:1] # no V6 by default
- all_addresses = info.addresses_by_version(r.IPVersion.All)
- assert all_addresses == addresses, all_addresses
-
- info = zeroconf_browser.get_service_info(subtype, registration_name)
- assert info is not None
- assert info.properties[b'prop_none'] is False
-
- # Begin material test addition
- sublistener = MySubListener()
- zeroconf_browser.add_service_listener(registration_name, sublistener)
- properties['prop_blank'] = b'an updated string'
- desc.update(properties)
- info_service = ServiceInfo(
- subtype, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local."
- )
- zeroconf_registrar.update_service(info_service)
- service_updated.wait(1)
- assert service_updated.is_set()
-
- info = zeroconf_browser.get_service_info(type_, registration_name)
- assert info is not None
- assert info.properties[b'prop_blank'] == properties['prop_blank']
- # End material test addition
-
- zeroconf_registrar.unregister_service(info_service)
- service_removed.wait(1)
- assert service_removed.is_set()
-
- finally:
- zeroconf_registrar.close()
- zeroconf_browser.remove_service_listener(listener)
- zeroconf_browser.close()
-
-
-class TestServiceBrowser(unittest.TestCase):
- def test_update_record(self):
-
- service_name = 'name._type._tcp.local.'
- service_type = '_type._tcp.local.'
- service_server = 'ash-2.local.'
- service_text = b'path=/~paulsm/'
- service_address = '10.0.1.2'
-
- service_added = False
- service_removed = False
- service_updated_count = 0
- service_add_event = Event()
- service_removed_event = Event()
- service_updated_event = Event()
-
- class MyServiceListener(r.ServiceListener):
- def add_service(self, zc, type_, name) -> None:
- nonlocal service_added
- service_added = True
- service_add_event.set()
-
- def remove_service(self, zc, type_, name) -> None:
- nonlocal service_added, service_removed
- service_added = False
- service_removed = True
- service_removed_event.set()
-
- def update_service(self, zc, type_, name) -> None:
- nonlocal service_updated_count
- service_updated_count += 1
-
- service_info = zc.get_service_info(type_, name)
- assert service_info.text == service_text
- service_updated_event.set()
-
- def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
- ttl = 120
- generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
-
- if service_state_change == r.ServiceStateChange.Updated:
- generated.add_answer_at_time(
- r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0
- )
- return r.DNSIncoming(generated.packet())
-
- if service_state_change == r.ServiceStateChange.Removed:
- ttl = 0
-
- generated.add_answer_at_time(
- r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_name), 0
- )
- generated.add_answer_at_time(
- r.DNSService(
- service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server
- ),
- 0,
- )
- generated.add_answer_at_time(
- r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0
- )
- generated.add_answer_at_time(
- r.DNSAddress(
- service_server,
- r._TYPE_A,
- r._CLASS_IN | r._CLASS_UNIQUE,
- ttl,
- socket.inet_aton(service_address),
- ),
- 0,
- )
-
- return r.DNSIncoming(generated.packet())
-
- zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
- service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener())
-
- try:
- # service added
- zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Added))
- service_add_event.wait(1)
- service_updated_event.wait(1)
- assert service_added is True
- assert service_updated_count == 1
- assert service_removed is False
-
- # service updated. currently only text record can be updated
- service_updated_event.clear()
- service_text = b'path=/~humingchun/'
- zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated))
- service_updated_event.wait(1)
- assert service_added is True
- assert service_updated_count == 2
- assert service_removed is False
-
- # service removed
- zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Removed))
- service_removed_event.wait(1)
- assert service_added is False
- assert service_updated_count == 2
- assert service_removed is True
-
- finally:
- service_browser.cancel()
- zeroconf.remove_all_service_listeners()
- zeroconf.close()
-
-
-def test_backoff():
- got_query = Event()
-
- type_ = "_http._tcp.local."
- zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
-
- # we are going to monkey patch the zeroconf send to check query transmission
- old_send = zeroconf_browser.send
-
- time_offset = 0.0
- start_time = time.time() * 1000
- initial_query_interval = r._BROWSER_TIME / 1000
-
- def current_time_millis():
- """Current system time in milliseconds"""
- return start_time + time_offset * 1000
-
- def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT):
- """Sends an outgoing packet."""
- got_query.set()
- old_send(out, addr=addr, port=port)
-
- # monkey patch the zeroconf send
- setattr(zeroconf_browser, "send", send)
-
- # monkey patch the zeroconf current_time_millis
- r.current_time_millis = current_time_millis
-
- # monkey patch the backoff limit to prevent test running forever
- r._BROWSER_BACKOFF_LIMIT = 10 # seconds
-
- # dummy service callback
- def on_service_state_change(zeroconf, service_type, state_change, name):
- pass
-
- browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change])
-
- try:
- # Test that queries are sent at increasing intervals
- sleep_count = 0
- next_query_interval = 0.0
- expected_query_time = 0.0
- while True:
- zeroconf_browser.notify_all()
- sleep_count += 1
- got_query.wait(0.1)
- if time_offset == expected_query_time:
- assert got_query.is_set()
- got_query.clear()
- if next_query_interval == r._BROWSER_BACKOFF_LIMIT:
- # Only need to test up to the point where we've seen a query
- # after the backoff limit has been hit
- break
- elif next_query_interval == 0:
- next_query_interval = initial_query_interval
- expected_query_time = initial_query_interval
- else:
- next_query_interval = min(2 * next_query_interval, r._BROWSER_BACKOFF_LIMIT)
- expected_query_time += next_query_interval
- else:
- assert not got_query.is_set()
- time_offset += initial_query_interval
-
- finally:
- browser.cancel()
- zeroconf_browser.close()
-
-
-def test_integration():
- service_added = Event()
- service_removed = Event()
- unexpected_ttl = Event()
- got_query = Event()
-
- type_ = "_http._tcp.local."
- registration_name = "xxxyyy.%s" % type_
-
- def on_service_state_change(zeroconf, service_type, state_change, name):
- if name == registration_name:
- if state_change is ServiceStateChange.Added:
- service_added.set()
- elif state_change is ServiceStateChange.Removed:
- service_removed.set()
-
- zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
-
- # we are going to monkey patch the zeroconf send to check packet sizes
- old_send = zeroconf_browser.send
-
- time_offset = 0.0
-
- def current_time_millis():
- """Current system time in milliseconds"""
- return time.time() * 1000 + time_offset * 1000
-
- expected_ttl = r._DNS_HOST_TTL
-
- nbr_answers = 0
-
- def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT):
- """Sends an outgoing packet."""
- pout = r.DNSIncoming(out.packet())
- nonlocal nbr_answers
- for answer in pout.answers:
- nbr_answers += 1
- if not answer.ttl > expected_ttl / 2:
- unexpected_ttl.set()
-
- got_query.set()
- old_send(out, addr=addr, port=port)
-
- # monkey patch the zeroconf send
- setattr(zeroconf_browser, "send", send)
-
- # monkey patch the zeroconf current_time_millis
- r.current_time_millis = current_time_millis
-
- # monkey patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL
- r._BROWSER_BACKOFF_LIMIT = int(expected_ttl / 4)
-
- service_added = Event()
- service_removed = Event()
-
- browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change])
-
- zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1'])
- desc = {'path': '/~paulsm/'}
- info = ServiceInfo(type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local.")
- zeroconf_registrar.register_service(info)
-
- try:
- service_added.wait(1)
- assert service_added.is_set()
-
- # Test that we receive queries containing answers only if the remaining TTL
- # is greater than half the original TTL
- sleep_count = 0
- test_iterations = 50
- while nbr_answers < test_iterations:
- # Increase simulated time shift by 1/4 of the TTL in seconds
- time_offset += expected_ttl / 4
- zeroconf_browser.notify_all()
- sleep_count += 1
- got_query.wait(0.1)
- got_query.clear()
- # Prevent the test running indefinitely in an error condition
- assert sleep_count < test_iterations * 4
- assert not unexpected_ttl.is_set()
-
- # Don't remove service, allow close() to cleanup
-
- finally:
- zeroconf_registrar.close()
- service_removed.wait(1)
- assert service_removed.is_set()
- browser.cancel()
- zeroconf_browser.close()
-
-
-def test_multiple_addresses():
- type_ = "_http._tcp.local."
- registration_name = "xxxyyy.%s" % type_
- desc = {'path': '/~paulsm/'}
- address_parsed = "10.0.1.2"
- address = socket.inet_aton(address_parsed)
-
- # Old way
- info = ServiceInfo(type_, registration_name, address, 80, 0, 0, desc, "ash-2.local.")
-
- assert info.address == address
- assert info.addresses == [address]
-
- # Updating works
- address2 = socket.inet_aton("10.0.1.3")
- info.address = address2
-
- assert info.address == address2
- assert info.addresses == [address2]
-
- info.address = None
-
- assert info.address is None
- assert info.addresses == []
-
- info.addresses = [address2]
-
- assert info.address == address2
- assert info.addresses == [address2]
-
- # Compatibility way
- info = ServiceInfo(type_, registration_name, [address, address], 80, 0, 0, desc, "ash-2.local.")
-
- assert info.addresses == [address, address]
-
- # New kwarg way
- info = ServiceInfo(
- type_, registration_name, None, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address]
- )
-
- assert info.addresses == [address, address]
-
- if socket.has_ipv6:
- address_v6_parsed = "2001:db8::1"
- address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed)
- info = ServiceInfo(type_, registration_name, [address, address_v6], 80, 0, 0, desc, "ash-2.local.")
- assert info.addresses == [address]
- assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6]
- assert info.addresses_by_version(r.IPVersion.V4Only) == [address]
- assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6]
- assert info.parsed_addresses() == [address_parsed, address_v6_parsed]
- assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed]
- assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed]
-
-
-def test_ptr_optimization():
-
- # instantiate a zeroconf instance
- zc = Zeroconf(interfaces=['127.0.0.1'])
-
- # service definition
- type_ = "_test-srvc-type._tcp.local."
- name = "xxxyyy"
- registration_name = "%s.%s" % (name, type_)
-
- desc = {'path': '/~paulsm/'}
- info = ServiceInfo(type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local.")
-
- # we are going to monkey patch the zeroconf send to check packet sizes
- old_send = zc.send
-
- nbr_answers = nbr_additionals = nbr_authorities = 0
- has_srv = has_txt = has_a = False
-
- def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT):
- """Sends an outgoing packet."""
- nonlocal nbr_answers, nbr_additionals, nbr_authorities
- nonlocal has_srv, has_txt, has_a
-
- nbr_answers += len(out.answers)
- nbr_authorities += len(out.authorities)
- for answer in out.additionals:
- nbr_additionals += 1
- if answer.type == r._TYPE_SRV:
- has_srv = True
- elif answer.type == r._TYPE_TXT:
- has_txt = True
- elif answer.type == r._TYPE_A:
- has_a = True
-
- old_send(out, addr=addr, port=port)
-
- # monkey patch the zeroconf send
- setattr(zc, "send", send)
-
- # register
- zc.register_service(info)
- nbr_answers = nbr_additionals = nbr_authorities = 0
-
- # query
- query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA)
- query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN))
- zc.handle_query(r.DNSIncoming(query.packet()), r._MDNS_ADDR, r._MDNS_PORT)
- assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0
- assert has_srv and has_txt and has_a
-
- # unregister
- zc.unregister_service(info)